Source code for pixyz.distributions.poe

from __future__ import print_function
import torch
from torch import nn

from ..utils import tolist, get_dict_values
from .exponential_distributions import Normal


[docs]class ProductOfNormal(Normal): r"""Product of normal distributions. .. math:: p(z|x,y) \propto p(z)p(z|x)p(z|y) In this model, :math:`p(z|x)` and :math:`p(a|y)` perform as `experts` and :math:`p(z)` corresponds a prior of `experts`. References ---------- [Vedantam+ 2017] Generative Models of Visually Grounded Imagination [Wu+ 2018] Multimodal Generative Models for Scalable Weakly-Supervised Learning Examples -------- >>> pon = ProductOfNormal([p_x, p_y]) # doctest: +SKIP >>> pon.sample({"x": x, "y": y}) # doctest: +SKIP {'x': tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]],), 'y': tensor([[0., 0., 0., ..., 0., 0., 1.], [0., 0., 1., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 1., 0.], [1., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 1.]]), 'z': tensor([[ 0.6611, 0.3811, 0.7778, ..., -0.0468, -0.3615, -0.6569], [-0.0071, -0.9178, 0.6620, ..., -0.1472, 0.6023, 0.5903], [-0.3723, -0.7758, 0.0195, ..., 0.8239, -0.3537, 0.3854], ..., [ 0.7820, -0.4761, 0.1804, ..., -0.5701, -0.0714, -0.5485], [-0.1873, -0.2105, -0.1861, ..., -0.5372, 0.0752, 0.2777], [-0.2563, -0.0828, 0.1605, ..., 0.2767, -0.8456, 0.7364]])} >>> pon.sample({"y": y}) # doctest: +SKIP {'y': tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 1.], [0., 0., 0., ..., 1., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]), 'z': tensor([[-0.3264, -0.4448, 0.3610, ..., -0.7378, 0.3002, 0.4370], [ 0.0928, -0.1830, 1.1768, ..., 1.1808, -0.7226, -0.4152], [ 0.6999, 0.2222, -0.2901, ..., 0.5706, 0.7091, 0.5179], ..., [ 0.5688, -1.6612, -0.0713, ..., -0.1400, -0.3903, 0.2533], [ 0.5412, -0.0289, 0.6365, ..., 0.7407, 0.7838, 0.9218], [ 0.0299, 0.5148, -0.1001, ..., 0.9938, 1.0689, -1.1902]])} >>> pon.sample() # same as sampling from unit Gaussian. # doctest: +SKIP {'z': tensor(-0.4494)} """
[docs] def __init__(self, p=[], name="p", features_shape=torch.Size()): """ Parameters ---------- p : :obj:`list` of :class:`pixyz.distributions.Normal`. List of experts. name : :obj:`str`, defaults to "p" Name of this distribution. This name is displayed in prob_text and prob_factorized_text. features_shape : :obj:`torch.Size` or :obj:`list`, defaults to torch.Size()) Shape of dimensions (features) of this distribution. """ p = tolist(p) if len(p) == 0: raise ValueError var = p[0].var cond_var = [] for _p in p: if _p.var != var: raise ValueError if _p.distribution_name != "Normal": raise ValueError cond_var += _p.cond_var super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape) if len(p) == 1: self.p = p[0] else: self.p = nn.ModuleList(p)
@property def prob_factorized_text(self): prob_text = "p({})".format( ','.join(self._var) ) if len(self._cond_var) != 0: prob_text += "".join([p.prob_text for p in self.p]) return prob_text @property def prob_joint_factorized_and_text(self): """str: Return a formula of the factorized probability distribution.""" if self.prob_factorized_text == self.prob_text: prob_text = self.prob_text else: prob_text = "{} \\propto {}".format(self.prob_text, self.prob_factorized_text) return prob_text def _get_expert_params(self, params_dict={}, **kwargs): """Get the output parameters of all experts. Parameters ---------- params_dict : dict **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor """ loc = [] scale = [] for _p in self.p: inputs_dict = get_dict_values(params_dict, _p.cond_var, True) if len(inputs_dict) != 0: outputs = _p.get_params(inputs_dict, **kwargs) loc.append(outputs["loc"]) scale.append(outputs["scale"]) loc = torch.stack(loc) scale = torch.stack(scale) return loc, scale
[docs] def get_params(self, params_dict={}, **kwargs): # experts if len(params_dict) > 0: loc, scale = self._get_expert_params(params_dict, **kwargs) # (n_expert, n_batch, output_dim) else: loc = torch.zeros(1) scale = torch.zeros(1) output_loc, output_scale = self._compute_expert_params(loc, scale) output_dict = {"loc": output_loc, "scale": output_scale} return output_dict
@staticmethod def _compute_expert_params(loc, scale): """Compute parameters for the product of experts. Is is assumed that unspecified experts are excluded from inputs. Parameters ---------- loc : torch.Tensor Concatenation of mean vectors for specified experts. (n_expert, n_batch, output_dim) scale : torch.Tensor Concatenation of the square root of a diagonal covariance matrix for specified experts. (n_expert, n_batch, output_dim) Returns ------- output_loc : torch.Tensor Mean vectors for this distribution. (n_batch, output_dim) output_scale : torch.Tensor The square root of diagonal covariance matrices for this distribution. (n_batch, output_dim) """ # parameter for prior prior_prec = 1 # prior_loc is not specified because it is equal to 0. # compute the diagonal precision matrix. prec = torch.zeros_like(scale).type(scale.dtype) prec[scale != 0] = 1. / scale[scale != 0] # compute the square root of a diagonal covariance matrix for the product of distributions. output_prec = torch.sum(prec, dim=0) + prior_prec output_variance = 1. / output_prec # (n_batch, output_dim) # compute the mean vectors for the product of normal distributions. output_loc = torch.sum(prec * loc, dim=0) # (n_batch, output_dim) output_loc = output_loc * output_variance return output_loc, torch.sqrt(output_variance) def _check_input(self, x, var=None): if var is None: var = self.input_var if type(x) is torch.Tensor: checked_x = {var[0]: x} elif type(x) is list: # TODO: we need to check if all the elements contained in this list are torch.Tensor. checked_x = dict(zip(var, x)) elif type(x) is dict: # point of modification checked_x = x else: raise ValueError("The type of input is not valid, got %s." % type(x)) return checked_x
[docs] def log_prob(self, sum_features=True, feature_dims=None): raise NotImplementedError
[docs] def prob(self, sum_features=True, feature_dims=None): raise NotImplementedError
[docs] def get_log_prob(self, x_dict, sum_features=True, feature_dims=None): raise NotImplementedError
[docs]class ElementWiseProductOfNormal(ProductOfNormal): r"""Product of normal distributions. In this distribution, each element of the input vector on the given distribution is considered as a different expert. .. math:: p(z|x) = p(z|x_1, x_2) \propto p(z)p(z|x_1)p(z|x_2) Examples -------- >>> pon = ElementWiseProductOfNormal(p) # doctest: +SKIP >>> pon.sample({"x": x}) # doctest: +SKIP {'x': tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]), 'z': tensor([[-0.3572, -0.0632, 0.4872, 0.2269, -0.1693, -0.0160, -0.0429, 0.2017, -0.1589, -0.3380, -0.9598, 0.6216, -0.4296, -1.1349, 0.0901, 0.3994, 0.2313, -0.5227, -0.7973, 0.3968, 0.7137, -0.5639, -0.4891, -0.1249, 0.8256, 0.1463, 0.0801, -1.2202, 0.6984, -0.4036, 0.4960, -0.4376, 0.3310, -0.2243, -0.2381, -0.2200, 0.8969, 0.2674, 0.4681, 1.6764, 0.8127, 0.2722, -0.2048, 0.1903, -0.1398, 0.0099, 0.4382, -0.8016, 0.9947, 0.7556, -0.2017, -0.3920, 1.4212, -1.2529, -0.1002, -0.0031, 0.1876, 0.4267, 0.3622, 0.2648, 0.4752, 0.0843, -0.3065, -0.4922], [ 0.3770, -0.0413, 0.9102, 0.2897, -0.0567, 0.5211, 1.5233, -0.3539, 0.5163, -0.2271, -0.1027, 0.0294, -1.4617, 0.1640, 0.2025, -0.2190, 0.0555, 0.5779, -0.2930, -0.2161, 0.2835, -0.0354, -0.2569, -0.7171, 0.0164, -0.4080, 1.1088, 0.3947, 0.2720, -0.0600, -0.9295, -0.0234, 0.5624, 0.4866, 0.5285, 1.1827, 0.2494, 0.0777, 0.7585, 0.5127, 0.7500, -0.3253, 0.0250, 0.0888, 1.0340, -0.1405, -0.8114, 0.4492, 0.2725, -0.0270, 0.6379, -0.8096, 0.4259, 0.3179, -0.1681, 0.3365, 0.6305, 0.5203, 0.2384, 0.0572, 0.4804, 0.9553, -0.3244, 1.5373]])} >>> pon.sample({"x": torch.zeros_like(x)}) # same as sampling from unit Gaussian. # doctest: +SKIP {'x': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), 'z': tensor([[-0.7777, -0.5908, -1.5498, -0.7505, 0.6201, 0.7218, 1.0045, 0.8923, -0.8030, -0.3569, 0.2932, 0.2122, 0.1640, 0.7893, -0.3500, -1.0537, -1.2769, 0.6122, -1.0083, -0.2915, -0.1928, -0.7486, 0.2418, -1.9013, 1.2514, 1.3035, -0.3029, -0.3098, -0.5415, 1.1970, -0.4443, 2.2393, -0.6980, 0.2820, 1.6972, 0.6322, 0.4308, 0.8953, 0.7248, 0.4440, 2.2770, 1.7791, 0.7563, -1.1781, -0.8331, 0.1825, 1.5447, 0.1385, -1.1348, 0.0257, 0.3374, 0.5889, 1.1231, -1.2476, -0.3801, -1.4404, -1.3066, -1.2653, 0.5958, -1.7423, 0.7189, -0.7236, 0.2330, 0.3117], [ 0.5495, 0.7210, -0.4708, -2.0631, -0.6170, 0.2436, -0.0133, -0.4616, -0.8091, -0.1592, 1.3117, 0.0276, 0.6625, -0.3748, -0.5049, 1.8260, -0.3631, 1.1546, -1.0913, 0.2712, 1.5493, 1.4294, -2.1245, -2.0422, 0.4976, -1.2785, 0.5028, 1.4240, 1.1983, 0.2468, 1.1682, -0.6725, -1.1198, -1.4942, -0.3629, 0.1325, -0.2256, 0.4280, 0.9830, -1.9427, -0.2181, 1.1850, -0.7514, -0.8172, 2.1031, -0.1698, -0.3777, -0.7863, 1.0936, -1.3720, 0.9999, 1.3302, -0.8954, -0.5999, 2.3305, 0.5702, -1.0767, -0.2750, -0.3741, -0.7026, -1.5408, 0.0667, 1.2550, -0.5117]])} """
[docs] def __init__(self, p, name="p", features_shape=torch.Size()): r""" Parameters ---------- p : pixyz.distributions.Normal Each element of this input vector is considered as a different expert. When some elements are 0, experts corresponding to these elements are considered not to be specified. :math:`p(z|x) = p(z|x_1, x_2=0) \propto p(z)p(z|x_1)` name : str, defaults to "p" Name of this distribution. This name is displayed in prob_text and prob_factorized_text. features_shape : :obj:`torch.Size` or :obj:`list`, defaults to torch.Size()) Shape of dimensions (features) of this distribution. """ if len(p.cond_var) != 1: raise ValueError super().__init__(p=p, name=name, features_shape=features_shape)
def _check_input(self, x, var=None): if var is None: var = self.input_var if type(x) is torch.Tensor: checked_x = {var[0]: x} elif type(x) is list: # TODO: we need to check if all the elements contained in this list are torch.Tensor. checked_x = dict(zip(var, x)) elif type(x) is dict: if not (set(list(x.keys())) >= set(var)): raise ValueError("Input keys are not valid.") checked_x = x else: raise ValueError("The type of input is not valid, got %s." % type(x)) return checked_x @staticmethod def _get_mask(inputs, index): """Get a mask to the input to specify an expert identified by index. Parameters ---------- inputs : torch.Tensor index : int Returns ------- torch.Tensor """ mask = torch.zeros_like(inputs).type(inputs.dtype) mask[:, index] = 1 return mask def _get_params_with_masking(self, inputs, index, **kwargs): """Get the output parameters of the index-specified expert. Parameters ---------- inputs : torch.Tensor index : int **kwargs Arbitrary keyword arguments. Returns ------- outputs : torch.Tensor Examples -------- >>> # pon = ElementWiseProductOfNormal(p) >>> # a = torch.tensor([[1, 0, 0], [0, 1, 0]]) >>> # pon._get_params_with_masking(a, 0) tensor([[[0.01, 0.0131], [0, 0]], # loc [[0.42, 0.39], [1, 1]], # scale ]) >>> # pon._get_params_with_masking(a, 1) tensor([[[0, 0], [0.021, 0.11]], # loc [[1, 1], [0.293, 0.415]], # scale ]) >>> # self._get_params_with_masking(a, 2) tensor([[[0, 0], [0, 0]], # loc [[1, 1], [1, 1]], # scale ]) """ mask = self._get_mask(inputs, index) # (n_batch, n_expert) outputs_dict = self.p.get_params({self.cond_var[0]: inputs * mask}, **kwargs) outputs = torch.stack([outputs_dict["loc"], outputs_dict["scale"]]) # (2, n_batch, output_dim) # When the index-th expert in the output examples is not specified, set zero to them. outputs[:, inputs[:, index] == 0, :] = 0 return outputs def _get_expert_params(self, params_dict={}, **kwargs): """Get the output parameters of all experts. Parameters ---------- params_dict : dict **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor torch.Tensor """ inputs = get_dict_values(params_dict, self.cond_var)[0] # (n_batch, n_expert=input_dim) n_expert = inputs.size()[1] outputs = [self._get_params_with_masking(inputs, i) for i in range(n_expert)] outputs = torch.stack(outputs) # (n_expert, 2, n_batch, output_dim) return outputs[:, 0, :, :], outputs[:, 1, :, :] # (n_expert, n_batch, output_dim)