from __future__ import print_function
from .distributions import Distribution
[docs]class Deterministic(Distribution):
"""
Deterministic distribution (or degeneration distribution)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@property
def distribution_name(self):
return "Deterministic"
[docs] def sample(self, x={}, return_all=True, **kwargs):
if len(x) > 0:
x_dict = self._check_input(x)
output_dict = self.forward(**x_dict)
if set(output_dict.keys()) != set(self._var):
raise ValueError("Output variables are not same as `var`.")
if return_all:
output_dict.update(x_dict)
return output_dict
raise ValueError("You should set inputs.")
[docs]class DataDistribution(Distribution):
"""
Data distribution.
TODO: Fix this behavior if multiplied with other distributions
"""
def __init__(self, var, name="p_data"):
super().__init__(var=var, cond_var=[], name=name, dim=1)
@property
def distribution_name(self):
return "Data distribution"
[docs] def sample(self, x={}, **kwargs):
if len(x) > 0:
output_dict = self._check_input(x)
return output_dict
raise ValueError("You should set inputs.")
@property
def input_var(self):
"""
In DataDistribution, `input_var` is same as `var`.
"""
return self.var