Source code for pixyz.distributions.distributions

from __future__ import print_function
import torch
import numbers
import re
from torch import nn
from copy import deepcopy

from ..utils import get_dict_values, replace_dict_keys, delete_dict_values, tolist
from ..losses import LogProb, Prob


[docs]class Distribution(nn.Module): """ Distribution class. In Pixyz, all distributions are required to inherit this class. Attributes ---------- var : list Variables of this distribution. cond_var : list Conditional variables of this distribution. In case that cond_var is not empty, we must set the corresponding inputs in order to sample variables. dim : int Number of dimensions of this distribution. This might be ignored depending on the shape which is set in the sample method and on its parent distribution. Moreover, this is not consider when this class is inherited by DNNs. This is set to 1 by default. name : str Name of this distribution. This name is displayed in prob_text and prob_factorized_text. This is set to "p" by default. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=1): super().__init__() _vars = cond_var + var if len(_vars) != len(set(_vars)): raise ValueError("There are conflicted variables.") self._cond_var = cond_var self._var = var self.dim = dim self._name = name self._prob_text = None self._prob_factorized_text = None @property def distribution_name(self): return None @property def name(self): return self._name @name.setter def name(self, name): if type(name) is str: self._name = name return raise ValueError("Name of the distribution class must be set as a string type.") @property def var(self): return self._var @property def cond_var(self): return self._cond_var @property def input_var(self): """ Normally, `input_var` has same values as `cond_var`. """ return self._cond_var @property def prob_text(self): _var_text = [','.join(self._var)] if len(self._cond_var) != 0: _var_text += [','.join(self._cond_var)] _prob_text = "{}({})".format( self._name, "|".join(_var_text) ) return _prob_text @property def prob_factorized_text(self): return self.prob_text def _check_input(self, x, var=None): """ Check the type of given input. If the input type is `dictionary`, this method checks whether the input keys contains the `var` list. In case that its type is `list` or `tensor`, it returns the output formatted in `dictionary`. Parameters ---------- x : torch.Tensor, list, or dict Input variables var : list or None Variables to check if given input contains them. This is set to None by default. Returns ------- checked_x : dict Variables checked in this method. Raises ------ ValueError Raises ValueError if the type of input is neither tensor, list, nor dictionary. """ if var is None: var = self.input_var if type(x) is torch.Tensor: checked_x = {var[0]: x} elif type(x) is list: # TODO: we need to check if all the elements contained in this list are torch.Tensor. checked_x = dict(zip(var, x)) elif type(x) is dict: if not (set(list(x.keys())) >= set(var)): raise ValueError("Input keys are not valid.") checked_x = x else: raise ValueError("The type of input is not valid, got %s." % type(x)) return checked_x
[docs] def get_params(self, params_dict={}): """ This method aims to get parameters of this distributions from constant parameters set in initialization and outputs of DNNs. Parameters ---------- params_dict : dict Input parameters. Returns ------- output_dict : dict Output parameters Examples -------- >>> print(dist_1.prob_text, dist_1.distribution_name) p(x) Normal >>> dist_1.get_params() {"loc": 0, "scale": 1} >>> print(dist_2.prob_text, dist_2.distribution_name) p(x|z) Normal >>> dist_1.get_params({"z": 1}) {"loc": 0, "scale": 1} """ raise NotImplementedError
[docs] def sample(self, x={}, shape=None, batch_size=1, return_all=True, reparam=False): """ Sample variables of this distribution. If `cond_var` is not empty, we should set inputs as a dictionary format. Parameters ---------- x : torch.Tensor, list, or dict Input variables. shape : tuple Shape of samples. If set, `batch_size` and `dim` are ignored. batch_size : int Batch size of samples. This is set to 1 by default. return_all : bool Choose whether the output contains input variables. reparam : bool Choose whether we sample variables with reparameterized trick. Returns ------- output : dict Samples of this distribution. """ raise NotImplementedError
[docs] def get_log_prob(self, *args, **kwargs): raise NotImplementedError
[docs] def log_prob(self, sum_features=True, feature_dims=None): return LogProb(self, sum_features=sum_features, feature_dims=feature_dims)
[docs] def prob(self, sum_features=True, feature_dims=None): return Prob(self, sum_features=sum_features, feature_dims=feature_dims)
[docs] def log_likelihood(self, *args, **kwargs): raise NotImplementedError("The `log_likelihood()` method has been removed. " "You should use `log_prob().eval()` instead.")
[docs] def forward(self, *args, **kwargs): """ When this class is inherited by DNNs, it is also intended that this method is overrided. """ raise NotImplementedError
[docs] def sample_mean(self, x): raise NotImplementedError
[docs] def sample_variance(self, x): raise NotImplementedError
[docs] def replace_var(self, **replace_dict): return ReplaceVarDistribution(self, replace_dict)
[docs] def marginalize_var(self, marginalize_list): marginalize_list = tolist(marginalize_list) return MarginalizeVarDistribution(self, marginalize_list)
def __mul__(self, other): return MultiplyDistribution(self, other) def __str__(self): # Distribution if self.prob_factorized_text == self.prob_text: prob_text = "{} ({})".format(self.prob_text, self.distribution_name) else: prob_text = "{} = {}".format(self.prob_text, self.prob_factorized_text) text = "Distribution:\n {}\n".format(prob_text) # Network architecture (`repr`) network_text = self.__repr__() network_text = re.sub('^', ' ' * 2, str(network_text), flags=re.MULTILINE) text += "Network architecture:\n{}".format(network_text) return text
class DistributionBase(Distribution): def __init__(self, cond_var=[], var=["x"], name="p", dim=1, **kwargs): super().__init__(cond_var=cond_var, var=var, name=name, dim=dim) self._set_constant_params(**kwargs) def _set_constant_params(self, **params_dict): """ Format constant parameters of this distribution. Parameters ---------- params_dict : dict Constant parameters of this distribution set at initialization. If the values of these dictionaries contain parameters which are named as strings, which means that these parameters are set as "variables", the correspondences between these values and the true name of these parameters are stored as a dictionary format (`replace_params_dict`). """ self.replace_params_dict = {} self.constant_params_dict = {} for key in params_dict.keys(): if type(params_dict[key]) is str: if params_dict[key] in self._cond_var: self.replace_params_dict[params_dict[key]] = key else: raise ValueError elif isinstance(params_dict[key], numbers.Number) or isinstance(params_dict[key], torch.Tensor): self.constant_params_dict[key] = params_dict[key] else: raise ValueError def set_distribution(self, x={}, sampling=True): """ Require self.params_keys and self.DistributionTorch Parameters ---------- x : dict sampling : bool Returns ------- """ params = self.get_params(x) if set(self.params_keys) != set(params.keys()): raise ValueError self.dist = self.DistributionTorch(**params) def _get_sample(self, reparam=False, sample_shape=torch.Size()): """ Parameters ---------- reparam : bool sample_shape : tuple Returns ------- samples_dict : dict """ if reparam: try: _samples = self.dist.rsample(sample_shape=sample_shape) except NotImplementedError: print("We can not use the reparameterization trick " "for this distribution.") else: _samples = self.dist.sample(sample_shape=sample_shape) samples_dict = {self._var[0]: _samples} return samples_dict def get_log_prob(self, x_dict, sum_features=True, feature_dims=None): """ Parameters ---------- x_dict : dict sum_features : bool feature_dims : None or list Returns ------- log_prob : torch.Tensor """ _x_dict = get_dict_values(x_dict, self._cond_var, return_dict=True) self.set_distribution(_x_dict, sampling=False) x_targets = get_dict_values(x_dict, self._var) log_prob = self.dist.log_prob(*x_targets) if sum_features: log_prob = sum_samples(log_prob) return log_prob def _replace_vars_to_params(self, vars_dict, replace_dict): """ Replace variables in input keys to parameters of this distribution according to these correspondences which is formatted in a dictionary and set in `_initialize_constant_params`. Parameters ---------- vars_dict : dict Dictionary. replace_dict : dict Dictionary. Returns ------- params_dict : dict Dictionary. Examples -------- >>> replace_dict {"a": "loc"} >>> x = {"a": 0, "b": 1} >>> distribution._replace_vars_to_params(x, replace_dict) {"loc": 0}, {"b": 1} """ params_dict = {replace_dict[key]: value for key, value in vars_dict.items() if key in list(replace_dict.keys())} vars_dict = {key: value for key, value in vars_dict.items() if key not in list(replace_dict.keys())} return params_dict, vars_dict def get_params(self, params_dict={}): params_dict, vars_dict = self._replace_vars_to_params(params_dict, self.replace_params_dict) output_dict = self.forward(**vars_dict) # append constant_params to dict output_dict.update(params_dict) output_dict.update(self.constant_params_dict) return output_dict def sample(self, x={}, shape=None, batch_size=1, return_all=True, reparam=False): # check whether the input is valid or convert it to valid dictionary. x_dict = self._check_input(x) # unconditioned if len(self.input_var) == 0: if shape: sample_shape = shape else: if self.dim is None: sample_shape = (batch_size, ) else: sample_shape = (batch_size, self.dim) self.set_distribution() output_dict = self._get_sample(reparam=reparam, sample_shape=sample_shape) # conditioned else: # remove redundant variables from x_dict. _x_dict = get_dict_values(x_dict, self.input_var, return_dict=True) self.set_distribution(_x_dict) output_dict = self._get_sample(reparam=reparam) if return_all: x_dict.update(output_dict) return x_dict return output_dict def sample_mean(self, x={}): self.set_distribution(x) return self.dist.mean def sample_variance(self, x={}): self.set_distribution(x) return self.dist.variance def forward(self, **params): return params
[docs]class MultiplyDistribution(Distribution): """ Multiply by given distributions, e.g, :math:`p(x,y|z) = p(x|z,y)p(y|z)`. In this class, it is checked if two distributions can be multiplied. p(x|z)p(z|y) -> Valid p(x|z)p(y|z) -> Valid p(x|z)p(y|a) -> Valid p(x|z)p(z|x) -> Invalid (recursive) p(x|z)p(x|y) -> Invalid (conflict) Parameters ---------- a : pixyz.Distribution Distribution. b : pixyz.Distribution Distribution. Examples -------- >>> p_multi = MultipleDistribution([a, b]) >>> p_multi = a * b """ def __init__(self, a, b): if not (isinstance(a, Distribution) and isinstance(b, Distribution)): raise ValueError("Given inputs should be `pixyz.Distribution`, got {} and {}.".format(type(a), type(b))) # Check parent-child relationship between two distributions. # If inherited variables (`_inh_var`) are exist (e.g. c in p(e|c)p(c|a,b)), # then p(e|c) is a child and p(c|a,b) is a parent, otherwise it is opposite. _vars_a_b = a.cond_var + b.var _vars_b_a = b.cond_var + a.var _inh_var_a_b = [var for var in set(_vars_a_b) if _vars_a_b.count(var) > 1] _inh_var_b_a = [var for var in set(_vars_b_a) if _vars_b_a.count(var) > 1] if len(_inh_var_a_b) > 0: _child = a _parent = b _inh_var = _inh_var_a_b elif len(_inh_var_b_a) > 0: _child = b _parent = a _inh_var = _inh_var_b_a else: _child = a _parent = b _inh_var = [] # Check if variables of two distributions are "recursive" (e.g. p(x|z)p(z|x)). _check_recursive_vars = _child.var + _parent.cond_var if len(_check_recursive_vars) != len(set(_check_recursive_vars)): raise ValueError("Variables of two distributions, {} and {}, are recursive.".format(_child.prob_text, _parent.prob_text)) # Set variables. _var = _child.var + _parent.var if len(_var) != len(set(_var)): # e.g. p(x|z)p(x|y) raise ValueError("Variables of two distributions, {} and {}, are conflicted.".format(_child.prob_text, _parent.prob_text)) # Set conditional variables. _cond_var = _child.cond_var + _parent.cond_var _cond_var = sorted(set(_cond_var), key=_cond_var.index) # Delete inh_var in conditional variables. _cond_var = [var for var in _cond_var if var not in _inh_var] super().__init__(cond_var=_cond_var, var=_var) self._inh_var = _inh_var self._parent = _parent self._child = _child # Set input_var (it might be different from cond_var if either a and b contain data distributions.) _input_var = [var for var in self._child.input_var if var not in _inh_var] _input_var += self._parent.input_var self._input_var = sorted(set(_input_var), key=_input_var.index) @property def inh_var(self): return self._inh_var @property def input_var(self): return self._input_var @property def prob_factorized_text(self): return self._child.prob_factorized_text + self._parent.prob_factorized_text
[docs] def sample(self, x={}, shape=None, batch_size=1, return_all=True, reparam=False): # sample from the parent distribution parents_x_dict = x child_x_dict = self._parent.sample(x=parents_x_dict, shape=shape, batch_size=batch_size, return_all=True, reparam=reparam) # sample from the child distribution output_dict = self._child.sample(x=child_x_dict, shape=shape, batch_size=batch_size, return_all=True, reparam=reparam) if return_all is False: output_dict = get_dict_values(x, self._var, return_dict=True) return output_dict return output_dict
[docs] def get_log_prob(self, x, sum_features=True, feature_dims=None): parent_log_prob = self._parent.get_log_prob(x, sum_features=sum_features, feature_dims=feature_dims) child_log_prob = self._child.get_log_prob(x, sum_features=sum_features, feature_dims=feature_dims) if sum_features: return parent_log_prob + child_log_prob if parent_log_prob.size() == child_log_prob.size(): return parent_log_prob + child_log_prob raise ValueError("Two PDFs, {} and {}, have different sizes," " so you must set sum_dim=True.".format(self._parent.prob_text, self._child.prob_text))
def __repr__(self): if isinstance(self._parent, MultiplyDistribution): text = self._parent.__repr__() else: text = "{} ({}): {}".format(self._parent.prob_text, self._parent.distribution_name, self._parent.__repr__()) text += "\n" if isinstance(self._child, MultiplyDistribution): text += self._child.__repr__() else: text += "{} ({}): {}".format(self._child.prob_text, self._child.distribution_name, self._child.__repr__()) return text
[docs]class ReplaceVarDistribution(Distribution): """ Replace names of variables in Distribution. Attributes ---------- a : pixyz.Distribution (not pixyz.MultiplyDistribution) Distribution. replace_dict : dict Dictionary. """ def __init__(self, a, replace_dict): if not isinstance(a, Distribution): raise ValueError("Given input should be `pixyz.Distribution`, got {}.".format(type(a))) if isinstance(a, MultiplyDistribution): raise ValueError("`pixyz.MultiplyDistribution` is not supported for now.") if isinstance(a, MarginalizeVarDistribution): raise ValueError("`pixyz.MarginalizeVarDistribution` is not supported for now.") _cond_var = deepcopy(a.cond_var) _var = deepcopy(a.var) all_vars = _cond_var + _var if not (set(replace_dict.keys()) <= set(all_vars)): raise ValueError _replace_inv_cond_var_dict = {replace_dict[var]: var for var in _cond_var if var in replace_dict.keys()} _replace_inv_dict = {value: key for key, value in replace_dict.items()} self._replace_inv_cond_var_dict = _replace_inv_cond_var_dict self._replace_inv_dict = _replace_inv_dict self._replace_dict = replace_dict _cond_var = [replace_dict[var] if var in replace_dict.keys() else var for var in _cond_var] _var = [replace_dict[var] if var in replace_dict.keys() else var for var in _var] super().__init__(cond_var=_cond_var, var=_var, name=a.name, dim=a.dim) self._a = a _input_var = [replace_dict[var] if var in replace_dict.keys() else var for var in a.input_var] self._input_var = _input_var
[docs] def forward(self, *args, **kwargs): return self._a.forward(*args, **kwargs)
[docs] def get_params(self, params_dict): params_dict = replace_dict_keys(params_dict, self._replace_inv_cond_var_dict) return self._a.get_params(params_dict)
[docs] def sample(self, x={}, shape=None, batch_size=1, return_all=True, reparam=False): input_dict = get_dict_values(x, self.cond_var, return_dict=True) replaced_input_dict = replace_dict_keys(input_dict, self._replace_inv_cond_var_dict) output_dict = self._a.sample(replaced_input_dict, shape=shape, batch_size=batch_size, return_all=False, reparam=reparam) output_dict = replace_dict_keys(output_dict, self._replace_dict) x.update(output_dict) return x
[docs] def get_log_prob(self, x_dict, **kwargs): """ Parameters ---------- x_dict : dict Returns ------- torch.Tensor In """ input_dict = get_dict_values(x_dict, self.cond_var + self.var, return_dict=True) input_dict = replace_dict_keys(input_dict, self._replace_inv_dict) return self._a.get_log_prob(input_dict, **kwargs)
[docs] def sample_mean(self, x): input_dict = get_dict_values(x, self.cond_var, return_dict=True) input_dict = replace_dict_keys(input_dict, self._replace_inv_cond_var_dict) return self._a.sample_mean(input_dict)
[docs] def sample_variance(self, x): input_dict = get_dict_values(x, self.cond_var, return_dict=True) input_dict = replace_dict_keys(input_dict, self._replace_inv_cond_var_dict) return self._a.sample_variance(input_dict)
@property def input_var(self): return self._input_var @property def distribution_name(self): return self._a.distribution_name def __repr__(self): return self._a.__repr__() def __getattr__(self, item): try: return super().__getattr__(item) except AttributeError: return self._a.__getattribute__(item)
[docs]class MarginalizeVarDistribution(Distribution): """ Marginalize variables in Distribution. :math:`p(x) = \int p(x,z) dz` Attributes ---------- a : pixyz.Distribution (not pixyz.DistributionBase) Distribution. marginalize_list : list Variables to marginalize. """ def __init__(self, a, marginalize_list): marginalize_list = tolist(marginalize_list) if not isinstance(a, Distribution): raise ValueError("Given input must be `pixyz.Distribution`, got {}.".format(type(a))) if isinstance(a, DistributionBase): raise ValueError("`pixyz.DistributionBase` cannot marginalize its variables for now.") _var = deepcopy(a.var) _cond_var = deepcopy(a.cond_var) if not((set(marginalize_list)) < set(_var)): raise ValueError() if not((set(marginalize_list)).isdisjoint(set(_cond_var))): raise ValueError() if len(marginalize_list) == 0: raise ValueError("Length of `marginalize_list` must be at least 1, got %d." % len(marginalize_list)) _var = [var for var in _var if var not in marginalize_list] super().__init__(cond_var=_cond_var, var=_var, name=a.name, dim=a.dim) self._a = a self._marginalize_list = marginalize_list
[docs] def forward(self, *args, **kwargs): return self._a.forward(*args, **kwargs)
[docs] def get_params(self, params_dict): return self._a.get_params(params_dict)
[docs] def sample(self, x={}, shape=None, batch_size=1, return_all=True, reparam=False): output_dict = self._a.sample(x=x, shape=shape, batch_size=batch_size, return_all=False, reparam=reparam) output_dict = delete_dict_values(output_dict, self._marginalize_list) return output_dict
[docs] def sample_mean(self, x): return self._a.sample_mean(x)
[docs] def sample_variance(self, x): return self._a.sample_variance(x)
@property def input_var(self): return self._a.input_var @property def distribution_name(self): return self._a.distribution_name @property def prob_factorized_text(self): integral_symbol = len(self._marginalize_list) * "∫" integral_variables = ["d"+str(var) for var in self._marginalize_list] integral_variables = "".join(integral_variables) return "{}{}{}".format(integral_symbol, self._a.prob_factorized_text, integral_variables) def __repr__(self): return self._a.__repr__() def __getattr__(self, item): try: return super().__getattr__(item) except AttributeError: return self._a.__getattribute__(item)
[docs]def sum_samples(samples): dim = samples.dim() if dim == 1: return samples elif dim <= 4: dim_list = list(torch.arange(samples.dim())) samples = torch.sum(samples, dim=dim_list[1:]) return samples raise ValueError("The dim of samples must be any of 1, 2, 3, or 4, " "got dim %s." % dim)