Source code for pixyz.distributions.exponential_distributions

from torch.distributions import Normal as NormalTorch
from torch.distributions import Bernoulli as BernoulliTorch
from torch.distributions import RelaxedBernoulli as RelaxedBernoulliTorch
from torch.distributions \
    import RelaxedOneHotCategorical as RelaxedOneHotCategoricalTorch
from torch.distributions.one_hot_categorical\
    import OneHotCategorical as CategoricalTorch

from ..utils import get_dict_values
from .distributions import DistributionBase, sum_samples


[docs]class Normal(DistributionBase): def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["loc", "scale"] self.DistributionTorch = NormalTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Normal"
[docs] def sample_mean(self, x): params = self.forward(**x) return params["loc"]
[docs]class Bernoulli(DistributionBase): def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = BernoulliTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Bernoulli"
[docs] def sample_mean(self, x): params = self.forward(**x) return params["probs"]
[docs]class RelaxedBernoulli(DistributionBase): def __init__(self, temperature, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = BernoulliTorch # use relaxed version only when sampling self.RelaxedDistributionTorch = RelaxedBernoulliTorch self.temperature = temperature super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "RelaxedBernoulli" def _set_distribution(self, x={}, sampling=True, **kwargs): params = self.get_params(x, **kwargs) if sampling is True: self.dist =\ self.RelaxedDistributionTorch(temperature=self.temperature, **params) else: self.dist = self.DistributionTorch(**params)
[docs] def log_likelihood(self, x): # input : dict # output : dict if not set(list(x.keys())) >= set(self._cond_var + self._var): raise ValueError("Input's keys are not valid.") if len(self._cond_var) > 0: # conditional distribution _x = get_dict_values(x, self._cond_var, True) self._set_distribution(_x, sampling=False) log_like = self._get_log_like(x) return sum_samples(log_like)
[docs] def sample_mean(self, x): params = self.forward(**x) return params["probs"]
[docs]class FactorizedBernoulli(Bernoulli): """ Generative Models of Visually Grounded Imagination """ def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "FactorizedBernoulli" def _get_log_like(self, x): log_like = super()._get_log_like(x) [_x] = get_dict_values(x, self._var) log_like[_x == 0] = 0 return log_like
[docs]class Categorical(DistributionBase): def __init__(self, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = CategoricalTorch super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "Categorical"
[docs] def sample_mean(self, x): params = self.forward(**x) return params["probs"]
[docs]class RelaxedCategorical(DistributionBase): def __init__(self, temperature, cond_var=[], var=["x"], name="p", dim=None, **kwargs): self.params_keys = ["probs"] self.DistributionTorch = CategoricalTorch # use relaxed version only when sampling self.RelaxedDistributionTorch = RelaxedOneHotCategoricalTorch self.temperature = temperature super().__init__(cond_var=cond_var, var=var, name=name, dim=dim, **kwargs) @property def distribution_name(self): return "RelaxedCategorical" def _set_distribution(self, x={}, sampling=True, **kwargs): params = self.get_params(x, **kwargs) if sampling is True: self.dist =\ self.RelaxedDistributionTorch(temperature=self.temperature, **params) else: self.dist = self.DistributionTorch(**params)
[docs] def log_likelihood(self, x): # input : dict # output : dict if not set(list(x.keys())) >= set(self._cond_var + self._var): raise ValueError("Input's keys are not valid.") if len(self._cond_var) > 0: # conditional distribution _x = get_dict_values(x, self._cond_var, True) self._set_distribution(_x, sampling=False) log_like = self._get_log_like(x) return sum_samples(log_like)
[docs] def sample_mean(self, x): params = self.forward(**x) return params["probs"]