Source code for pixyz.distributions.exponential_distributions

import torch
from torch.distributions import Normal as NormalTorch
from torch.distributions import Bernoulli as BernoulliTorch
from torch.distributions import RelaxedBernoulli as RelaxedBernoulliTorch
from torch.distributions import RelaxedOneHotCategorical as RelaxedOneHotCategoricalTorch
from torch.distributions.one_hot_categorical import OneHotCategorical as CategoricalTorch
from torch.distributions import Multinomial as MultinomialTorch
from torch.distributions import Dirichlet as DirichletTorch
from torch.distributions import Beta as BetaTorch
from torch.distributions import Laplace as LaplaceTorch
from torch.distributions import Gamma as GammaTorch

from ..utils import get_dict_values, sum_samples
from .distributions import DistributionBase


[docs]class Normal(DistributionBase): """Normal distribution parameterized by :attr:`loc` and :attr:`scale`. """ @property def params_keys(self): return ["loc", "scale"] @property def distribution_torch_class(self): return NormalTorch @property def distribution_name(self): return "Normal"
[docs]class Bernoulli(DistributionBase): """Bernoulli distribution parameterized by :attr:`probs`.""" @property def params_keys(self): return ["probs"] @property def distribution_torch_class(self): return BernoulliTorch @property def distribution_name(self): return "Bernoulli"
[docs]class RelaxedBernoulli(Bernoulli): """Relaxed (re-parameterizable) Bernoulli distribution parameterized by :attr:`probs`.""" def __init__(self, temperature=torch.tensor(0.1), cond_var=[], var=["x"], name="p", features_shape=torch.Size(), **kwargs): self._temperature = temperature super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **kwargs) @property def temperature(self): return self._temperature @property def distribution_torch_class(self): return BernoulliTorch @property def relaxed_distribution_torch_class(self): """Use relaxed version only when sampling""" return RelaxedBernoulliTorch @property def distribution_name(self): return "RelaxedBernoulli"
[docs] def set_dist(self, x_dict={}, sampling=True, batch_n=None, **kwargs): params = self.get_params(x_dict) if sampling is True: self._dist = self.relaxed_distribution_torch_class(temperature=self.temperature, **params) else: self._dist = self.distribution_torch_class(**params) # expand batch_n if batch_n: batch_shape = self._dist.batch_shape if batch_shape[0] == 1: self._dist = self._dist.expand(torch.Size([batch_n]) + batch_shape[1:]) elif batch_shape[0] == batch_n: return else: raise ValueError
[docs]class FactorizedBernoulli(Bernoulli): """ Factorized Bernoulli distribution parameterized by :attr:`probs`. References ---------- [Vedantam+ 2017] Generative Models of Visually Grounded Imagination """ @property def distribution_name(self): return "FactorizedBernoulli"
[docs] def get_log_prob(self, x_dict): log_prob = super().get_log_prob(x_dict, sum_features=False) [_x] = get_dict_values(x_dict, self._var) log_prob[_x == 0] = 0 log_prob = sum_samples(log_prob) return log_prob
[docs]class Categorical(DistributionBase): """Categorical distribution parameterized by :attr:`probs`.""" @property def params_keys(self): return ["probs"] @property def distribution_torch_class(self): return CategoricalTorch @property def distribution_name(self): return "Categorical"
[docs]class RelaxedCategorical(Categorical): """Relaxed (re-parameterizable) categorical distribution parameterized by :attr:`probs`.""" def __init__(self, temperature=torch.tensor(0.1), cond_var=[], var=["x"], name="p", features_shape=torch.Size(), **kwargs): self._temperature = temperature super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **kwargs) @property def temperature(self): return self._temperature @property def distribution_torch_class(self): return CategoricalTorch @property def relaxed_distribution_torch_class(self): """Use relaxed version only when sampling""" return RelaxedOneHotCategoricalTorch @property def distribution_name(self): return "RelaxedCategorical"
[docs] def set_dist(self, x_dict={}, sampling=True, batch_n=None, **kwargs): params = self.get_params(x_dict) if sampling is True: self._dist = self.relaxed_distribution_torch_class(temperature=self.temperature, **params) else: self._dist = self.distribution_torch_class(**params) # expand batch_n if batch_n: batch_shape = self._dist.batch_shape if batch_shape[0] == 1: self._dist = self._dist.expand(torch.Size([batch_n]) + batch_shape[1:]) elif batch_shape[0] == batch_n: return else: raise ValueError
[docs] def sample_mean(self, x_dict={}): self.set_dist(x_dict, sampling=False) return self.dist.mean
[docs] def sample_variance(self, x_dict={}): self.set_dist(x_dict, sampling=False) return self.dist.variance
class Multinomial(DistributionBase): """Multinomial distribution parameterized by :attr:`total_count` and :attr:`probs`.""" def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), total_count=1, **kwargs): self._total_count = total_count super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **kwargs) @property def total_count(self): return self._total_count @property def params_keys(self): return ["probs"] @property def distribution_torch_class(self): return MultinomialTorch @property def distribution_name(self): return "Multinomial"
[docs]class Dirichlet(DistributionBase): """Dirichlet distribution parameterized by :attr:`concentration`.""" @property def params_keys(self): return ["concentration"] @property def distribution_torch_class(self): return DirichletTorch @property def distribution_name(self): return "Dirichlet"
[docs]class Beta(DistributionBase): """Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.""" @property def params_keys(self): return ["concentration1", "concentration0"] @property def distribution_torch_class(self): return BetaTorch @property def distribution_name(self): return "Beta"
[docs]class Laplace(DistributionBase): """ Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. """ @property def params_keys(self): return ["loc", "scale"] @property def distribution_torch_class(self): return LaplaceTorch @property def distribution_name(self): return "Laplace"
[docs]class Gamma(DistributionBase): """ Gamma distribution parameterized by :attr:`concentration` and :attr:`rate`. """ @property def params_keys(self): return ["concentration", "rate"] @property def distribution_torch_class(self): return GammaTorch @property def distribution_name(self): return "Gamma"