Source code for pixyz.distributions.mixture_distributions

import torch
from torch import nn

from ..utils import get_dict_values
from .distributions import Distribution


[docs]class MixtureModel(Distribution): """ Mixture models. :math:`p(x) = \sum_i p(x|z=i)p(z=i)` Parameters ---------- distributions : list List of distributions. prior : pixyz.Distribution.Categorical Prior distribution of latent variable (i.e., the contribution rate). This should be a categorical distribution and the number of its category should be the same as the length of the distribution list. Examples -------- >>> from pixyz.distributions import Normal, Categorical >>> from pixyz.distributions.mixture_distributions import MixtureModel >>> >>> z_dim = 3 # the number of mixture >>> x_dim = 2 # the input dimension. >>> >>> distributions = [] # the list of distributions >>> for i in range(z_dim): >>> loc = torch.randn(x_dim) # initialize the value of location (mean) >>> scale = torch.empty(x_dim).fill_(1.) # initialize the value of scale (variance) >>> distributions.append(Normal(loc=loc, scale=scale, var=["x"], name="p_%d" %i)) >>> >>> probs = torch.empty(z_dim).fill_(1. / z_dim) # initialize the value of probabilities >>> prior = Categorical(probs=probs, var=["z"], name="prior") >>> >>> p = MixtureModel(distributions=distributions, prior=prior) """ def __init__(self, distributions, prior, name="p"): if not isinstance(distributions, list): raise ValueError else: distributions = nn.ModuleList(distributions) if prior.distribution_name != "Categorical": raise ValueError("The prior must be the categorical distribution.") # check the number of mixture if len(prior.get_params()["probs"]) != len(distributions): raise ValueError("The number of its category must be the same as the length of the distribution list.") # check whether all distributions have the same variable. var_list = [] for d in distributions: var_list += d.var var_list = list(set(var_list)) if len(var_list) != 1: raise ValueError("All distributions must have the same variable.") hidden_var = prior.var super().__init__(var=var_list, name=name) self._distributions = distributions self._prior = prior self._hidden_var = hidden_var @property def prob_text(self): _prob_text = "{}({})".format( self._name, ','.join(self._var) ) return _prob_text @property def prob_factorized_text(self): _mixture_prob_text = [] for i, d in enumerate(self._distributions): _mixture_prob_text.append("{}({}|{}={}){}({}={})".format( d.name, self._var[0], self._hidden_var[0], i, self._prior.name, self._hidden_var[0], i )) _prob_text = ' + '.join(_mixture_prob_text) return _prob_text @property def distribution_name(self): return "Mixture Model"
[docs] def posterior(self, name=None): return PosteriorMixtureModel(self, name=name)
[docs] def sample(self, batch_size=1, return_hidden=False, **kwargs): hidden_output = [] var_output = [] for i in range(batch_size): # sample from prior _hidden_output = self._prior.sample()[self._hidden_var[0]] hidden_output.append(_hidden_output) var_output.append(self._distributions[ _hidden_output.argmax(dim=-1)].sample()[self._var[0]]) output_dict = {self._var[0]: torch.cat(var_output, 0)} if return_hidden: output_dict.update({self._hidden_var[0]: torch.cat(hidden_output, 0)}) return output_dict
[docs] def get_log_prob(self, x_dict, return_hidden=False, **kwargs): """ Evaluate log-pdf, log p(x) (if return_hidden=False) or log p(x, z) (if return_hidden=True). Parameters ---------- x_dict : dict Input variables (including `var`). return_hidden : bool (False as default) Returns ------- log_prob : torch.Tensor The log-pdf value of x. return_hidden = 0 : dim=0 : the size of batch return_hidden = 1 : dim=0 : the number of mixture dim=1 : the size of batch """ log_prob_all = [] _device = x_dict[self._var[0]].device eye_tensor = torch.eye(len(self._distributions)).to(_device) # for prior for i, d in enumerate(self._distributions): # p(z=i) prior_log_prob = self._prior.log_prob().eval({self._hidden_var[0]: eye_tensor[i]}) # p(x|z=i) log_prob = d.log_prob().eval(x_dict) # p(x, z=i) log_prob_all.append(log_prob + prior_log_prob) log_prob_all = torch.stack(log_prob_all, dim=0) # (num_mix, batch_size) if return_hidden: return log_prob_all return torch.logsumexp(log_prob_all, 0)
class PosteriorMixtureModel(Distribution): def __init__(self, p, name=None): if name is None: name = p.name super().__init__(var=p.var, name=name) self._p = p self._hidden_var = p._hidden_var @property def prob_text(self): _prob_text = "{}({}|{})".format( self._name, self._hidden_var[0], self._var[0] ) return _prob_text @property def prob_factorized_text(self): _prob_text = "{}({},{})/{}({})".format( self._name, self._hidden_var[0], self._var[0], self._name, self._var[0]) return _prob_text @property def distribution_name(self): return "Mixture Model (Posterior)" def sample(self, x={}, shape=None, batch_size=1, return_all=True, reparam=False): raise NotImplementedError def get_log_prob(self, x_dict, **kwargs): # log p(z|x) = log p(x, z) - log p(x) log_prob = self._p.get_log_prob(x_dict, return_hidden=True) - self._p.get_log_prob(x_dict) return log_prob # (num_mix, batch_size)