Source code for pixyz.distributions.exponential_distributions

from torch.distributions import Normal as NormalTorch
from torch.distributions import Bernoulli as BernoulliTorch
from torch.distributions import RelaxedBernoulli as RelaxedBernoulliTorch
from torch.distributions \
    import RelaxedOneHotCategorical as RelaxedOneHotCategoricalTorch
from torch.distributions.one_hot_categorical\
    import OneHotCategorical as CategoricalTorch
from torch.distributions import Dirichlet as DirichletTorch
from torch.distributions import Beta as BetaTorch
from torch.distributions import Laplace as LaplaceTorch
from torch.distributions import Gamma as GammaTorch

from ..utils import get_dict_values
from .distributions import DistributionBase, sum_samples


[docs]class Normal(DistributionBase): """ Normal distribution parameterized by :attr:`loc` and :attr:`scale`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["loc", "scale"] self.DistributionTorch = NormalTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Normal"
[docs]class Bernoulli(DistributionBase): """ Bernoulli distribution parameterized by :attr:`probs`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = BernoulliTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Bernoulli"
[docs]class RelaxedBernoulli(DistributionBase): """ Relaxed (reparameterizable) Bernoulli distribution parameterized by :attr:`probs`. """ def __init__(self, temperature, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = BernoulliTorch # use relaxed version only when sampling self.RelaxedDistributionTorch = RelaxedBernoulliTorch self.temperature = temperature super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "RelaxedBernoulli"
[docs] def set_distribution(self, x={}, sampling=True, **kwargs): params = self.get_params(x, **kwargs) if sampling is True: self.dist =\ self.RelaxedDistributionTorch(temperature=self.temperature, **params) else: self.dist = self.DistributionTorch(**params)
[docs] def log_likelihood(self, x): # input : dict # output : dict if not set(list(x.keys())) >= set(self._cond_var + self._var): raise ValueError("Input keys are not valid.") if len(self._cond_var) > 0: # conditional distribution _x = get_dict_values(x, self._cond_var, True) self.set_distribution(_x, sampling=False) log_like = self._get_log_like(x) return sum_samples(log_like)
[docs]class FactorizedBernoulli(Bernoulli): """ Factorized Bernoulli distribution parameterized by :attr:`probs`. See `Generative Models of Visually Grounded Imagination` """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "FactorizedBernoulli" def _get_log_like(self, x): log_like = super()._get_log_like(x) [_x] = get_dict_values(x, self._var) log_like[_x == 0] = 0 return log_like
[docs]class Categorical(DistributionBase): """ Categorical distribution parameterized by :attr:`probs`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = CategoricalTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Categorical"
[docs]class RelaxedCategorical(DistributionBase): """ Relaxed (reparameterizable) categorical distribution parameterized by :attr:`probs`. """ def __init__(self, temperature, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = CategoricalTorch # use relaxed version only when sampling self.RelaxedDistributionTorch = RelaxedOneHotCategoricalTorch self.temperature = temperature super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "RelaxedCategorical"
[docs] def set_distribution(self, x={}, sampling=True, **kwargs): params = self.get_params(x, **kwargs) if sampling is True: self.dist =\ self.RelaxedDistributionTorch(temperature=self.temperature, **params) else: self.dist = self.DistributionTorch(**params)
[docs] def log_likelihood(self, x): # input : dict # output : dict if not set(list(x.keys())) >= set(self._cond_var + self._var): raise ValueError("Input keys are not valid.") if len(self._cond_var) > 0: # conditional distribution _x = get_dict_values(x, self._cond_var, True) self.set_distribution(_x, sampling=False) log_like = self._get_log_like(x) return sum_samples(log_like)
[docs]class Dirichlet(DistributionBase): """ Dirichlet distribution parameterized by :attr:`concentration`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["concentration"] self.DistributionTorch = DirichletTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Dirichlet"
[docs]class Beta(DistributionBase): """ Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["concentration1", "concentration0"] self.DistributionTorch = BetaTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Beta"
[docs]class Laplace(DistributionBase): """ Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["loc", "scale"] self.DistributionTorch = LaplaceTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Laplace"
[docs]class Gamma(DistributionBase): """ Gamma distribution parameterized by :attr:`concentration` and :attr:`rate`. """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["concentration", "rate"] self.DistributionTorch = GammaTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Gamma"