import torch
import sympy
from IPython.display import Math
import pixyz
_EPSILON = 1e-07
[docs]def set_epsilon(eps):
"""Set a `epsilon` parameter.
Parameters
----------
eps : int or float
Returns
-------
Examples
--------
>>> from unittest import mock
>>> with mock.patch('pixyz.utils._EPSILON', 1e-07):
... set_epsilon(1e-06)
... epsilon()
1e-06
"""
global _EPSILON
_EPSILON = eps
[docs]def epsilon():
"""Get a `epsilon` parameter.
Returns
-------
int or float
Examples
--------
>>> from unittest import mock
>>> with mock.patch('pixyz.utils._EPSILON', 1e-07):
... epsilon()
1e-07
"""
return _EPSILON
[docs]def get_dict_values(dicts, keys, return_dict=False):
"""Get values from `dicts` specified by `keys`.
When `return_dict` is True, return values are in dictionary format.
Parameters
----------
dicts : dict
keys : list
return_dict : bool
Returns
-------
dict or list
Examples
--------
>>> get_dict_values({"a":1,"b":2,"c":3}, ["b"])
[2]
>>> get_dict_values({"a":1,"b":2,"c":3}, ["b", "d"], True)
{'b': 2}
"""
new_dicts = dict((key, dicts[key]) for key in keys if key in list(dicts.keys()))
if return_dict is False:
return list(new_dicts.values())
return new_dicts
[docs]def delete_dict_values(dicts, keys):
"""Delete values from `dicts` specified by `keys`.
Parameters
----------
dicts : dict
keys : list
Returns
-------
new_dicts : dict
Examples
--------
>>> delete_dict_values({"a":1,"b":2,"c":3}, ["b","d"])
{'a': 1, 'c': 3}
"""
new_dicts = dict((key, value) for key, value in dicts.items() if key not in keys)
return new_dicts
[docs]def detach_dict(dicts):
"""Detach all values in `dicts`.
Parameters
----------
dicts : dict
Returns
-------
dict
"""
return {k: v.detach() for k, v in dicts.items()}
[docs]def replace_dict_keys(dicts, replace_list_dict):
""" Replace values in `dicts` according to `replace_list_dict`.
Parameters
----------
dicts : dict
Dictionary.
replace_list_dict : dict
Dictionary.
Returns
-------
replaced_dicts : dict
Dictionary.
Examples
--------
>>> replace_dict_keys({"a":1,"b":2,"c":3}, {"a":"x","b":"y"})
{'x': 1, 'y': 2, 'c': 3}
>>> replace_dict_keys({"a":1,"b":2,"c":3}, {"a":"x","e":"y"}) # keys of `replace_list_dict`
{'x': 1, 'b': 2, 'c': 3}
"""
replaced_dicts = dict([(replace_list_dict[key], value) if key in list(replace_list_dict.keys())
else (key, value) for key, value in dicts.items()])
return replaced_dicts
[docs]def replace_dict_keys_split(dicts, replace_list_dict):
""" Replace values in `dicts` according to :attr:`replace_list_dict`.
Replaced dict is splitted by :attr:`replaced_dict` and :attr:`remain_dict`.
Parameters
----------
dicts : dict
Dictionary.
replace_list_dict : dict
Dictionary.
Returns
-------
replaced_dict : dict
Dictionary.
remain_dict : dict
Dictionary.
Examples
--------
>>> replace_list_dict = {'a': 'loc'}
>>> x_dict = {'a': 0, 'b': 1}
>>> print(replace_dict_keys_split(x_dict, replace_list_dict))
({'loc': 0}, {'b': 1})
"""
replaced_dict = {replace_list_dict[key]: value for key, value in dicts.items()
if key in list(replace_list_dict.keys())}
remain_dict = {key: value for key, value in dicts.items()
if key not in list(replace_list_dict.keys())}
return replaced_dict, remain_dict
[docs]def tolist(a):
"""Convert a given input to the dictionary format.
Parameters
----------
a : list or other
Returns
-------
list
Examples
--------
>>> tolist(2)
[2]
>>> tolist([1, 2])
[1, 2]
>>> tolist([])
[]
"""
if type(a) is list:
return a
return [a]
[docs]def sum_samples(samples):
"""Sum a given sample across the axes.
Parameters
----------
samples : torch.Tensor
Input sample. The number of this axes is assumed to be 4 or less.
Returns
-------
torch.Tensor
Sum over all axes except the first axis.
Examples
--------
>>> a = torch.ones([2])
>>> sum_samples(a).size()
torch.Size([2])
>>> a = torch.ones([2, 3])
>>> sum_samples(a).size()
torch.Size([2])
>>> a = torch.ones([2, 3, 4])
>>> sum_samples(a).size()
torch.Size([2])
"""
dim = samples.dim()
if dim == 1:
return samples
elif dim <= 4:
dim_list = list(torch.arange(samples.dim()))
samples = torch.sum(samples, dim=dim_list[1:])
return samples
raise ValueError("The number of sample axes must be any of 1, 2, 3, or 4, "
"got %s." % dim)
[docs]def print_latex(obj):
"""Print formulas in latex format.
Parameters
----------
obj : pixyz.distributions.distributions.Distribution, pixyz.losses.losses.Loss or pixyz.models.model.Model.
"""
if isinstance(obj, pixyz.distributions.distributions.Distribution):
latex_text = obj.prob_joint_factorized_and_text
elif isinstance(obj, pixyz.losses.losses.Loss):
latex_text = obj.loss_text
elif isinstance(obj, pixyz.models.model.Model):
latex_text = obj.loss_cls.loss_text
return Math(latex_text)
[docs]def convert_latex_name(name):
return sympy.latex(sympy.Symbol(name))