Source code for pixyz.flows.flows

from torch import nn


[docs]class Flow(nn.Module): """Flow class. In Pixyz, all flows are required to inherit this class."""
[docs] def __init__(self, in_features): """ Parameters ---------- in_features : int Size of input data. """ super().__init__() self._in_features = in_features self._logdet_jacobian = None
@property def in_features(self): return self._in_features
[docs] def forward(self, x, y=None, compute_jacobian=True): """ Forward propagation of flow layers. Parameters ---------- x : torch.Tensor Input data. y : torch.Tensor, defaults to None Data for conditioning. compute_jacobian : bool, defaults to True Whether to calculate and store log-determinant Jacobian. If true, calculated Jacobian values are stored in :attr:`logdet_jacobian`. Returns ------- z : torch.Tensor """ z = x return z
[docs] def inverse(self, z, y=None): """ Backward (inverse) propagation of flow layers. In this method, log-determinant Jacobian is not calculated. Parameters ---------- z : torch.Tensor Input data. y : torch.Tensor, defaults to None Data for conditioning. Returns ------- x : torch.Tensor """ x = z return x
@property def logdet_jacobian(self): """ Get log-determinant Jacobian. Before calling this, you should run :attr:`forward` or :attr:`update_jacobian` methods to calculate and store log-determinant Jacobian. """ return self._logdet_jacobian
[docs]class FlowList(Flow):
[docs] def __init__(self, flow_list): """ Hold flow modules in a list. Once initializing, it can be handled as a single flow module. Notes ----- Indexing is not supported for now. Parameters ---------- flow_list : list """ super().__init__(flow_list[0].in_features) self.flow_list = nn.ModuleList(flow_list)
[docs] def forward(self, x, y=None, compute_jacobian=True): logdet_jacobian = 0 for flow in self.flow_list: x = flow.forward(x, y, compute_jacobian) if compute_jacobian: logdet_jacobian = logdet_jacobian + flow.logdet_jacobian if compute_jacobian: self._logdet_jacobian = logdet_jacobian return x
[docs] def inverse(self, z, y=None): for flow in self.flow_list[::-1]: z = flow.inverse(z, y) return z
def __repr__(self): # rename "ModuleList" to "FlowList" flow_list_repr = self.flow_list.__repr__().replace("ModuleList", "FlowList") return flow_list_repr