Source code for pixyz.models.gan

from torch import optim

from ..models.model import Model
from ..losses import AdversarialJensenShannon


[docs]class GAN(Model): """ Generative Adversarial Network """ def __init__(self, p_data, p, discriminator, optimizer=optim.Adam, optimizer_params={}, d_optimizer=optim.Adam, d_optimizer_params={},): # set distributions (for training) distributions = [p] # set losses loss = AdversarialJensenShannon(p_data, p, discriminator, optimizer=d_optimizer, optimizer_params=d_optimizer_params).mean() super().__init__(loss, test_loss=loss, distributions=distributions, optimizer=optimizer, optimizer_params=optimizer_params)
[docs] def train(self, train_x={}, adversarial_loss=True, **kwargs): if adversarial_loss: d_loss = self.loss_cls.train(train_x, **kwargs) loss = super().train(train_x, **kwargs) if adversarial_loss: return loss, d_loss return loss
[docs] def test(self, test_x={}, adversarial_loss=True, **kwargs): loss = super().test(test_x, **kwargs) if adversarial_loss: d_loss = self.loss_cls.test(test_x, **kwargs) return loss, d_loss return loss