Source code for pixyz.distributions.custom_distributions

from ..utils import get_dict_values, sum_samples
from .distributions import Distribution


[docs]class CustomProb(Distribution): """This distribution is constructed by user-defined probability density/mass function. Note that this distribution cannot perform sampling. Examples -------- >>> import torch >>> # banana shaped distribution >>> def log_prob(z): ... z1, z2 = torch.chunk(z, chunks=2, dim=1) ... norm = torch.sqrt(z1 ** 2 + z2 ** 2) ... exp1 = torch.exp(-0.5 * ((z1 - 2) / 0.6) ** 2) ... exp2 = torch.exp(-0.5 * ((z1 + 2) / 0.6) ** 2) ... u = 0.5 * ((norm - 2) / 0.4) ** 2 - torch.log(exp1 + exp2) ... return -u ... >>> p = CustomProb(log_prob, var=["z"]) >>> loss = p.log_prob().eval({"z": torch.randn(10, 2)}) """
[docs] def __init__(self, log_prob_function, var, distribution_name="Custom PDF", **kwargs): """ Parameters ---------- log_prob_function : function User-defined log-probability density/mass function. var : list Variables of this distribution. distribution_name : :obj:`str`, optional Name of this distribution. +*kwargs : Arbitrary keyword arguments. """ self._log_prob_function = log_prob_function self._distribution_name = distribution_name super().__init__(var=var, cond_var=[], **kwargs)
@property def log_prob_function(self): """User-defined log-probability density/mass function.""" return self._log_prob_function @property def input_var(self): return self.var @property def distribution_name(self): return self._distribution_name
[docs] def get_log_prob(self, x_dict, sum_features=True, feature_dims=None): x_dict = get_dict_values(x_dict, self._var, return_dict=True) log_prob = self.log_prob_function(**x_dict) if sum_features: log_prob = sum_samples(log_prob) return log_prob
@property def has_reparam(self): return False