Source code for pixyz.distributions.flows

from __future__ import print_function
import math

import torch
from torch import nn
import torch.nn.functional as F

from ..utils import get_dict_values, epsilon
from .distributions import Distribution


class Flow(Distribution):
    def __init__(self, prior, dim, num_layers=1, var=[],
                 flow_layer=None, flow_name=None, name="p"):
        super().__init__(cond_var=prior.cond_var, var=var,
                         name=name, dim=dim)
        self.prior = prior
        self.flows = nn.ModuleList([flow_layer(dim)
                                    for _ in range(num_layers)])
        self._flow_name = flow_name

    @property
    def prob_text(self):
        _var_text = []
        _text = "{}={}({})".format(','.join(self._var),
                                   self._flow_name,
                                   ','.join(self.prior.var))
        _var_text += [_text]
        if len(self._cond_var) != 0:
            _var_text += [','.join(self._cond_var)]

        _prob_text = "{}({})".format(
            self._name,
            "|".join(_var_text)
        )

        return _prob_text

    def forward(self, x, jacobian=False):
        if jacobian is False:
            for flow in self.flows:
                x = flow(x)
            output = x

        else:
            logdet_jacobian = 0
            for flow in self.flows:
                x, _logdet_jacobian = flow(x, jacobian)
                logdet_jacobian += _logdet_jacobian
            output = logdet_jacobian

        return output

    def sample(self, x={}, only_flow=False, return_all=True, **kwargs):
        if only_flow:
            _samples = get_dict_values(x, self.var)
        else:
            samples_dict = self.prior.sample(x, **kwargs)
            _samples = get_dict_values(samples_dict, self.prior.var)
        output = self.forward(_samples[0], jacobian=False)
        output_dict = {self.var[0]: output}

        if return_all:
            output_dict.update(samples_dict)

        return output_dict

    def log_likelihood(self, x):
        log_dist = self.prior.log_likelihood(x)

        x_values = get_dict_values(x, self.prior.var)
        logdet_jacobian = self.forward(x_values[0], jacobian=True)

        return log_dist - logdet_jacobian


[docs]class PlanarFlow(Flow): def __init__(self, prior, dim, num_layers=1, var=[], **kwargs): super(PlanarFlow, self).__init__(prior, dim, num_layers=num_layers, var=var, flow_layer=PlanarFlowLayer, flow_name="PlanarFlow", **kwargs)
class PlanarFlowLayer(nn.Module): def __init__(self, in_features): super(PlanarFlowLayer, self).__init__() self.in_features = in_features self.weight = nn.Parameter(torch.Tensor(1, in_features)) self.bias = nn.Parameter(torch.Tensor(1)) self.u = nn.Parameter(torch.Tensor(1, in_features)) self.reset_params() def reset_params(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv) self.u.data.uniform_(-stdv, stdv) def forward(self, x, jacobian=False): z = F.tanh(F.linear(x, self.weight, self.bias)) output = x + self.u * z if jacobian: # TODO: use autograd z_grad = (1 - z ** 2) psi = z_grad * self.weight det_grad = 1 + torch.mm(psi, self.u.t()) logdet_jacobian = torch.log(torch.abs(det_grad) + epsilon()) return output, logdet_jacobian return output