pixyz.losses (Loss API)

Loss

class pixyz.losses.losses.Loss(p, q=None, input_var=None)[source]

Bases: object

input_var
loss_text
abs()[source]
mean()[source]
sum()[source]
eval(x={}, return_dict=False, **kwargs)[source]
expectation(p, input_var=None)[source]
estimate(*args, **kwargs)[source]

Negative expected value of log-likelihood (entropy)

CrossEntropy

class pixyz.losses.CrossEntropy(p, q, input_var=None)[source]

Bases: pixyz.losses.losses.SetLoss

Cross entropy, a.k.a., the negative expected value of log-likelihood (Monte Carlo approximation).

H[p||q] = -\mathbb{E}_{p(x)}[\log q(x)] \approx -\frac{1}{L}\sum_{l=1}^L \log q(x_l),

where x_l \sim p(x).

Note:
This class is a special case of the Expectation class.

Entropy

class pixyz.losses.Entropy(p, input_var=None)[source]

Bases: pixyz.losses.losses.SetLoss

Entropy (Monte Carlo approximation).

H[p] = -\mathbb{E}_{p(x)}[\log p(x)] \approx -\frac{1}{L}\sum_{l=1}^L \log p(x_l),

where x_l \sim p(x).

Note:
This class is a special case of the Expectation class.

StochasticReconstructionLoss

class pixyz.losses.StochasticReconstructionLoss(encoder, decoder, input_var=None)[source]

Bases: pixyz.losses.losses.SetLoss

Reconstruction Loss (Monte Carlo approximation).

-\mathbb{E}_{q(z|x)}[\log p(x|z)] \approx -\frac{1}{L}\sum_{l=1}^L \log p(x|z_l),

where z_l \sim q(z|x).

Note:
This class is a special case of the Expectation class.

LossExpectation

Negative log-likelihood

NLL

Lower bound

ELBO

class pixyz.losses.ELBO(p, q, input_var=None)[source]

Bases: pixyz.losses.losses.SetLoss

The evidence lower bound (Monte Carlo approximation).

\mathbb{E}_{q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] \approx \frac{1}{L}\sum_{l=1}^L \log p(x, z_l),

where z_l \sim q(z|x).

Note:
This class is a special case of the Expectation class.

Statistical distance

KullbackLeibler

class pixyz.losses.KullbackLeibler(p, q, input_var=None, dim=None)[source]

Bases: pixyz.losses.losses.Loss

Kullback-Leibler divergence (analytical).

D_{KL}[p||q] = \mathbb{E}_{p(x)}[\log \frac{p(x)}{q(x)}]

TODO: This class seems to be slightly slower than this previous implementation
(perhaps because of set_distribution).
loss_text

WassersteinDistance

class pixyz.losses.WassersteinDistance(p, q, metric=PairwiseDistance(), input_var=None)[source]

Bases: pixyz.losses.losses.Loss

Wasserstein distance.

W(p, q) = \inf_{\Gamma \in \mathcal{P}(x_p\sim p, x_q\sim q)} \mathbb{E}_{(x_p, x_q) \sim \Gamma}[d(x_p, x_q)]

However, instead of the above true distance, this class computes the following one.

W'(p, q) = \mathbb{E}_{x_p\sim p, x_q \sim q}[d(x_p, x_q)].

Here, W' is the upper of W (i.e., W\leq W'), and these are equal when both p and q are degenerate (deterministic) distributions.

loss_text

MMD

class pixyz.losses.MMD(p, q, input_var=None, kernel='gaussian', **kernel_params)[source]

Bases: pixyz.losses.losses.Loss

The Maximum Mean Discrepancy (MMD).

D_{MMD^2}[p||q] = \mathbb{E}_{p(x), p(x')}[k(x, x')] + \mathbb{E}_{q(x), q(x')}[k(x, x')]
- 2\mathbb{E}_{p(x), q(x')}[k(x, x')]

where k(x, x') is any positive definite kernel.

loss_text

Adversarial statistical distance (GAN loss)

AdversarialJensenShannon

class pixyz.losses.AdversarialJensenShannon(p, q, discriminator, input_var=None, optimizer=<class 'torch.optim.adam.Adam'>, optimizer_params={}, inverse_g_loss=True)[source]

Bases: pixyz.losses.adversarial_loss.AdversarialLoss

Jensen-Shannon divergence (adversarial training).

D_{JS}[p(x)||q(x)] \leq 2 \cdot D_{JS}[p(x)||q(x)] + 2 \log 2
 = \mathbb{E}_{p(x)}[\log d^*(x)] + \mathbb{E}_{q(x)}[\log (1-d^*(x))],

where d^*(x) = \arg\max_{d} \mathbb{E}_{p(x)}[\log d(x)] + \mathbb{E}_{q(x)}[\log (1-d(x))].

loss_text
d_loss(y_p, y_q, batch_size)[source]
g_loss(y_p, y_q, batch_size)[source]

AdversarialKullbackLeibler

class pixyz.losses.AdversarialKullbackLeibler(p, q, discriminator, **kwargs)[source]

Bases: pixyz.losses.adversarial_loss.AdversarialLoss

Kullback-Leibler divergence (adversarial training).

D_{KL}[p(x)||q(x)] = \mathbb{E}_{p(x)}[\log \frac{p(x)}{q(x)}]
 = \mathbb{E}_{p(x)}[\log \frac{d^*(x)}{1-d^*(x)}],

where d^*(x) = \arg\max_{d} \mathbb{E}_{q(x)}[\log d(x)] + \mathbb{E}_{p(x)}[\log (1-d(x))].

Note that this divergence is minimized to close p to q.

loss_text
g_loss(y_p, batch_size)[source]
d_loss(y_p, y_q, batch_size)[source]

AdversarialWassersteinDistance

class pixyz.losses.AdversarialWassersteinDistance(p, q, discriminator, clip_value=0.01, **kwargs)[source]

Bases: pixyz.losses.adversarial_loss.AdversarialJensenShannon

Wasserstein distance (adversarial training).

W(p, q) = \sup_{||d||_{L} \leq 1} \mathbb{E}_{p(x)}[d(x)] - \mathbb{E}_{q(x)}[d(x)]

loss_text
d_loss(y_p, y_q, *args, **kwargs)[source]
g_loss(y_p, y_q, *args, **kwargs)[source]
train(train_x, **kwargs)[source]

Loss for sequential distributions

IterativeLoss

class pixyz.losses.IterativeLoss(step_loss, max_iter=1, input_var=None, series_var=None, update_value={}, slice_step=None, timestep_var=['t'])[source]

Bases: pixyz.losses.losses.Loss

Iterative loss.

This class allows implementing an arbitrary model which requires iteration (e.g., auto-regressive models).

\mathcal{L} = \sum_{t=1}^{T}\mathcal{L}_{step}(x_t, h_t), where x_t = f_{slice_step}(x, t)

loss_text
slice_step_fn(t, x)[source]

Loss for special purpose

Parameter

class pixyz.losses.losses.Parameter(input_var)[source]

Bases: pixyz.losses.losses.Loss

loss_text

Operators

LossOperator

class pixyz.losses.losses.LossOperator(loss1, loss2)[source]

Bases: pixyz.losses.losses.Loss

loss_text
train(x, **kwargs)[source]

TODO: Fix

test(x, **kwargs)[source]

TODO: Fix

LossSelfOperator

class pixyz.losses.losses.LossSelfOperator(loss1)[source]

Bases: pixyz.losses.losses.Loss

train(x={}, **kwargs)[source]
test(x={}, **kwargs)[source]

AddLoss

class pixyz.losses.losses.AddLoss(loss1, loss2)[source]

Bases: pixyz.losses.losses.LossOperator

loss_text

SubLoss

class pixyz.losses.losses.SubLoss(loss1, loss2)[source]

Bases: pixyz.losses.losses.LossOperator

loss_text

MulLoss

class pixyz.losses.losses.MulLoss(loss1, loss2)[source]

Bases: pixyz.losses.losses.LossOperator

loss_text

DivLoss

class pixyz.losses.losses.DivLoss(loss1, loss2)[source]

Bases: pixyz.losses.losses.LossOperator

loss_text

NegLoss

class pixyz.losses.losses.NegLoss(loss1)[source]

Bases: pixyz.losses.losses.LossSelfOperator

loss_text

AbsLoss

class pixyz.losses.losses.AbsLoss(loss1)[source]

Bases: pixyz.losses.losses.LossSelfOperator

loss_text

BatchMean

class pixyz.losses.losses.BatchMean(loss1)[source]

Bases: pixyz.losses.losses.LossSelfOperator

Loss averaged over batch data.

\mathbb{E}_{p_{data}(x)}[\mathcal{L}(x)] \approx \frac{1}{N}\sum_{i=1}^N \mathcal{L}(x_i),

where x_i \sim p_{data}(x) and \mathcal{L} is a loss function.

loss_text

BatchSum

class pixyz.losses.losses.BatchSum(loss1)[source]

Bases: pixyz.losses.losses.LossSelfOperator

Loss summed over batch data.

\sum_{i=1}^N \mathcal{L}(x_i),

where x_i \sim p_{data}(x) and \mathcal{L} is a loss function.

loss_text