Source code for pixyz.losses.divergences

import torch
from ..utils import get_dict_values
from .losses import Loss


[docs]class KullbackLeibler(Loss): r""" Kullback-Leibler divergence (analytical). .. math:: D_{KL}[p||q] = \mathbb{E}_{p(x)}[\log \frac{p(x)}{q(x)}] """ def __init__(self, p1, p2, input_var=None): super().__init__(p1, p2, input_var) @property def loss_text(self): return "KL[{}||{}]".format(self._p1.prob_text, self._p2.prob_text)
[docs] def estimate(self, x, **kwargs): x = super().estimate(x) if self._p1.distribution_name == "Normal" and self._p2.distribution_name == "Normal": inputs = get_dict_values(x, self._p1.input_var, True) params1 = self._p1.get_params(inputs, **kwargs) inputs = get_dict_values(x, self._p2.input_var, True) params2 = self._p2.get_params(inputs, **kwargs) return gauss_gauss_kl(params1["loc"], params1["scale"], params2["loc"], params2["scale"]) raise Exception("You cannot use these distributions, " "got %s and %s." % (self._p1.distribution_name, self._p2.distribution_name))
def gauss_gauss_kl(loc1, scale1, loc2, scale2, dim=None): # https://github.com/pytorch/pytorch/blob/85408e744fc1746ab939ae824a26fd6821529a94/torch/distributions/kl.py#L384 var_ratio = (scale1 / scale2).pow(2) t1 = ((loc1 - loc2) / scale2).pow(2) _kl = 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) if dim: _kl = torch.sum(_kl, dim=dim) return _kl dim_list = list(torch.arange(_kl.dim())) _kl = torch.sum(_kl, dim=dim_list[1:]) return _kl