Source code for pixyz.models.vae

from torch import optim

from ..models.model import Model
from ..utils import tolist
from ..losses import StochasticReconstructionLoss


[docs]class VAE(Model): """ Variational Autoencoder. In VAE class, reconstruction loss on given distributions (encoder and decoder) is set as the default loss class. However, if you want to add additional terms, e.g., the KL divergence between encoder and prior, you need to set them to the `regularizer` argument, which defaults to None. References ---------- [Kingma+ 2013] Auto-Encoding Variational Bayes """
[docs] def __init__(self, encoder, decoder, other_distributions=[], regularizer=None, optimizer=optim.Adam, optimizer_params={}, clip_grad_norm=None, clip_grad_value=None): """ Parameters ---------- encoder : torch.distributions.Distribution Encoder distribution. decoder : torch.distributions.Distribution Decoder distribution. regularizer : torch.losses.Loss, defaults to None If you want to add additional terms to the loss, set them to this argument. optimizer : torch.optim Optimization algorithm. optimizer_params : dict Parameters of optimizer clip_grad_norm : float or int Maximum allowed norm of the gradients. clip_grad_value : float or int Maximum allowed value of the gradients. """ # set distributions (for training) distributions = [encoder, decoder] + tolist(other_distributions) # set losses reconstruction =\ StochasticReconstructionLoss(encoder, decoder) loss = (reconstruction + regularizer).mean() super().__init__(loss, test_loss=loss, distributions=distributions, optimizer=optimizer, optimizer_params=optimizer_params, clip_grad_norm=clip_grad_norm, clip_grad_value=clip_grad_value)
[docs] def train(self, train_x_dict={}, **kwargs): return super().train(train_x_dict, **kwargs)
[docs] def test(self, test_x_dict={}, **kwargs): return super().test(test_x_dict, **kwargs)