from __future__ import print_function
from .distributions import Distribution
[docs]class Deterministic(Distribution):
"""
Deterministic distribution (or degeneration distribution)
Examples
--------
>>> import torch
>>> class Generator(Deterministic):
... def __init__(self):
... super().__init__(cond_var=["z"], var=["x"])
... self.model = torch.nn.Linear(64, 512)
... def forward(self, z):
... return {"x": self.model(z)}
>>> p = Generator()
>>> print(p)
Distribution:
p(x|z)
Network architecture:
Generator(
name=p, distribution_name=Deterministic,
var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
(model): Linear(in_features=64, out_features=512, bias=True)
)
>>> sample = p.sample({"z": torch.randn(1, 64)})
>>> p.log_prob().eval(sample) # log_prob is not defined.
Traceback (most recent call last):
...
NotImplementedError
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@property
def distribution_name(self):
return "Deterministic"
[docs] def sample(self, x_dict={}, return_all=True, **kwargs):
input_dict = self._get_input_dict(x_dict)
output_dict = self.forward(**input_dict)
if set(output_dict.keys()) != set(self._var):
raise ValueError("Output variables are not the same as `var`.")
if return_all:
x_dict = x_dict.copy()
x_dict.update(output_dict)
return x_dict
return output_dict
[docs] def sample_mean(self, x_dict):
return self.sample(x_dict, return_all=False)[self._var[0]]
@property
def has_reparam(self):
return True
[docs]class DataDistribution(Distribution):
"""
Data distribution.
Samples from this distribution equal given inputs.
Examples
--------
>>> import torch
>>> p = DataDistribution(var=["x"])
>>> print(p)
Distribution:
p_{data}(x)
Network architecture:
DataDistribution(
name=p_{data}, distribution_name=Data distribution,
var=['x'], cond_var=[], input_var=['x'], features_shape=torch.Size([])
)
>>> sample = p.sample({"x": torch.randn(1, 64)})
"""
def __init__(self, var, name="p_{data}"):
super().__init__(var=var, cond_var=[], name=name)
@property
def distribution_name(self):
return "Data distribution"
[docs] def sample(self, x_dict={}, **kwargs):
output_dict = self._get_input_dict(x_dict)
return output_dict
[docs] def sample_mean(self, x_dict):
return self.sample(x_dict, return_all=False)[self._var[0]]
@property
def input_var(self):
"""
In DataDistribution, `input_var` is same as `var`.
"""
return self.var
@property
def has_reparam(self):
return True