Source code for pixyz.losses.entropy

import sympy
import torch

from pixyz.losses.losses import Loss
from pixyz.losses.divergences import KullbackLeibler

[docs]def Entropy(p, analytical=True, sample_shape=torch.Size([1])):
r"""
Entropy (Analytical or Monte Carlo approximation).

.. math::

H(p) &= -\mathbb{E}_{p(x)}[\log p(x)] \qquad \text{(analytical)}\\

Examples
--------
>>> import torch
>>> from pixyz.distributions import Normal
>>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], features_shape=[64])
>>> loss_cls = Entropy(p,analytical=True)
>>> print(loss_cls)
H \left[ {p(x)} \right]
>>> loss_cls.eval()
tensor([90.8121])
>>> loss_cls = Entropy(p,analytical=False,sample_shape=[10])
>>> print(loss_cls)
- \mathbb{E}_{p(x)} \left[\log p(x) \right]
>>> loss_cls.eval() # doctest: +SKIP
tensor([90.5991])
"""
if analytical:
loss = AnalyticalEntropy(p)
else:
loss = -p.log_prob().expectation(p, sample_shape=sample_shape)
return loss

class AnalyticalEntropy(Loss):
def __init__(self, p):
_input_var = p.input_var.copy()
super().__init__(_input_var)
self.p = p

@property
def _symbol(self):
p_text = "{" + self.p.prob_text + "}"
return sympy.Symbol("H \\left[ {} \\right]".format(p_text))

def forward(self, x_dict, **kwargs):
if not hasattr(self.p, 'distribution_torch_class'):
raise ValueError("Entropy of this distribution cannot be evaluated, "
"got %s." % self.p.distribution_name)

entropy = self.p.get_entropy(x_dict)

return entropy, {}

[docs]def CrossEntropy(p, q, analytical=False, sample_shape=torch.Size([1])):
r"""
Cross entropy, a.k.a., the negative expected value of log-likelihood (Monte Carlo approximation or Analytical).

.. math::

H(p,q) &= -\mathbb{E}_{p(x)}[\log q(x)] \qquad \text{(analytical)}\\

Examples
--------
>>> import torch
>>> from pixyz.distributions import Normal
>>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], features_shape=[64], name="p")
>>> q = Normal(loc=torch.tensor(1.), scale=torch.tensor(1.), var=["x"], features_shape=[64], name="q")
>>> loss_cls = CrossEntropy(p,q,analytical=True)
>>> print(loss_cls)
D_{KL} \left[p(x)||q(x) \right] + H \left[ {p(x)} \right]
>>> loss_cls.eval()
tensor([122.8121])
>>> loss_cls = CrossEntropy(p,q,analytical=False,sample_shape=[10])
>>> print(loss_cls)
- \mathbb{E}_{p(x)} \left[\log q(x) \right]
>>> loss_cls.eval() # doctest: +SKIP
tensor([123.2192])
"""
if analytical:
loss = Entropy(p) + KullbackLeibler(p, q)
else:
loss = -q.log_prob().expectation(p, sample_shape=sample_shape)
return loss

class StochasticReconstructionLoss(Loss):
def __init__(self, encoder, decoder, sample_shape=torch.Size([1])):
raise NotImplementedError("This function is obsolete."
" please use -decoder.log_prob().expectation(encoder) instead of it.")