pixyz.models (Model API)

Model

class pixyz.models.Model(loss, test_loss=None, distributions=[], optimizer=<class 'torch.optim.adam.Adam'>, optimizer_params={})[source]

Bases: object

set_loss(loss, test_loss=None)[source]
train(train_x={}, **kwargs)[source]
test(test_x={}, **kwargs)[source]

Pre-implementation models

ML

class pixyz.models.ML(p, other_distributions=[], optimizer=<class 'torch.optim.adam.Adam'>, optimizer_params={})[source]

Bases: pixyz.models.model.Model

Maximum Likelihood (log-likelihood)

train(train_x={}, **kwargs)[source]
test(test_x={}, **kwargs)[source]

VAE

class pixyz.models.VAE(encoder, decoder, other_distributions=[], regularizer=[], optimizer=<class 'torch.optim.adam.Adam'>, optimizer_params={})[source]

Bases: pixyz.models.model.Model

Variational Autoencoder

[Kingma+ 2013] Auto-Encoding Variational Bayes

train(train_x={}, **kwargs)[source]
test(test_x={}, **kwargs)[source]

VI

class pixyz.models.VI(p, approximate_dist, other_distributions=[], optimizer=<class 'torch.optim.adam.Adam'>, optimizer_params={})[source]

Bases: pixyz.models.model.Model

Variational Inference (Amortized inference)

train(train_x={}, **kwargs)[source]
test(test_x={}, **kwargs)[source]

GAN

class pixyz.models.GAN(p_data, p, discriminator, optimizer=<class 'torch.optim.adam.Adam'>, optimizer_params={}, d_optimizer=<class 'torch.optim.adam.Adam'>, d_optimizer_params={})[source]

Bases: pixyz.models.model.Model

Generative Adversarial Network

train(train_x={}, adversarial_loss=True, **kwargs)[source]
test(test_x={}, adversarial_loss=True, **kwargs)[source]