# Source code for pixyz.distributions.mixture_distributions

import torch
from torch import nn

from ..distributions.distributions import Distribution
from ..utils import convert_latex_name

[docs]class MixtureModel(Distribution):
r"""Mixture models.

.. math::

p(x) = \sum_i p(x|z=i)p(z=i)

Examples
--------
>>> from pixyz.distributions import Normal, Categorical
>>> from pixyz.distributions.mixture_distributions import MixtureModel
>>> z_dim = 3  # the number of mixture
>>> x_dim = 2  # the input dimension.
>>> distributions = []  # the list of distributions
>>> for i in range(z_dim):
...     loc = torch.randn(x_dim)  # initialize the value of location (mean)
...     scale = torch.empty(x_dim).fill_(1.)  # initialize the value of scale (variance)
...     distributions.append(Normal(loc=loc, scale=scale, var=["x"], name="p_%d" %i))
>>> probs = torch.empty(z_dim).fill_(1. / z_dim)  # initialize the value of probabilities
>>> prior = Categorical(probs=probs, var=["z"], name="prior")
>>> p = MixtureModel(distributions=distributions, prior=prior)
>>> print(p)
Distribution:
p(x) = p_{0}(x|z=0)prior(z=0) + p_{1}(x|z=1)prior(z=1) + p_{2}(x|z=2)prior(z=2)
Network architecture:
MixtureModel(
name=p, distribution_name=Mixture Model,
var=['x'], cond_var=[], input_var=[], features_shape=torch.Size([])
(distributions): ModuleList(
(0): Normal(
name=p_{0}, distribution_name=Normal,
var=['x'], cond_var=[], input_var=[], features_shape=torch.Size([2])
(loc): torch.Size([1, 2])
(scale): torch.Size([1, 2])
)
(1): Normal(
name=p_{1}, distribution_name=Normal,
var=['x'], cond_var=[], input_var=[], features_shape=torch.Size([2])
(loc): torch.Size([1, 2])
(scale): torch.Size([1, 2])
)
(2): Normal(
name=p_{2}, distribution_name=Normal,
var=['x'], cond_var=[], input_var=[], features_shape=torch.Size([2])
(loc): torch.Size([1, 2])
(scale): torch.Size([1, 2])
)
)
(prior): Categorical(
name=prior, distribution_name=Categorical,
var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([3])
(probs): torch.Size([1, 3])
)
)
"""

[docs]    def __init__(self, distributions, prior, name="p"):
"""
Parameters
----------
distributions : list
List of distributions.
prior : pixyz.Distribution.Categorical
Prior distribution of latent variable (i.e., a contribution rate).
This should be a categorical distribution and
the number of its category should be the same as the length of :attr:distributions.
name : :obj:str, defaults to "p"
Name of this distribution.
This name is displayed in :attr:prob_text and :attr:prob_factorized_text.

"""
if not isinstance(distributions, list):
raise ValueError
else:
distributions = nn.ModuleList(distributions)

if prior.distribution_name != "Categorical":
raise ValueError("The prior must be the categorical distribution.")

# check the number of mixture
if prior.get_params()["probs"].shape[-1] != len(distributions):
raise ValueError("The number of its category must be the same as the length of the distribution list.")

# check whether all distributions have the same variable.
var_list = []
for d in distributions:
var_list += d.var
var_list = list(set(var_list))

if len(var_list) != 1:
raise ValueError("All distributions must have the same variable.")

hidden_var = prior.var

super().__init__(var=var_list, name=name)

self.distributions = distributions
self.prior = prior

self._hidden_var = hidden_var

@property
def hidden_var(self):
"""list: Hidden variables of this distribution."""
return self._hidden_var

@property
def prob_text(self):
_prob_text = "{}({})".format(
self._name, ','.join(([convert_latex_name(var_name) for var_name in self._var]))
)

return _prob_text

@property
def prob_factorized_text(self):
_mixture_prob_text = []
for i, d in enumerate(self.distributions):
_mixture_prob_text.append("{}({}|{}={}){}({}={})".format(
d.name, self.var[0], self._hidden_var[0], i,
self.prior.name, self._hidden_var[0], i
))

_prob_text = ' + '.join(_mixture_prob_text)

return _prob_text

@property
def distribution_name(self):
return "Mixture Model"

[docs]    def posterior(self, name=None):
return PosteriorMixtureModel(self, name=name)

[docs]    def sample(self, batch_n=None, sample_shape=torch.Size(), return_hidden=False, **kwargs):
# sample from prior
hidden_output = self.prior.sample(batch_n=batch_n)[self._hidden_var[0]]

var_output = []
for _hidden_output in hidden_output:
var_output.append(self.distributions[_hidden_output.argmax(dim=-1)].sample()[self._var[0]])

var_output = torch.cat(var_output, dim=0)
output_dict = {self._var[0]: var_output}

if return_hidden:
output_dict.update({self._hidden_var[0]: hidden_output})

return output_dict

[docs]    def get_log_prob(self, x_dict, return_hidden=False, **kwargs):
"""Evaluate log-pdf, log p(x) (if return_hidden=False) or log p(x, z) (if return_hidden=True).

Parameters
----------
x_dict : dict
Input variables (including var).

return_hidden : :obj:bool, defaults to False

Returns
-------
log_prob : torch.Tensor
The log-pdf value of x.

return_hidden = 0 :
dim=0 : the size of batch

return_hidden = 1 :
dim=0 : the number of mixture

dim=1 : the size of batch

"""

log_prob_all = []

_device = x_dict[self._var[0]].device
eye_tensor = torch.eye(len(self.distributions)).to(_device)  # for prior

for i, d in enumerate(self.distributions):
# p(z=i)
prior_log_prob = self.prior.log_prob().eval({self._hidden_var[0]: eye_tensor[i]})
# p(x|z=i)
log_prob = d.log_prob().eval(x_dict)
# p(x, z=i)
log_prob_all.append(log_prob + prior_log_prob)

log_prob_all = torch.stack(log_prob_all, dim=0)  # (num_mix, batch_size)

if return_hidden:
return log_prob_all

class PosteriorMixtureModel(Distribution):
def __init__(self, p, name=None):
if name is None:
name = p.name
super().__init__(var=p.var, name=name)

self.p = p
self._hidden_var = p.hidden_var

@property
def hidden_var(self):
"""list: Hidden variables of this distribution."""
return self._hidden_var

@property
def prob_text(self):
_prob_text = "{}({}|{})".format(
self._name, convert_latex_name(self._hidden_var[0]), convert_latex_name(self._var[0])
)
return _prob_text

@property
def prob_factorized_text(self):
numinator = "{" + "{}({},{})".format(self._name, self._hidden_var[0], self._var[0]) + "}"
denominator = "{" + "{}({})".format(self._name, self._var[0]) + "}"

_prob_text = "\\frac{}{}".format(numinator, denominator)

return _prob_text

@property
def distribution_name(self):
return "Mixture Model (Posterior)"

def sample(self, *args, **kwargs):
raise NotImplementedError

def get_log_prob(self, x_dict, **kwargs):
# log p(z|x) = log p(x, z) - log p(x)
log_prob = self.p.get_log_prob(x_dict, return_hidden=True) - self.p.get_log_prob(x_dict)
return log_prob  # (num_mix, batch_size)