Source code for pixyz.models.vae

from torch import optim, nn

from ..models.model import Model
from ..utils import tolist
from ..losses import StochasticReconstructionLoss


[docs]class VAE(Model): """ Variational Autoencoder [Kingma+ 2013] Auto-Encoding Variational Bayes """ def __init__(self, encoder, decoder, other_distributions=[], regularizer=[], optimizer=optim.Adam, optimizer_params={}): # set distributions (for training) distributions = [encoder, decoder] + tolist(other_distributions) # set losses reconstruction =\ StochasticReconstructionLoss(encoder, decoder) loss = (reconstruction + regularizer).mean() super().__init__(loss, test_loss=loss, distributions=distributions, optimizer=optimizer, optimizer_params=optimizer_params)
[docs] def train(self, train_x={}, **kwargs): return super().train(train_x, **kwargs)
[docs] def test(self, test_x={}, **kwargs): return super().test(test_x, **kwargs)