Source code for pixyz.losses.pdf

import sympy
import torch
from .losses import Loss


[docs]class LogProb(Loss): r""" The log probability density/mass function. .. math:: \log p(x) Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = LogProb(p) # or p.log_prob() >>> print(loss_cls) \log p(x) >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor([12.9894, 15.5280]) """ def __init__(self, p, sum_features=True, feature_dims=None): input_var = p.var + p.cond_var self.sum_features = sum_features self.feature_dims = feature_dims super().__init__(p, input_var=input_var) @property def _symbol(self): return sympy.Symbol("\\log {}".format(self.p.prob_text)) def _get_eval(self, x={}, **kwargs): log_prob = self.p.get_log_prob(x, sum_features=self.sum_features, feature_dims=self.feature_dims) return log_prob, x
[docs]class Prob(LogProb): r""" The probability density/mass function. .. math:: p(x) = \exp(\log p(x)) Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = Prob(p) # or p.prob() >>> print(loss_cls) p(x) >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor([3.2903e-07, 5.5530e-07]) """ @property def _symbol(self): return sympy.Symbol(self.p.prob_text) def _get_eval(self, x={}, **kwargs): log_prob, x = super()._get_eval(x, **kwargs) return torch.exp(log_prob), x