Source code for pixyz.losses.iteration

from copy import deepcopy
import sympy

from .losses import Loss
from ..utils import get_dict_values


[docs]class IterativeLoss(Loss): r""" Iterative loss. This class allows implementing an arbitrary model which requires iteration. .. math:: \mathcal{L} = \sum_{t=1}^{T}\mathcal{L}_{step}(x_t, h_t), where :math:`x_t = f_{slice\_step}(x, t)`. Examples -------- >>> import torch >>> from torch.nn import functional as F >>> from pixyz.distributions import Normal, Bernoulli, Deterministic >>> >>> # Set distributions >>> x_dim = 128 >>> z_dim = 64 >>> h_dim = 32 >>> >>> # p(x|z,h_{prev}) >>> class Decoder(Bernoulli): ... def __init__(self): ... super().__init__(cond_var=["z", "h_prev"], var=["x"], name="p") ... self.fc = torch.nn.Linear(z_dim + h_dim, x_dim) ... def forward(self, z, h_prev): ... return {"probs": torch.sigmoid(self.fc(torch.cat((z, h_prev), dim=-1)))} ... >>> # q(z|x,h_{prev}) >>> class Encoder(Normal): ... def __init__(self): ... super().__init__(cond_var=["x", "h_prev"], var=["z"], name="q") ... self.fc_loc = torch.nn.Linear(x_dim + h_dim, z_dim) ... self.fc_scale = torch.nn.Linear(x_dim + h_dim, z_dim) ... def forward(self, x, h_prev): ... xh = torch.cat((x, h_prev), dim=-1) ... return {"loc": self.fc_loc(xh), "scale": F.softplus(self.fc_scale(xh))} ... >>> # f(h|x,z,h_{prev}) (update h) >>> class Recurrence(Deterministic): ... def __init__(self): ... super().__init__(cond_var=["x", "z", "h_prev"], var=["h"], name="f") ... self.rnncell = torch.nn.GRUCell(x_dim + z_dim, h_dim) ... def forward(self, x, z, h_prev): ... return {"h": self.rnncell(torch.cat((z, x), dim=-1), h_prev)} >>> >>> p = Decoder() >>> q = Encoder() >>> f = Recurrence() >>> >>> # Set the loss class >>> step_loss_cls = p.log_prob().expectation(q * f).mean() >>> print(step_loss_cls) mean \left(\mathbb{E}_{p(h,z|x,h_{prev})} \left[\log p(x|z,h_{prev}) \right] \right) >>> loss_cls = IterativeLoss(step_loss=step_loss_cls, ... series_var=["x"], update_value={"h": "h_prev"}) >>> print(loss_cls) \sum_{t=1}^{t_{max}} mean \left(\mathbb{E}_{p(h,z|x,h_{prev})} \left[\log p(x|z,h_{prev}) \right] \right) >>> >>> # Evaluate >>> x_sample = torch.randn(30, 2, 128) # (timestep_size, batch_size, feature_size) >>> h_init = torch.zeros(2, 32) # (batch_size, h_dim) >>> loss = loss_cls.eval({"x": x_sample, "h_prev": h_init}) >>> print(loss) # doctest: +SKIP tensor(-2826.0906, grad_fn=<AddBackward0> """ def __init__(self, step_loss, max_iter=None, input_var=None, series_var=None, update_value={}, slice_step=None, timestep_var=["t"]): self.step_loss = step_loss self.max_iter = max_iter self.update_value = update_value self.timestep_var = timestep_var self.timpstep_symbol = sympy.Symbol(self.timestep_var[0]) if (series_var is None) and (max_iter is None): raise ValueError self.slice_step = slice_step if self.slice_step: self.step_loss = self.step_loss.expectation(self.slice_step) if input_var is not None: self._input_var = input_var else: _input_var = [] _input_var += deepcopy(self.step_loss.input_var) self._input_var = sorted(set(_input_var), key=_input_var.index) if slice_step: self._input_var.remove(timestep_var[0]) # delete a time-step variable from input_var self.series_var = series_var @property def _symbol(self): # TODO: naive implementation dummy_loss = sympy.Symbol("dummy_loss") if self.max_iter: max_iter = self.max_iter else: max_iter = sympy.Symbol(sympy.latex(self.timpstep_symbol) + "_{max}") _symbol = sympy.Sum(dummy_loss, (self.timpstep_symbol, 1, max_iter)) _symbol = _symbol.subs({dummy_loss: self.step_loss._symbol}) return _symbol
[docs] def slice_step_fn(self, t, x): return {k: v[t] for k, v in x.items()}
def _get_eval(self, x_dict, **kwargs): series_x_dict = get_dict_values(x_dict, self.series_var, return_dict=True) step_loss_sum = 0 # set max_iter if self.max_iter: max_iter = self.max_iter else: max_iter = len(series_x_dict[self.series_var[0]]) if "mask" in kwargs.keys(): mask = kwargs["mask"].float() else: mask = None for t in range(max_iter): if self.slice_step: x_dict.update({self.timestep_var[0]: t}) else: # update series inputs & use slice_step_fn x_dict.update(self.slice_step_fn(t, series_x_dict)) # evaluate step_loss, samples = self.step_loss.eval(x_dict, return_dict=True) x_dict.update(samples) if mask is not None: step_loss *= mask[t] step_loss_sum += step_loss # update for key, value in self.update_value.items(): x_dict.update({value: x_dict[key]}) loss = step_loss_sum x_dict.update(series_x_dict) return loss, x_dict