pixyz.utils

pixyz.utils.set_epsilon(eps)[source]

Set a epsilon parameter.

Parameters:eps (int or float) –

Examples

>>> from unittest import mock
>>> with mock.patch('pixyz.utils._EPSILON', 1e-07):
...     set_epsilon(1e-06)
...     epsilon()
1e-06
pixyz.utils.epsilon()[source]

Get a epsilon parameter.

Returns:
Return type:int or float

Examples

>>> from unittest import mock
>>> with mock.patch('pixyz.utils._EPSILON', 1e-07):
...     epsilon()
1e-07
pixyz.utils.get_dict_values(dicts, keys, return_dict=False)[source]

Get values from dicts specified by keys.

When return_dict is True, return values are in dictionary format.

Parameters:
  • dicts (dict) –
  • keys (list) –
  • return_dict (bool) –
Returns:

Return type:

dict or list

Examples

>>> get_dict_values({"a":1,"b":2,"c":3}, ["b"])
[2]
>>> get_dict_values({"a":1,"b":2,"c":3}, ["b", "d"], True)
{'b': 2}
pixyz.utils.delete_dict_values(dicts, keys)[source]

Delete values from dicts specified by keys.

Parameters:
  • dicts (dict) –
  • keys (list) –
Returns:

new_dicts

Return type:

dict

Examples

>>> delete_dict_values({"a":1,"b":2,"c":3}, ["b","d"])
{'a': 1, 'c': 3}
pixyz.utils.detach_dict(dicts)[source]

Detach all values in dicts.

Parameters:dicts (dict) –
Returns:
Return type:dict
pixyz.utils.replace_dict_keys(dicts, replace_list_dict)[source]

Replace values in dicts according to replace_list_dict.

Parameters:
  • dicts (dict) – Dictionary.
  • replace_list_dict (dict) – Dictionary.
Returns:

replaced_dicts – Dictionary.

Return type:

dict

Examples

>>> replace_dict_keys({"a":1,"b":2,"c":3}, {"a":"x","b":"y"})
{'x': 1, 'y': 2, 'c': 3}
>>> replace_dict_keys({"a":1,"b":2,"c":3}, {"a":"x","e":"y"})  # keys of `replace_list_dict`
{'x': 1, 'b': 2, 'c': 3}
pixyz.utils.replace_dict_keys_split(dicts, replace_list_dict)[source]

Replace values in dicts according to replace_list_dict.

Replaced dict is splitted by replaced_dict and remain_dict.

Parameters:
  • dicts (dict) – Dictionary.
  • replace_list_dict (dict) – Dictionary.
Returns:

  • replaced_dict (dict) – Dictionary.
  • remain_dict (dict) – Dictionary.

Examples

>>> replace_list_dict = {'a': 'loc'}
>>> x_dict = {'a': 0, 'b': 1}
>>> print(replace_dict_keys_split(x_dict, replace_list_dict))
({'loc': 0}, {'b': 1})
class pixyz.utils.FrozenSampleDict(dict_)[source]

Bases: object

pixyz.utils.lru_cache_for_sample_dict(maxsize=0)[source]

Memoize the calculation result linked to the argument of sample dict. Note that dictionary arguments of the target function must be sample dict.

Parameters:maxsize (cache size prepared for the target method) –
Returns:
Return type:decorator function

Examples

>>> import time
>>> import torch.nn as nn
>>> import pixyz.utils as utils
>>> # utils.CACHE_SIZE = 2  # you can also use this module option to enable all memoization of distribution
>>> import pixyz.distributions as pd
>>> class LongEncoder(pd.Normal):
...     def __init__(self):
...         super().__init__(cond_var=['y'], var=['x'])
...         self.nn = nn.Sequential(*(nn.Linear(1,1) for i in range(10000)))
...     def forward(self, y):
...         return {'loc': self.nn(y), 'scale': torch.ones(1,1)}
...     @lru_cache_for_sample_dict(maxsize=2)
...     def get_params(self, params_dict={}, **kwargs):
...         return super().get_params(params_dict, **kwargs)
>>> def measure_time(func):
...     start = time.time()
...     func()
...     elapsed_time = time.time() - start
...     return elapsed_time
>>> le = LongEncoder()
>>> y = torch.ones(1, 1)
>>> t_sample1 = measure_time(lambda:le.sample({'y': y}))
>>> print ("sample1:{0}".format(t_sample1) + "[sec]") # doctest: +SKIP
>>> t_log_prob = measure_time(lambda:le.get_log_prob({'x': y, 'y': y}))
>>> print ("log_prob:{0}".format(t_log_prob) + "[sec]") # doctest: +SKIP
>>> t_sample2 = measure_time(lambda:le.sample({'y': y}))
>>> print ("sample2:{0}".format(t_sample2) + "[sec]") # doctest: +SKIP
>>> assert t_sample1 > t_sample2, "processing time increases: {0}".format(t_sample2 - t_sample1)
pixyz.utils.tolist(a)[source]

Convert a given input to the dictionary format.

Parameters:a (list or other) –
Returns:
Return type:list

Examples

>>> tolist(2)
[2]
>>> tolist([1, 2])
[1, 2]
>>> tolist([])
[]
pixyz.utils.sum_samples(samples)[source]

Sum a given sample across the axes.

Parameters:samples (torch.Tensor) – Input sample. The number of this axes is assumed to be 4 or less.
Returns:Sum over all axes except the first axis.
Return type:torch.Tensor

Examples

>>> a = torch.ones([2])
>>> sum_samples(a).size()
torch.Size([2])
>>> a = torch.ones([2, 3])
>>> sum_samples(a).size()
torch.Size([2])
>>> a = torch.ones([2, 3, 4])
>>> sum_samples(a).size()
torch.Size([2])
pixyz.utils.print_latex(obj)[source]

Print formulas in latex format.

Parameters:obj (pixyz.distributions.distributions.Distribution, pixyz.losses.losses.Loss or pixyz.models.model.Model.) –
pixyz.utils.convert_latex_name(name)[source]