Source code for pixyz.utils

import torch
import sympy
from IPython.display import Math
import pixyz

_EPSILON = 1e-07


[docs]def set_epsilon(eps): """Set a `epsilon` parameter. Parameters ---------- eps : int or float Returns ------- Examples -------- >>> from unittest import mock >>> with mock.patch('pixyz.utils._EPSILON', 1e-07): ... set_epsilon(1e-06) ... epsilon() 1e-06 """ global _EPSILON _EPSILON = eps
[docs]def epsilon(): """Get a `epsilon` parameter. Returns ------- int or float Examples -------- >>> from unittest import mock >>> with mock.patch('pixyz.utils._EPSILON', 1e-07): ... epsilon() 1e-07 """ return _EPSILON
[docs]def get_dict_values(dicts, keys, return_dict=False): """Get values from `dicts` specified by `keys`. When `return_dict` is True, return values are in dictionary format. Parameters ---------- dicts : dict keys : list return_dict : bool Returns ------- dict or list Examples -------- >>> get_dict_values({"a":1,"b":2,"c":3}, ["b"]) [2] >>> get_dict_values({"a":1,"b":2,"c":3}, ["b", "d"], True) {'b': 2} """ new_dicts = dict((key, dicts[key]) for key in keys if key in list(dicts.keys())) if return_dict is False: return list(new_dicts.values()) return new_dicts
[docs]def delete_dict_values(dicts, keys): """Delete values from `dicts` specified by `keys`. Parameters ---------- dicts : dict keys : list Returns ------- new_dicts : dict Examples -------- >>> delete_dict_values({"a":1,"b":2,"c":3}, ["b","d"]) {'a': 1, 'c': 3} """ new_dicts = dict((key, value) for key, value in dicts.items() if key not in keys) return new_dicts
[docs]def detach_dict(dicts): """Detach all values in `dicts`. Parameters ---------- dicts : dict Returns ------- dict """ return {k: v.detach() for k, v in dicts.items()}
[docs]def replace_dict_keys(dicts, replace_list_dict): """ Replace values in `dicts` according to `replace_list_dict`. Parameters ---------- dicts : dict Dictionary. replace_list_dict : dict Dictionary. Returns ------- replaced_dicts : dict Dictionary. Examples -------- >>> replace_dict_keys({"a":1,"b":2,"c":3}, {"a":"x","b":"y"}) {'x': 1, 'y': 2, 'c': 3} >>> replace_dict_keys({"a":1,"b":2,"c":3}, {"a":"x","e":"y"}) # keys of `replace_list_dict` {'x': 1, 'b': 2, 'c': 3} """ replaced_dicts = dict([(replace_list_dict[key], value) if key in list(replace_list_dict.keys()) else (key, value) for key, value in dicts.items()]) return replaced_dicts
[docs]def replace_dict_keys_split(dicts, replace_list_dict): """ Replace values in `dicts` according to :attr:`replace_list_dict`. Replaced dict is splitted by :attr:`replaced_dict` and :attr:`remain_dict`. Parameters ---------- dicts : dict Dictionary. replace_list_dict : dict Dictionary. Returns ------- replaced_dict : dict Dictionary. remain_dict : dict Dictionary. Examples -------- >>> replace_list_dict = {'a': 'loc'} >>> x_dict = {'a': 0, 'b': 1} >>> print(replace_dict_keys_split(x_dict, replace_list_dict)) ({'loc': 0}, {'b': 1}) """ replaced_dict = {replace_list_dict[key]: value for key, value in dicts.items() if key in list(replace_list_dict.keys())} remain_dict = {key: value for key, value in dicts.items() if key not in list(replace_list_dict.keys())} return replaced_dict, remain_dict
[docs]def tolist(a): """Convert a given input to the dictionary format. Parameters ---------- a : list or other Returns ------- list Examples -------- >>> tolist(2) [2] >>> tolist([1, 2]) [1, 2] >>> tolist([]) [] """ if type(a) is list: return a return [a]
[docs]def sum_samples(samples): """Sum a given sample across the axes. Parameters ---------- samples : torch.Tensor Input sample. The number of this axes is assumed to be 4 or less. Returns ------- torch.Tensor Sum over all axes except the first axis. Examples -------- >>> a = torch.ones([2]) >>> sum_samples(a).size() torch.Size([2]) >>> a = torch.ones([2, 3]) >>> sum_samples(a).size() torch.Size([2]) >>> a = torch.ones([2, 3, 4]) >>> sum_samples(a).size() torch.Size([2]) """ dim = samples.dim() if dim == 1: return samples elif dim <= 4: dim_list = list(torch.arange(samples.dim())) samples = torch.sum(samples, dim=dim_list[1:]) return samples raise ValueError("The number of sample axes must be any of 1, 2, 3, or 4, " "got %s." % dim)
[docs]def convert_latex_name(name): return sympy.latex(sympy.Symbol(name))