Source code for pixyz.losses.losses

import abc
import sympy
import torch
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

import numbers
from copy import deepcopy

from ..utils import tolist, get_dict_values


[docs]class Loss(torch.nn.Module, metaclass=abc.ABCMeta): """Loss class. In Pixyz, all loss classes are required to inherit this class. Examples -------- >>> import torch >>> from torch.nn import functional as F >>> from pixyz.distributions import Bernoulli, Normal >>> from pixyz.losses import KullbackLeibler ... >>> # Set distributions >>> class Inference(Normal): ... def __init__(self): ... super().__init__(cond_var=["x"], var=["z"], name="q") ... self.model_loc = torch.nn.Linear(128, 64) ... self.model_scale = torch.nn.Linear(128, 64) ... def forward(self, x): ... return {"loc": self.model_loc(x), "scale": F.softplus(self.model_scale(x))} ... >>> class Generator(Bernoulli): ... def __init__(self): ... super().__init__(cond_var=["z"], var=["x"], name="p") ... self.model = torch.nn.Linear(64, 128) ... def forward(self, z): ... return {"probs": torch.sigmoid(self.model(z))} ... >>> p = Generator() >>> q = Inference() >>> prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), ... var=["z"], features_shape=[64], name="p_{prior}") ... >>> # Define a loss function (VAE) >>> reconst = -p.log_prob().expectation(q) >>> kl = KullbackLeibler(q, prior) >>> loss_cls = (reconst - kl).mean() >>> print(loss_cls) mean \\left(- D_{KL} \\left[q(z|x)||p_{prior}(z) \\right] - \\mathbb{E}_{q(z|x)} \\left[\\log p(x|z) \\right] \\right) >>> # Evaluate this loss function >>> data = torch.randn(1, 128) # Pseudo data >>> loss = loss_cls.eval({"x": data}) >>> print(loss) # doctest: +SKIP tensor(65.5939, grad_fn=<MeanBackward0>) """
[docs] def __init__(self, input_var=None): """ Parameters ---------- input_var : :obj:`list` of :obj:`str`, defaults to None Input variables of this loss function. In general, users do not need to set them explicitly because these depend on the given distributions and each loss function. """ super().__init__() self._input_var = deepcopy(input_var)
@property def input_var(self): """list: Input variables of this distribution.""" return self._input_var @property @abc.abstractmethod def _symbol(self): raise NotImplementedError() @property def loss_text(self): return sympy.latex(self._symbol) def __str__(self): return self.loss_text def __repr__(self): return self.loss_text def __add__(self, other): return AddLoss(self, other) def __radd__(self, other): return AddLoss(other, self) def __sub__(self, other): return SubLoss(self, other) def __rsub__(self, other): return SubLoss(other, self) def __mul__(self, other): return MulLoss(self, other) def __rmul__(self, other): return MulLoss(other, self) def __truediv__(self, other): return DivLoss(self, other) def __rtruediv__(self, other): return DivLoss(other, self) def __neg__(self): return NegLoss(self)
[docs] def abs(self): """Return an instance of :class:`pixyz.losses.losses.AbsLoss`. Returns ------- pixyz.losses.losses.AbsLoss An instance of :class:`pixyz.losses.losses.AbsLoss` """ return AbsLoss(self)
[docs] def mean(self): """Return an instance of :class:`pixyz.losses.losses.BatchMean`. Returns ------- pixyz.losses.losses.BatchMean An instance of :class:`pixyz.losses.BatchMean` """ return BatchMean(self)
[docs] def sum(self): """Return an instance of :class:`pixyz.losses.losses.BatchSum`. Returns ------- pixyz.losses.losses.BatchSum An instance of :class:`pixyz.losses.losses.BatchSum` """ return BatchSum(self)
[docs] def detach(self): """Return an instance of :class:`pixyz.losses.losses.Detach`. Returns ------- pixyz.losses.losses.Detach An instance of :class:`pixyz.losses.losses.Detach` """ return Detach(self)
[docs] def expectation(self, p, input_var=None, sample_shape=torch.Size()): """Return an instance of :class:`pixyz.losses.Expectation`. Parameters ---------- p : pixyz.distributions.Distribution Distribution for sampling. input_var : list Input variables of this loss. sample_shape : :obj:`list` or :obj:`NoneType`, defaults to torch.Size() Shape of generating samples. Returns ------- pixyz.losses.Expectation An instance of :class:`pixyz.losses.Expectation` """ return Expectation(p, self, input_var=input_var, sample_shape=sample_shape)
[docs] def eval(self, x_dict={}, return_dict=False, return_all=True, **kwargs): """Evaluate the value of the loss function given inputs (:attr:`x_dict`). Parameters ---------- x_dict : :obj:`dict`, defaults to {} Input variables. return_dict : bool, default to False. Whether to return samples along with the evaluated value of the loss function. return_all : bool, default to True. Whether to return all samples, including those that have not been updated. Returns ------- loss : torch.Tensor the evaluated value of the loss function. x_dict : :obj:`dict` All samples generated when evaluating the loss function. If :attr:`return_dict` is False, it is not returned. """ if not(set(list(x_dict.keys())) >= set(self._input_var)): raise ValueError("Input keys are not valid, expected {} but got {}.".format(self._input_var, list(x_dict.keys()))) input_dict = get_dict_values(x_dict, self.input_var, return_dict=True) loss, eval_dict = self(input_dict, **kwargs) if return_dict: output_dict = x_dict.copy() if return_all else {} output_dict.update(eval_dict) return loss, output_dict return loss
[docs] @abc.abstractmethod def forward(self, x_dict, **kwargs): """ Parameters ---------- x_dict : dict Input variables. Returns ------- a tuple of :class:`pixyz.losses.Loss` and dict deterministically calcurated loss and updated all samples. """ raise NotImplementedError()
class Divergence(Loss, abc.ABC): def __init__(self, p, q=None, input_var=None): """ Parameters ---------- p : pixyz.distributions.Distribution Distribution. q : pixyz.distributions.Distribution, defaults to None Distribution. input_var : :obj:`list` of :obj:`str`, defaults to None Input variables of this loss function. In general, users do not need to set them explicitly because these depend on the given distributions and each loss function. """ if input_var is not None: _input_var = deepcopy(input_var) else: _input_var = deepcopy(p.input_var) if q is not None: _input_var += deepcopy(q.input_var) _input_var = sorted(set(_input_var), key=_input_var.index) super().__init__(_input_var) self.p = p self.q = q
[docs]class ValueLoss(Loss): """ This class contains a scalar as a loss value. If multiplying a scalar by an arbitrary loss class, this scalar is converted to the :class:`ValueLoss`. Examples -------- >>> loss_cls = ValueLoss(2) >>> print(loss_cls) 2 >>> loss = loss_cls.eval() >>> print(loss) tensor(2.) """ def __init__(self, loss1): super().__init__() self.original_value = loss1 self.register_buffer('value', torch.tensor(loss1, dtype=torch.float)) self._input_var = []
[docs] def forward(self, x_dict={}, **kwargs): return self.value, {}
@property def _symbol(self): return self.original_value
[docs]class Parameter(Loss): """ This class defines a single variable as a loss class. It can be used such as a coefficient parameter of a loss class. Examples -------- >>> loss_cls = Parameter("x") >>> print(loss_cls) x >>> loss = loss_cls.eval({"x": 2}) >>> print(loss) 2 """ def __init__(self, input_var): if not isinstance(input_var, str): raise ValueError() super().__init__(tolist(input_var))
[docs] def forward(self, x_dict={}, **kwargs): return x_dict[self._input_var[0]], {}
@property def _symbol(self): return sympy.Symbol(self._input_var[0])
[docs]class LossOperator(Loss): def __init__(self, loss1, loss2): super().__init__() _input_var = [] if isinstance(loss1, Loss): _input_var += deepcopy(loss1.input_var) elif isinstance(loss1, numbers.Number): loss1 = ValueLoss(loss1) elif isinstance(loss2, type(None)): pass else: raise ValueError("{} cannot be operated with {}.".format(type(loss1), type(loss2))) if isinstance(loss2, Loss): _input_var += deepcopy(loss2.input_var) elif isinstance(loss2, numbers.Number): loss2 = ValueLoss(loss2) elif isinstance(loss2, type(None)): pass else: raise ValueError("{} cannot be operated with {}.".format(type(loss2), type(loss1))) _input_var = sorted(set(_input_var), key=_input_var.index) self._input_var = _input_var self.loss1 = loss1 self.loss2 = loss2
[docs] def forward(self, x_dict={}, **kwargs): if not isinstance(self.loss1, type(None)): loss1, x1 = self.loss1.eval(x_dict, return_dict=True, return_all=False, **kwargs) else: loss1 = 0 x1 = {} if not isinstance(self.loss2, type(None)): loss2, x2 = self.loss2.eval(x_dict, return_dict=True, return_all=False, **kwargs) else: loss2 = 0 x2 = {} x1.update(x2) return loss1, loss2, x1
[docs]class AddLoss(LossOperator): """ Apply the `add` operation to the two losses. Examples -------- >>> loss_cls_1 = ValueLoss(2) >>> loss_cls_2 = Parameter("x") >>> loss_cls = loss_cls_1 + loss_cls_2 # equals to AddLoss(loss_cls_1, loss_cls_2) >>> print(loss_cls) x + 2 >>> loss = loss_cls.eval({"x": 3}) >>> print(loss) tensor(5.) """ @property def _symbol(self): return self.loss1._symbol + self.loss2._symbol
[docs] def forward(self, x_dict={}, **kwargs): loss1, loss2, x_dict = super().forward(x_dict, **kwargs) return loss1 + loss2, x_dict
[docs]class SubLoss(LossOperator): """ Apply the `sub` operation to the two losses. Examples -------- >>> loss_cls_1 = ValueLoss(2) >>> loss_cls_2 = Parameter("x") >>> loss_cls = loss_cls_1 - loss_cls_2 # equals to SubLoss(loss_cls_1, loss_cls_2) >>> print(loss_cls) 2 - x >>> loss = loss_cls.eval({"x": 4}) >>> print(loss) tensor(-2.) >>> loss_cls = loss_cls_2 - loss_cls_1 # equals to SubLoss(loss_cls_2, loss_cls_1) >>> print(loss_cls) x - 2 >>> loss = loss_cls.eval({"x": 4}) >>> print(loss) tensor(2.) """ @property def _symbol(self): return self.loss1._symbol - self.loss2._symbol
[docs] def forward(self, x_dict={}, **kwargs): loss1, loss2, x_dict = super().forward(x_dict, **kwargs) return loss1 - loss2, x_dict
[docs]class MulLoss(LossOperator): """ Apply the `mul` operation to the two losses. Examples -------- >>> loss_cls_1 = ValueLoss(2) >>> loss_cls_2 = Parameter("x") >>> loss_cls = loss_cls_1 * loss_cls_2 # equals to MulLoss(loss_cls_1, loss_cls_2) >>> print(loss_cls) 2 x >>> loss = loss_cls.eval({"x": 4}) >>> print(loss) tensor(8.) """ @property def _symbol(self): return self.loss1._symbol * self.loss2._symbol
[docs] def forward(self, x_dict={}, **kwargs): loss1, loss2, x_dict = super().forward(x_dict, **kwargs) return loss1 * loss2, x_dict
[docs]class DivLoss(LossOperator): """ Apply the `div` operation to the two losses. Examples -------- >>> loss_cls_1 = ValueLoss(2) >>> loss_cls_2 = Parameter("x") >>> loss_cls = loss_cls_1 / loss_cls_2 # equals to DivLoss(loss_cls_1, loss_cls_2) >>> print(loss_cls) \\frac{2}{x} >>> loss = loss_cls.eval({"x": 4}) >>> print(loss) tensor(0.5000) >>> loss_cls = loss_cls_2 / loss_cls_1 # equals to DivLoss(loss_cls_2, loss_cls_1) >>> print(loss_cls) \\frac{x}{2} >>> loss = loss_cls.eval({"x": 4}) >>> print(loss) tensor(2.) """ @property def _symbol(self): return self.loss1._symbol / self.loss2._symbol
[docs] def forward(self, x_dict={}, **kwargs): loss1, loss2, x_dict = super().forward(x_dict, **kwargs) return loss1 / loss2, x_dict
[docs]class MinLoss(LossOperator): r""" Apply the `min` operation to the loss. Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> from pixyz.losses.losses import ValueLoss, Parameter, MinLoss >>> loss_min= MinLoss(ValueLoss(3), ValueLoss(1)) >>> print(loss_min) min \left(3, 1\right) >>> print(loss_min.eval()) tensor(1.) """ def __init__(self, loss1, loss2): super().__init__(loss1, loss2) @property def _symbol(self): return sympy.Symbol("min \\left({}, {}\\right)".format(self.loss1.loss_text, self.loss2.loss_text))
[docs] def forward(self, x_dict={}, **kwargs): loss1, loss2, x_dict = super().forward(x_dict, **kwargs) return torch.min(loss1, loss2), x_dict
[docs]class MaxLoss(LossOperator): r""" Apply the `max` operation to the loss. Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> from pixyz.losses.losses import ValueLoss, MaxLoss >>> loss_max= MaxLoss(ValueLoss(3), ValueLoss(1)) >>> print(loss_max) max \left(3, 1\right) >>> print(loss_max.eval()) tensor(3.) """ def __init__(self, loss1, loss2): super().__init__(loss1, loss2) @property def _symbol(self): return sympy.Symbol("max \\left({}, {}\\right)".format(self.loss1.loss_text, self.loss2.loss_text))
[docs] def forward(self, x_dict={}, **kwargs): loss1, loss2, x_dict = super().forward(x_dict, **kwargs) return torch.max(loss1, loss2), x_dict
[docs]class LossSelfOperator(Loss): def __init__(self, loss1): super().__init__() _input_var = [] if isinstance(loss1, type(None)): raise ValueError() if isinstance(loss1, Loss): _input_var = deepcopy(loss1.input_var) elif isinstance(loss1, numbers.Number): loss1 = ValueLoss(loss1) else: raise ValueError() self._input_var = _input_var self.loss1 = loss1
[docs] def loss_train(self, x_dict={}, **kwargs): return self.loss1.loss_train(x_dict, **kwargs)
[docs] def loss_test(self, x_dict={}, **kwargs): return self.loss1.loss_test(x_dict, **kwargs)
[docs]class NegLoss(LossSelfOperator): """ Apply the `neg` operation to the loss. Examples -------- >>> loss_cls_1 = Parameter("x") >>> loss_cls = -loss_cls_1 # equals to NegLoss(loss_cls_1) >>> print(loss_cls) - x >>> loss = loss_cls.eval({"x": 4}) >>> print(loss) -4 """ @property def _symbol(self): return -self.loss1._symbol
[docs] def forward(self, x_dict={}, **kwargs): loss, x_dict = self.loss1.eval(x_dict, return_dict=True, return_all=False, **kwargs) return -loss, x_dict
[docs]class AbsLoss(LossSelfOperator): """ Apply the `abs` operation to the loss. Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> from pixyz.losses import LogProb >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = LogProb(p).abs() # equals to AbsLoss(LogProb(p)) >>> print(loss_cls) |\\log p(x)| >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor([12.9894, 15.5280]) """ @property def _symbol(self): return sympy.Symbol("|{}|".format(self.loss1.loss_text))
[docs] def forward(self, x_dict={}, **kwargs): loss, x_dict = self.loss1.eval(x_dict, return_dict=True, return_all=False, **kwargs) return loss.abs(), x_dict
[docs]class BatchMean(LossSelfOperator): r""" Average a loss class over given batch data. .. math:: \mathbb{E}_{p_{data}(x)}[\mathcal{L}(x)] \approx \frac{1}{N}\sum_{i=1}^N \mathcal{L}(x_i), where :math:`x_i \sim p_{data}(x)` and :math:`\mathcal{L}` is a loss function. Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> from pixyz.losses import LogProb >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = LogProb(p).mean() # equals to BatchMean(LogProb(p)) >>> print(loss_cls) mean \left(\log p(x) \right) >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor(-14.5038) """ @property def _symbol(self): return sympy.Symbol("mean \\left({} \\right)".format(self.loss1.loss_text)) # TODO: fix it
[docs] def forward(self, x_dict={}, **kwargs): loss, x_dict = self.loss1.eval(x_dict, return_dict=True, return_all=False, **kwargs) return loss.mean(), x_dict
[docs]class BatchSum(LossSelfOperator): r""" Summation a loss class over given batch data. .. math:: \sum_{i=1}^N \mathcal{L}(x_i), where :math:`x_i \sim p_{data}(x)` and :math:`\mathcal{L}` is a loss function. Examples -------- >>> import torch >>> from pixyz.distributions import Normal >>> from pixyz.losses import LogProb >>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["x"], ... features_shape=[10]) >>> loss_cls = LogProb(p).sum() # equals to BatchSum(LogProb(p)) >>> print(loss_cls) sum \left(\log p(x) \right) >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor(-31.9434) """ @property def _symbol(self): return sympy.Symbol("sum \\left({} \\right)".format(self.loss1.loss_text)) # TODO: fix it
[docs] def forward(self, x_dict={}, **kwargs): loss, x_dict = self.loss1.eval(x_dict, return_dict=True, return_all=False, **kwargs) return loss.sum(), x_dict
[docs]class Detach(LossSelfOperator): r""" Apply the `detach` method to the loss. """ @property def _symbol(self): return sympy.Symbol("detach \\left({} \\right)".format(self.loss1.loss_text)) # TODO: fix it?
[docs] def forward(self, x_dict={}, **kwargs): loss, x_dict = self.loss1.eval(x_dict, return_dict=True, return_all=False, **kwargs) return loss.detach(), x_dict
[docs]class Expectation(Loss): r""" Expectation of a given function (Monte Carlo approximation). .. math:: \mathbb{E}_{p(x)}[f(x)] \approx \frac{1}{L}\sum_{l=1}^L f(x_l), \quad \text{where}\quad x_l \sim p(x). Note that :math:`f` doesn't need to be able to sample, which is known as the law of the unconscious statistician (LOTUS). Therefore, in this class, :math:`f` is assumed to :attr:`pixyz.Loss`. Examples -------- >>> import torch >>> from pixyz.distributions import Normal, Bernoulli >>> from pixyz.losses import LogProb >>> q = Normal(loc="x", scale=torch.tensor(1.), var=["z"], cond_var=["x"], ... features_shape=[10]) # q(z|x) >>> p = Normal(loc="z", scale=torch.tensor(1.), var=["x"], cond_var=["z"], ... features_shape=[10]) # p(x|z) >>> loss_cls = LogProb(p).expectation(q) # equals to Expectation(q, LogProb(p)) >>> print(loss_cls) \mathbb{E}_{p(z|x)} \left[\log p(x|z) \right] >>> sample_x = torch.randn(2, 10) # Psuedo data >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP tensor([-12.8181, -12.6062]) >>> loss_cls = LogProb(p).expectation(q, sample_shape=(5,)) >>> loss = loss_cls.eval({"x": sample_x}) >>> print(loss) # doctest: +SKIP >>> q = Bernoulli(probs=torch.tensor(0.5), var=["x"], cond_var=[], features_shape=[10]) # q(x) >>> p = Bernoulli(probs=torch.tensor(0.3), var=["x"], cond_var=[], features_shape=[10]) # p(x) >>> loss_cls = p.log_prob().expectation(q, sample_shape=[64]) >>> train_loss = loss_cls.eval() >>> print(train_loss) # doctest: +SKIP tensor([46.7559]) >>> eval_loss = loss_cls.eval(test_mode=True) >>> print(eval_loss) # doctest: +SKIP tensor([-7.6047]) """ def __init__(self, p, f, input_var=None, sample_shape=torch.Size([1]), reparam=True): if input_var is None: input_var = list(set(p.input_var) | set(f.input_var) - set(p.var)) super().__init__(input_var=input_var) self.p = p self.f = f self.sample_shape = torch.Size(sample_shape) self.reparam = reparam @property def _symbol(self): p_text = "{" + self.p.prob_text + "}" return sympy.Symbol("\\mathbb{{E}}_{} \\left[{} \\right]".format(p_text, self.f.loss_text))
[docs] def forward(self, x_dict={}, **kwargs): samples_dicts = [self.p.sample(x_dict, reparam=self.reparam, return_all=False) for i in range(self.sample_shape.numel())] loss_and_dicts = [] for samples_dict in samples_dicts: input_dict = x_dict.copy() input_dict.update(samples_dict) loss_and_dicts.append(self.f.eval(input_dict, return_dict=True, return_all=False, **kwargs)) losses = [loss for loss, loss_sample_dict in loss_and_dicts] # sum over sample_shape loss = torch.stack(losses).mean(dim=0) output_dict = {} output_dict.update(samples_dicts[0]) output_dict.update(loss_and_dicts[0][1]) return loss, output_dict
[docs]def REINFORCE(p, f, b=ValueLoss(0), input_var=None, sample_shape=torch.Size([1]), reparam=True): r""" Surrogate Loss for Policy Gradient Method (REINFORCE) with a given reward function :math:`f` and a given baseline :math:`b`. .. math:: \mathbb{E}_{p(x)}[detach(f(x)-b(x))\log p(x)+f(x)-b(x)]. in this function, :math:`f` and :math:`b` is assumed to :attr:`pixyz.Loss`. Parameters ---------- p : :class:`pixyz.distributions.Distribution` Distribution for expectation. f : :class:`pixyz.losses.Loss` reward function b : :class:`pixyz.losses.Loss` default to pixyz.losses.ValueLoss(0) baseline function input_var : :obj:`list` of :obj:`str`, defaults to None Input variables of this loss function. In general, users do not need to set them explicitly because these depend on the given distributions and each loss function. sample_shape : :class:`torch.Size` default to torch.Size([1]) sample size for expectation reparam : :obj: bool default to True using reparameterization in internal sampling Returns ------- surrogate_loss : :class:`pixyz.losses.Loss` policy gradient can be calcurated from a gradient of this surrogate loss. Examples -------- >>> import torch >>> from pixyz.distributions import Normal, Bernoulli >>> from pixyz.losses import LogProb >>> q = Bernoulli(probs=torch.tensor(0.5), var=["x"], cond_var=[], features_shape=[10]) # q(x) >>> p = Bernoulli(probs=torch.tensor(0.3), var=["x"], cond_var=[], features_shape=[10]) # p(x) >>> loss_cls = REINFORCE(q, p.log_prob(), sample_shape=[64]) >>> train_loss = loss_cls.eval(test_mode=True) >>> print(train_loss) # doctest: +SKIP tensor([46.7559]) >>> loss_cls = p.log_prob().expectation(q, sample_shape=[64]) >>> test_loss = loss_cls.eval() >>> print(test_loss) # doctest: +SKIP tensor([-7.6047]) """ return Expectation(p, (f - b).detach() * p.log_prob() + (f - b), None, sample_shape, reparam=reparam)
[docs]class DataParalleledLoss(Loss): r""" Loss class wrapper of torch.nn.DataParallel. It can be used as the original loss class. `eval` & `forward` methods support data-parallel running. Examples -------- >>> import torch >>> from torch import optim >>> from torch.nn import functional as F >>> from pixyz.distributions import Bernoulli, Normal >>> from pixyz.losses import KullbackLeibler, DataParalleledLoss >>> from pixyz.models import Model >>> used_gpu_i = set() >>> used_gpu_g = set() >>> # Set distributions (Distribution API) >>> class Inference(Normal): ... def __init__(self): ... super().__init__(cond_var=["x"], var=["z"], name="q") ... self.model_loc = torch.nn.Linear(128, 64) ... self.model_scale = torch.nn.Linear(128, 64) ... def forward(self, x): ... used_gpu_i.add(x.device.index) ... return {"loc": self.model_loc(x), "scale": F.softplus(self.model_scale(x))} >>> class Generator(Bernoulli): ... def __init__(self): ... super().__init__(cond_var=["z"], var=["x"], name="p") ... self.model = torch.nn.Linear(64, 128) ... def forward(self, z): ... used_gpu_g.add(z.device.index) ... return {"probs": torch.sigmoid(self.model(z))} >>> p = Generator() >>> q = Inference() >>> prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), ... var=["z"], features_shape=[64], name="p_{prior}") >>> # Define a loss function (Loss API) >>> reconst = -p.log_prob().expectation(q) >>> kl = KullbackLeibler(q, prior) >>> batch_loss_cls = (reconst - kl) >>> # device settings >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") >>> device_count = torch.cuda.device_count() >>> if device_count > 1: ... loss_cls = DataParalleledLoss(batch_loss_cls).mean().to(device) ... else: ... loss_cls = batch_loss_cls.mean().to(device) >>> # Set a model (Model API) >>> model = Model(loss=loss_cls, distributions=[p, q], ... optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) >>> # Train and test the model >>> data = torch.randn(2, 128).to(device) # Pseudo data >>> train_loss = model.train({"x": data}) >>> expected = set(range(device_count)) if torch.cuda.is_available() else {None} >>> assert used_gpu_i==expected >>> assert used_gpu_g==expected """ def __init__(self, loss, distributed=False, **kwargs): super().__init__(loss.input_var) if distributed: self.paralleled = DistributedDataParallel(loss, **kwargs) else: self.paralleled = DataParallel(loss, **kwargs)
[docs] def forward(self, x_dict, **kwargs): return self.paralleled.forward(x_dict, **kwargs)
@property def _symbol(self): return self.paralleled.module._symbol def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.paralleled.module, name)