from __future__ import print_function
import torch
from torch import nn
import torch.nn.functional as F
from ..utils import get_dict_values, epsilon
from .distributions import Distribution
[docs]class RealNVP(Distribution):
def __init__(self, prior, dim,
num_multiscale_layers=2,
var=[], image=False, name="p",
**kwargs):
super(RealNVP, self).__init__(cond_var=prior.cond_var, var=var, name=name, dim=dim)
self.prior = prior
flow_list = [MultiScaleLayer1D(dim, layer_id=layer_id, **kwargs)
for layer_id in range(num_multiscale_layers)]
self.image = image
self.flows = nn.ModuleList(flow_list)
self._flow_name = "RealNVP"
@property
def prob_text(self):
_var_text = []
_text = "{}={}({})".format(','.join(self._var),
self._flow_name,
','.join(self.prior.var))
_var_text += [_text]
if len(self._cond_var) != 0:
_var_text += [','.join(self._cond_var)]
_prob_text = "{}({})".format(
self._name,
"|".join(_var_text)
)
return _prob_text
[docs] def forward(self, x, inverse=False, jacobian=False):
logdet_jacobian = 0
if inverse is False:
_flows = self.flows
x_use = x
x_disuse = None
if self.image:
# corrupt data (Tapani Raiko's dequantization)
x_use = x_use * 255.0
corruption_level = 1.0
x_use = x_use +\
corruption_level * torch.empty_like(x_use).uniform_(0, 1)
x_use = x_use / (255.0 + corruption_level)
# model logit
alpha = .05
# avoid 0 when x_use = 1
x_use = x_use * (1 - alpha) + alpha * epsilon()
jac = torch.sum(-torch.log(x_use)
- torch.log(1 - x_use), dim=1)
x_use = torch.log(x_use) - torch.log(1 - x_use)
logdet_jacobian += jac
else:
_flows = self.flows[::-1]
x_use = None
x_disuse = x
if jacobian is False:
for i, flow in enumerate(_flows):
x_use, x_disuse = flow(x_use, x_disuse, inverse=inverse)
else:
for i, flow in enumerate(_flows):
x_use, x_disuse, _logdet_jacobian = flow(x_use, x_disuse,
jacobian=jacobian,
inverse=inverse)
logdet_jacobian += _logdet_jacobian
if inverse is False:
x = torch.cat((x_disuse, x_use), dim=1)
else:
x = x_use
if self.image:
# inverse logit
x = 1. / (1 + torch.exp(-x))
if jacobian is False:
return x
else:
return x, logdet_jacobian
[docs] def sample(self, x={}, only_flow=False, return_all=True, **kwargs):
# x~p()
if only_flow:
samples_dict = x
else:
samples_dict = self.prior.sample(x, **kwargs)
_samples = get_dict_values(samples_dict, self.prior.var)
output = self.forward(_samples[0],
inverse=True, jacobian=False)
output_dict = {self.var[0]: output}
if return_all:
output_dict.update(samples_dict)
return output_dict
[docs] def sample_inv(self, x, return_all=True, **kwargs):
# z~p(x)
samples_dict = x
_samples = get_dict_values(x, self.var)
output = self.forward(_samples[0], jacobian=False)
output_dict = {self.prior.var[0]: output}
if return_all:
output_dict.update(samples_dict)
return output_dict
[docs] def log_likelihood(self, x):
# use a bijection function
# z=f(x)
_x = get_dict_values(x, self.var)
z, logdet_jacobian = self.forward(_x[0], jacobian=True)
# log p(z)
log_dist = self.prior.log_likelihood({self.prior.var[0]: z})
output = log_dist + logdet_jacobian
"""
if self.image:
output -= _x[0].shape[1]
* torch.log(torch.tensor(256.).to(_x[0].device))
"""
return output
class MultiScaleLayer1D(nn.Module):
def __init__(self, in_features, layer_id,
hidden_features=64,
num_nn_layers=2,
num_flow_layers=3):
super(MultiScaleLayer1D, self).__init__()
self.in_features = in_features // (2 ** layer_id)
flow_list =\
[AffineCouplingLayer1D(self.in_features,
hidden_features=hidden_features,
num_layers=num_nn_layers,
pattern=i % 2)
for i in range(num_flow_layers)]
self.flows = nn.ModuleList(flow_list)
self.split = SplitLayer(layer_id)
def forward(self, x_use, x_disuse, inverse=False, jacobian=False):
if inverse is False:
_flows = self.flows
else:
_flows = self.flows[::-1]
if inverse:
x_use, x_disuse = self.split.forward(x_use, x_disuse,
inverse=inverse)
if jacobian is False:
for i, flow in enumerate(_flows):
x_use = flow(x_use, inverse=inverse)
else:
logdet_jacobian = 0
for i, flow in enumerate(_flows):
x_use, _logdet_jacobian = flow(x_use, jacobian=jacobian,
inverse=inverse)
logdet_jacobian += _logdet_jacobian
if inverse is False:
x_use, x_disuse = self.split.forward(x_use, x_disuse,
inverse=inverse)
if jacobian is False:
return x_use, x_disuse
else:
return x_use, x_disuse, logdet_jacobian
class AffineCouplingLayer(nn.Module):
def __init__(self, in_features,
masked_type="checkerboard",
pattern=0 # 0 or 1
):
super(AffineCouplingLayer, self).__init__()
self.in_features = in_features
self.masked_type = masked_type
self.pattern = pattern
def _scale_translation(self, x):
NotImplementedError
def _masking(self, x, reverse=False):
NotImplementedError
def forward(self, x, inverse=False, jacobian=False):
# forward: (x->z)
# inverse: (z->x)
x_0 = self._masking(x, False)
x_1 = self._masking(x, True)
scale, trans = self._scale_translation(x_0)
if inverse:
x_1 = (x_1 - trans) / torch.exp(scale)
else:
x_1 = x_1 * torch.exp(scale) + trans
output = x_0 + x_1
if jacobian:
logdet_jacobian = torch.sum(scale, dim=1) # 1D
return output, logdet_jacobian
return output
def extra_repr(self):
return 'in_features={}, pattern={}'.format(self.in_features,
self.pattern)
class AffineCouplingLayer1D(AffineCouplingLayer):
def __init__(self, in_features,
hidden_features=512,
num_layers=2,
masked_type="checkerboard",
pattern=0 # 0 or 1
):
super(AffineCouplingLayer1D,
self).__init__(in_features,
masked_type=masked_type,
pattern=pattern)
self.hidden_features = hidden_features
layer_list = [nn.Linear(in_features,
hidden_features)]
layer_list += [nn.Linear(hidden_features,
hidden_features)
for _ in range(num_layers-2)]
layer_list += [nn.Linear(hidden_features,
2 * in_features)]
self.layers = nn.ModuleList(layer_list)
batch_norms = [nn.BatchNorm1d(hidden_features)
for _ in range(num_layers-1)]
self.batch_norms = nn.ModuleList(batch_norms)
def _scale_translation(self, x):
for layer, batch_norm in zip(self.layers[:-1],
self.batch_norms):
x = F.relu(batch_norm(layer(x)))
"""
for layer in self.layers[:-1]:
x = F.relu(layer(x))
"""
x = self.layers[-1](x)
scale, trans = torch.chunk(x, chunks=2, dim=-1)
scale = self._masking(torch.tanh(scale), True)
trans = self._masking(trans, True)
return scale, trans
def _masking(self, x, reverse=False):
x_shape = x.shape
if self.masked_type == "checkerboard":
mask = torch.zeros(x_shape[1]).to(x.device)
mask[self.pattern::2] = 1
else:
NotImplementedError
if reverse:
return x * (1 - mask).unsqueeze(0)
else:
return x * mask.unsqueeze(0)
class AffineCouplingLayer2D(AffineCouplingLayer):
def __init__(self, in_features,
hidden_features=512,
num_layers=2,
masked_type="checkerboard",
pattern=0 # 0 or 1
):
super(AffineCouplingLayer2D,
self).__init__(in_features,
masked_type=masked_type,
pattern=pattern)
self.hidden_features = hidden_features
flow_list = [nn.Linear(in_features,
hidden_features)]
flow_list += [nn.Linear(hidden_features,
hidden_features)
for _ in range(num_layers-1)]
flow_list += [nn.Linear(hidden_features,
2 * in_features)]
self.flows = nn.ModuleList(flow_list)
def _scale_translation(self, x):
for flow in self.flows[:-1]:
x = F.relu(flow(x))
x = self.flows[-1](x)
scale, trans = torch.chunk(x, chunks=2, dim=-1)
scale = self._masking(scale, True)
trans = self._masking(trans, True)
return scale, trans
def _masking(self, x, reverse=False):
x_shape = x.shape
if self.masked_type == "checkerboard":
mask = torch.zeros(x_shape[1])
mask[self.pattern::2] = 1
else:
NotImplementedError
if reverse:
return x * (1 - mask).unsqueeze(0)
else:
return x * mask.unsqueeze(0)
class SplitLayer(object):
# Factorizing out/in layer
def __init__(self, layer_id):
self.layer_id = layer_id
def get_split(self, x, split):
x_shape = x.shape
assert x_shape != 2 and x_shape != 4, \
NotImplementedError
if len(x.shape) == 2:
return x[:, :split], x[:, split:]
else:
return x[:, :, :, :split], x[:, :, :, split:]
def forward(self, x_use, x_disuse, inverse=False):
if inverse is False:
# increase x_disuse, decrease x_use
x_use_shape = x_use.shape
assert x_use_shape != 2 and x_use_shape != 4, \
NotImplementedError
split_dim = len(x_use_shape) - 1
split = x_use_shape[split_dim] // 2
_x_disuse, x_use = self.get_split(x_use, split)
if x_disuse is not None:
x_disuse = torch.cat((x_disuse, _x_disuse),
dim=split_dim)
else:
x_disuse = _x_disuse
else:
# increase x_use, decrease x_disuse
x_disuse_shape = x_disuse.shape
assert x_disuse_shape != 2 and x_disuse_shape != 4, \
NotImplementedError
split_dim = len(x_disuse_shape) - 1
if x_use is not None:
split = x_use.shape[split_dim]
else:
split = x_disuse_shape[split_dim] // (2 ** self.layer_id)
x_disuse, _x_use = self.get_split(x_disuse, -split)
if x_use is not None:
x_use = torch.cat((_x_use, x_use), dim=split_dim)
else:
x_use = _x_use
return x_use, x_disuse