Source code for pixyz.losses.iteration

from copy import deepcopy
import sympy

from .losses import Loss
from ..utils import get_dict_values, replace_dict_keys


[docs]class IterativeLoss(Loss): r""" Iterative loss. This class allows implementing an arbitrary model which requires iteration. .. math:: \mathcal{L} = \sum_{t=0}^{T-1}\mathcal{L}_{step}(x_t, h_t), where :math:`x_t = f_{slice\_step}(x, t)`. Examples -------- >>> import torch >>> from torch.nn import functional as F >>> from pixyz.distributions import Normal, Bernoulli, Deterministic >>> >>> # Set distributions >>> x_dim = 128 >>> z_dim = 64 >>> h_dim = 32 >>> >>> # p(x|z,h_{prev}) >>> class Decoder(Bernoulli): ... def __init__(self): ... super().__init__(var=["x"],cond_var=["z", "h_prev"],name="p") ... self.fc = torch.nn.Linear(z_dim + h_dim, x_dim) ... def forward(self, z, h_prev): ... return {"probs": torch.sigmoid(self.fc(torch.cat((z, h_prev), dim=-1)))} ... >>> # q(z|x,h_{prev}) >>> class Encoder(Normal): ... def __init__(self): ... super().__init__(var=["z"],cond_var=["x", "h_prev"],name="q") ... self.fc_loc = torch.nn.Linear(x_dim + h_dim, z_dim) ... self.fc_scale = torch.nn.Linear(x_dim + h_dim, z_dim) ... def forward(self, x, h_prev): ... xh = torch.cat((x, h_prev), dim=-1) ... return {"loc": self.fc_loc(xh), "scale": F.softplus(self.fc_scale(xh))} ... >>> # f(h|x,z,h_{prev}) (update h) >>> class Recurrence(Deterministic): ... def __init__(self): ... super().__init__(var=["h"], cond_var=["x", "z", "h_prev"], name="f") ... self.rnncell = torch.nn.GRUCell(x_dim + z_dim, h_dim) ... def forward(self, x, z, h_prev): ... return {"h": self.rnncell(torch.cat((z, x), dim=-1), h_prev)} >>> >>> p = Decoder() >>> q = Encoder() >>> f = Recurrence() >>> >>> # Set the loss class >>> step_loss_cls = p.log_prob().expectation(q * f).mean() >>> print(step_loss_cls) mean \left(\mathbb{E}_{q(z,h|x,h_{prev})} \left[\log p(x|z,h_{prev}) \right] \right) >>> loss_cls = IterativeLoss(step_loss=step_loss_cls, ... series_var=["x"], update_value={"h": "h_prev"}) >>> print(loss_cls) \sum_{t=0}^{t_{max} - 1} mean \left(\mathbb{E}_{q(z,h|x,h_{prev})} \left[\log p(x|z,h_{prev}) \right] \right) >>> >>> # Evaluate >>> x_sample = torch.randn(30, 2, 128) # (timestep_size, batch_size, feature_size) >>> h_init = torch.zeros(2, 32) # (batch_size, h_dim) >>> loss = loss_cls.eval({"x": x_sample, "h_prev": h_init}) >>> print(loss) # doctest: +SKIP tensor(-2826.0906, grad_fn=<AddBackward0> """ def __init__(self, step_loss, max_iter=None, series_var=(), update_value={}, slice_step=None, timestep_var=()): super().__init__() self.step_loss = step_loss self.max_iter = max_iter self.update_value = update_value self.timestep_var = timestep_var if timestep_var: self.timpstep_symbol = sympy.Symbol(self.timestep_var[0]) else: self.timpstep_symbol = sympy.Symbol("t") if not series_var and (max_iter is None): raise ValueError() self.slice_step = slice_step if self.slice_step: self.step_loss = self.step_loss.expectation(self.slice_step) _input_var = [] _input_var += deepcopy(self.step_loss.input_var) _input_var += series_var _input_var += update_value.values() self._input_var = sorted(set(_input_var), key=_input_var.index) if timestep_var: self._input_var.remove(timestep_var[0]) # delete a time-step variable from input_var self.series_var = series_var @property def _symbol(self): # TODO: naive implementation dummy_loss = sympy.Symbol("dummy_loss") if self.max_iter: max_iter = self.max_iter else: max_iter = sympy.Symbol(sympy.latex(self.timpstep_symbol) + "_{max}") _symbol = sympy.Sum(dummy_loss, (self.timpstep_symbol, 0, max_iter - 1)) _symbol = _symbol.subs({dummy_loss: self.step_loss._symbol}) return _symbol
[docs] def slice_step_fn(self, t, x): return {k: v[t] for k, v in x.items()}
[docs] def forward(self, x_dict, **kwargs): series_x_dict = get_dict_values(x_dict, self.series_var, return_dict=True) updated_x_dict = get_dict_values(x_dict, list(self.update_value.values()), return_dict=True) step_loss_sum = 0 # set max_iter if self.max_iter: max_iter = self.max_iter else: max_iter = len(series_x_dict[self.series_var[0]]) if "mask" in kwargs.keys(): mask = kwargs["mask"].float() else: mask = None for t in range(max_iter): if self.timestep_var: x_dict.update({self.timestep_var[0]: t}) if not self.slice_step: # update series inputs & use slice_step_fn x_dict.update(self.slice_step_fn(t, series_x_dict)) # evaluate step_loss, samples = self.step_loss.eval(x_dict, return_dict=True, return_all=False) x_dict.update(samples) if mask is not None: step_loss *= mask[t] step_loss_sum += step_loss # update x_dict = replace_dict_keys(x_dict, self.update_value) loss = step_loss_sum # Restore original values x_dict.update(series_x_dict) x_dict.update(updated_x_dict) # TODO: x_dict contains no-updated variables. return loss, x_dict