pixyz.distributions (Distribution API)¶
Distribution¶
-
class
pixyz.distributions.distributions.
Distribution
(cond_var=[], var=['x'], name='p', dim=1)[source]¶ Bases:
torch.nn.modules.module.Module
Distribution class. In pixyz, all distributions are required to inherit this class.
- var : list
- Variables of this distribution.
- cond_var : list
- Conditional variables of this distribution. In case that cond_var is not empty, we must set the corresponding inputs in order to sample variables or estimate the log likelihood.
- dim : int
- Number of dimensions of this distribution. This might be ignored depending on the shape which is set in the sample method and on its parent distribution. Moreover, this is not consider when this class is inherited by DNNs. This is set to 1 by default.
- name : str
- Name of this distribution. This name is displayed in prob_text and prob_factorized_text. This is set to “p” by default.
-
distribution_name
¶
-
name
¶
-
var
¶
-
cond_var
¶
-
input_var
¶ Normally, input_var has same values as cond_var.
-
prob_text
¶
-
prob_factorized_text
¶
-
get_params
(params_dict)[source]¶ This method aims to get parameters of this distributions from constant parameters set in initialization and outputs of DNNs.
- params_dict : dict
- Input parameters.
- output_dict : dict
- Output parameters
>>> print(dist_1.prob_text, dist_1.distribution_name) p(x) Normal >>> dist_1.get_params() {"loc": 0, "scale": 1} >>> print(dist_2.prob_text, dist_2.distribution_name) p(x|z) Normal >>> dist_1.get_params({"z": 1}) {"loc": 0, "scale": 1}
-
sample
(x={}, shape=None, batch_size=1, return_all=True, reparam=False)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.
-
log_likelihood
(x_dict)[source]¶ Estimate the log likelihood of this distribution from inputs formatted by a dictionary.
- x_dict : dict
- Input samples.
- log_like : torch.Tensor
- Log-likelihood.
Exponential families¶
Normal¶
Bernoulli¶
RelaxedBernoulli¶
-
class
pixyz.distributions.
RelaxedBernoulli
(temperature, cond_var=[], var=['x'], name='p', dim=1, **kwargs)[source]¶ Bases:
pixyz.distributions.distributions.DistributionBase
-
distribution_name
¶
-
FactorizedBernoulli¶
Categorical¶
RelaxedCategorical¶
-
class
pixyz.distributions.
RelaxedCategorical
(temperature, cond_var=[], var=['x'], name='p', dim=1, **kwargs)[source]¶ Bases:
pixyz.distributions.distributions.DistributionBase
-
distribution_name
¶
-
Special distributions¶
NormalPoE¶
Deterministic¶
-
class
pixyz.distributions.
Deterministic
(**kwargs)[source]¶ Bases:
pixyz.distributions.distributions.Distribution
Deterministic distribution (or degeneration distribution)
-
distribution_name
¶
-
sample
(x={}, return_all=True, **kwargs)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.
-
DataDistribution¶
-
class
pixyz.distributions.
DataDistribution
(var, name='p_data')[source]¶ Bases:
pixyz.distributions.distributions.Distribution
Data distribution. TODO: Fix this behavior if multiplied with other distributions
-
distribution_name
¶
-
sample
(x={}, **kwargs)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.
-
input_var
¶ In DataDistribution, input_var is same as var.
-
Flow-based¶
PlanarFlow¶
RealNVP¶
-
class
pixyz.distributions.
RealNVP
(prior, dim, num_multiscale_layers=2, var=[], image=False, name='p', **kwargs)[source]¶ Bases:
pixyz.distributions.distributions.Distribution
-
prob_text
¶
-
forward
(x, inverse=False, jacobian=False)[source]¶ When this class is inherited by DNNs, it is also intended that this method is overrided.
-
sample
(x={}, only_flow=False, return_all=True, **kwargs)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.
-
Operators¶
ReplaceVarDistribution¶
-
class
pixyz.distributions.distributions.
ReplaceVarDistribution
(a, replace_dict)[source]¶ Bases:
pixyz.distributions.distributions.Distribution
Replace names of variables in Distribution.
- a : pixyz.Distribution (not pixyz.MultiplyDistribution)
- Distribution.
- replace_dict : dict
- Dictionary.
-
forward
(*args, **kwargs)[source]¶ When this class is inherited by DNNs, it is also intended that this method is overrided.
-
get_params
(params_dict)[source]¶ This method aims to get parameters of this distributions from constant parameters set in initialization and outputs of DNNs.
- params_dict : dict
- Input parameters.
- output_dict : dict
- Output parameters
>>> print(dist_1.prob_text, dist_1.distribution_name) p(x) Normal >>> dist_1.get_params() {"loc": 0, "scale": 1} >>> print(dist_2.prob_text, dist_2.distribution_name) p(x|z) Normal >>> dist_1.get_params({"z": 1}) {"loc": 0, "scale": 1}
-
sample
(x={}, shape=None, batch_size=1, return_all=True, reparam=False)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.
-
log_likelihood
(x)[source]¶ Estimate the log likelihood of this distribution from inputs formatted by a dictionary.
- x_dict : dict
- Input samples.
- log_like : torch.Tensor
- Log-likelihood.
-
input_var
¶ Normally, input_var has same values as cond_var.
-
distribution_name
¶
MarginalizeVarDistribution¶
-
class
pixyz.distributions.distributions.
MarginalizeVarDistribution
(a, marginalize_list)[source]¶ Bases:
pixyz.distributions.distributions.Distribution
Marginalize variables in Distribution.
- a : pixyz.Distribution (not pixyz.DistributionBase)
- Distribution.
- marginalize_list : list
- Variables to marginalize.
-
forward
(*args, **kwargs)[source]¶ When this class is inherited by DNNs, it is also intended that this method is overrided.
-
get_params
(params_dict)[source]¶ This method aims to get parameters of this distributions from constant parameters set in initialization and outputs of DNNs.
- params_dict : dict
- Input parameters.
- output_dict : dict
- Output parameters
>>> print(dist_1.prob_text, dist_1.distribution_name) p(x) Normal >>> dist_1.get_params() {"loc": 0, "scale": 1} >>> print(dist_2.prob_text, dist_2.distribution_name) p(x|z) Normal >>> dist_1.get_params({"z": 1}) {"loc": 0, "scale": 1}
-
sample
(x={}, shape=None, batch_size=1, return_all=True, reparam=False)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.
-
log_likelihood
(x)[source]¶ Estimate the log likelihood of this distribution from inputs formatted by a dictionary.
- x_dict : dict
- Input samples.
- log_like : torch.Tensor
- Log-likelihood.
-
input_var
¶ Normally, input_var has same values as cond_var.
-
distribution_name
¶
-
prob_factorized_text
¶
MultiplyDistribution¶
-
class
pixyz.distributions.distributions.
MultiplyDistribution
(a, b)[source]¶ Bases:
pixyz.distributions.distributions.Distribution
Multiply by given distributions, e.g,
. In this class, it is checked if two distributions can be multiplied.
p(x|z)p(z|y) -> Valid
p(x|z)p(y|z) -> Valid
p(x|z)p(y|a) -> Valid
p(x|z)p(z|x) -> Invalid (recursive)
p(x|z)p(x|y) -> Invalid (conflict)
- a : pixyz.Distribution
- Distribution.
- b : pixyz.Distribution
- Distribution.
>>> p_multi = MultipleDistribution([a, b]) >>> p_multi = a * b
-
inh_var
¶
-
input_var
¶ Normally, input_var has same values as cond_var.
-
prob_factorized_text
¶
-
sample
(x={}, shape=None, batch_size=1, return_all=True, reparam=False)[source]¶ Sample variables of this distribution. If cond_var is not empty, we should set inputs as a dictionary format.
- x : torch.Tensor, list, or dict
- Input variables.
- shape : tuple
- Shape of samples. If set, batch_size and dim are ignored.
- batch_size : int
- Batch size of samples. This is set to 1 by default.
- return_all : bool
- Choose whether the output contains input variables.
- reparam : bool
- Choose whether we sample variables with reparameterized trick.
- output : dict
- Samples of this distribution.