Source code for pixyz.losses.expectations

from .losses import Loss
from ..utils import get_dict_values

[docs]class CrossEntropy(Loss):
r"""
Cross entropy, a.k.a., the negative expected value of log-likelihood (Monte Carlo approximation).

.. math::

-\mathbb{E}_{q(x)}[\log p(x)] \approx -\frac{1}{L}\sum_{l=1}^L \log p(x_l),

where :math:x_l \sim q(x).
"""

def __init__(self, p1, p2, input_var=None):
if input_var is None:
input_var = list(set(p1.input_var + p2.var))
super().__init__(p1, p2, input_var=input_var)

@property
def loss_text(self):
return "-E_{}[log {}]".format(self._p1.prob_text, self._p2.prob_text)

[docs]    def estimate(self, x={}):
_x = super().estimate(x)
_p1_input = get_dict_values(_x, self._p1.input_var, return_dict=True)
samples = self._p1.sample(_p1_input, reparam=True, return_all=False)

_p2_input = get_dict_values(_x, self._p2.var, return_dict=True)
samples.update(_p2_input)

loss = -self._p2.log_likelihood(samples)

return loss

[docs]class Entropy(Loss):
r"""
Entropy (Monte Carlo approximation).

.. math::

-\mathbb{E}_{p(x)}[\log p(x)] \approx -\frac{1}{L}\sum_{l=1}^L \log p(x_l),

where :math:x_l \sim p(x).

Note:
This class is a special case of the CrossEntropy class. You can get the same result with CrossEntropy.
"""

def __init__(self, p1, input_var=None):
if input_var is None:
input_var = p1.input_var
super().__init__(p1, None, input_var=input_var)

@property
def loss_text(self):
return "-E_{}[log {}]".format(self._p1.prob_text, self._p1.prob_text)

[docs]    def estimate(self, x={}):
_x = super().estimate(x)
samples = self._p1.sample(_x, reparam=True)

loss = self._p1.log_likelihood(samples)

return loss

[docs]class StochasticReconstructionLoss(Loss):
r"""
Reconstruction Loss (Monte Carlo approximation).

.. math::

-\mathbb{E}_{q(z|x)}[\log p(x|z)] \approx -\frac{1}{L}\sum_{l=1}^L \log p(x|z_l),

where :math:z_l \sim q(z|x).

Note:
This class is a special case of the CrossEntropy class. You can get the same result with CrossEntropy.
"""

def __init__(self, encoder, decoder, input_var=None):

if input_var is None:
input_var = encoder.input_var
super().__init__(encoder, decoder, input_var=input_var)

@property
def loss_text(self):
return "-E_{}[log {}]".format(self._p1.prob_text, self._p2.prob_text)

[docs]    def estimate(self, x={}):
_x = super().estimate(x)
samples = self._p1.sample(_x, reparam=True)
loss = -self._p2.log_likelihood(samples)

return loss