Source code for pixyz.losses.pdf

import sympy
import torch
from .losses import Loss


[docs]class LogProb(Loss): r""" The log probability density/mass function. .. math:: \log p(x) Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = LogProb(p) # or p.log_prob() >>> print(loss_cls) \log p(x) >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor([12.9894, 15.5280]) """ def __init__(self, p, sum_features=True, feature_dims=None): input_var = p.var + p.cond_var super().__init__(input_var=input_var) self.sum_features = sum_features self.feature_dims = feature_dims self.p = p @property def _symbol(self): return sympy.Symbol("\\log {}".format(self.p.prob_text))
[docs] def forward(self, x={}, **kwargs): log_prob = self.p.get_log_prob(x, sum_features=self.sum_features, feature_dims=self.feature_dims, **kwargs) return log_prob, {}
[docs]class Prob(LogProb): r""" The probability density/mass function. .. math:: p(x) = \exp(\log p(x)) Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = Prob(p) # or p.prob() >>> print(loss_cls) p(x) >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor([3.2903e-07, 5.5530e-07]) """ @property def _symbol(self): return sympy.Symbol(self.p.prob_text)
[docs] def forward(self, x={}, **kwargs): log_prob, x = super().forward(x, **kwargs) return torch.exp(log_prob), {}