# Source code for pixyz.losses.mmd

import torch
import sympy
from .losses import Loss
from ..utils import get_dict_values

[docs]class MMD(Loss):
r"""
The Maximum Mean Discrepancy (MMD).

.. math::

D_{MMD^2}[p||q] = \mathbb{E}_{p(x), p(x')}[k(x, x')] + \mathbb{E}_{q(x), q(x')}[k(x, x')]
- 2\mathbb{E}_{p(x), q(x')}[k(x, x')]

where :math:k(x, x') is any positive definite kernel.

Examples
--------
>>> import torch
>>> from pixyz.distributions import Normal
>>> p = Normal(loc="x", scale=torch.tensor(1.), var=["z"], cond_var=["x"], features_shape=[64], name="p")
>>> q = Normal(loc="x", scale=torch.tensor(1.), var=["z"], cond_var=["x"], features_shape=[64], name="q")
>>> loss_cls = MMD(p, q, kernel="gaussian")
>>> print(loss_cls)
D_{MMD^2} \left[p(z|x)||q(z|x) \right]
>>> loss = loss_cls.eval({"x": torch.randn(1, 64)})
>>> # Use the inverse (multi-)quadric kernel
>>> loss = MMD(p, q, kernel="inv-multiquadratic").eval({"x": torch.randn(10, 64)})
"""

def __init__(self, p, q, input_var=None, kernel="gaussian", **kernel_params):
if p.var != q.var:
raise ValueError("The two distribution variables must be the same.")

if len(p.var) != 1:
raise ValueError("A given distribution must have only one variable.")

if len(p.input_var) > 0:
self.input_dist = p
elif len(q.input_var) > 0:
self.input_dist = q
else:
raise NotImplementedError

if kernel == "gaussian":
self.kernel = gaussian_rbf_kernel
elif kernel == "inv-multiquadratic":
self.kernel = inverse_multiquadratic_rbf_kernel
else:
raise NotImplementedError

self.kernel_params = kernel_params

if input_var is None:
input_var = p.input_var + q.input_var

super().__init__(p, q, input_var=input_var)

@property
def _symbol(self):
return sympy.Symbol("D_{{MMD^2}} \\left[{}||{} \\right]".format(self.p.prob_text, self.q.prob_text))

def _get_batch_n(self, x_dict):
return get_dict_values(x_dict, self.input_dist.input_var[0])[0].shape[0]

def _get_eval(self, x_dict={}, **kwargs):
batch_n = self._get_batch_n(x_dict)

# sample from distributions
p_x = get_dict_values(self.p.sample(x_dict, batch_n=batch_n), self.p.var)[0]
q_x = get_dict_values(self.q.sample(x_dict, batch_n=batch_n), self.q.var)[0]

if p_x.shape != q_x.shape:
raise ValueError("The two distribution variables must have the same shape.")

if len(p_x.shape) != 2:
raise ValueError("The number of axes of a given sample must be 2, got %d" % len(p_x.shape))

p_x_dim = p_x.shape[1]
q_x_dim = q_x.shape[1]

# estimate the squared MMD (unbiased estimator)
p_kernel = self.kernel(p_x, p_x, **self.kernel_params).sum() / (p_x_dim * (p_x_dim - 1))
q_kernel = self.kernel(q_x, q_x, **self.kernel_params).sum() / (q_x_dim * (q_x_dim - 1))
pq_kernel = self.kernel(p_x, q_x, **self.kernel_params).sum() / (p_x_dim * q_x_dim)
mmd_loss = p_kernel + q_kernel - 2 * pq_kernel

return mmd_loss, x_dict

def pairwise_distance_matrix(x, y, metric="euclidean"):
r"""
Computes the pairwise distance matrix between x and y.
"""

if metric == "euclidean":
return torch.sum((x[:, None, :] - y[None, :, :]) ** 2, dim=-1)

raise NotImplementedError

def gaussian_rbf_kernel(x, y, sigma_sqr=2., **kwargs):
r"""
Gaussian radial basis function (RBF) kernel.

.. math::

k(x, y) = \exp (\frac{||x-y||^2}{\sigma^2})
"""

return torch.exp(-pairwise_distance_matrix(x, y) / (1. * sigma_sqr))

def inverse_multiquadratic_rbf_kernel(x, y, sigma_sqr=2., **kwargs):
r"""
Inverse multi-quadratic radial basis function (RBF) kernel.

.. math::

k(x, y) = \frac{\sigma^2}{||x-y||^2 + \sigma^2}
"""

return sigma_sqr / (pairwise_distance_matrix(x, y) + sigma_sqr)