Source code for pixyz.flows.normalizations

import torch
from torch import nn
import numpy as np

from .flows import Flow
from ..utils import epsilon


[docs]class BatchNorm1d(Flow): """ A batch normalization with the inverse transformation. Notes ----- This is implemented with reference to the following code. https://github.com/ikostrikov/pytorch-flows/blob/master/flows.py#L205 Examples -------- >>> x = torch.randn(20, 100) >>> f = BatchNorm1d(100) >>> # transformation >>> z = f(x) >>> # reconstruction >>> _x = f.inverse(f(x)) >>> # check this reconstruction >>> diff = torch.sum(torch.abs(_x-x)).data >>> diff < 0.1 tensor(1, dtype=torch.uint8) """ def __init__(self, in_features, momentum=0.0): super().__init__(in_features) self.log_gamma = nn.Parameter(torch.zeros(in_features)) self.beta = nn.Parameter(torch.zeros(in_features)) self.momentum = momentum self.register_buffer('running_mean', torch.zeros(in_features)) self.register_buffer('running_var', torch.ones(in_features))
[docs] def forward(self, x, y=None, compute_jacobian=True): if self.training: self.batch_mean = x.mean(0) self.batch_var = (x - self.batch_mean).pow(2).mean(0) + epsilon() self.running_mean = self.running_mean * self.momentum self.running_var = self.running_var * self.momentum self.running_mean = self.running_mean + (self.batch_mean.data * (1 - self.momentum)) self.running_var = self.running_var + (self.batch_var.data * (1 - self.momentum)) mean = self.batch_mean var = self.batch_var else: mean = self.running_mean var = self.running_var x_hat = (x - mean) / var.sqrt() z = torch.exp(self.log_gamma) * x_hat + self.beta if compute_jacobian: self._logdet_jacobian = (self.log_gamma - 0.5 * torch.log(var)).sum(-1) return z
[docs] def inverse(self, z, y=None): if self.training: mean = self.batch_mean var = self.batch_var else: mean = self.running_mean var = self.running_var x_hat = (z - self.beta) / torch.exp(self.log_gamma) x = x_hat * var.sqrt() + mean return x
[docs]class BatchNorm2d(BatchNorm1d): """ A batch normalization with the inverse transformation. Notes ----- This is implemented with reference to the following code. https://github.com/ikostrikov/pytorch-flows/blob/master/flows.py#L205 Examples -------- >>> x = torch.randn(20, 100, 35, 45) >>> f = BatchNorm2d(100) >>> # transformation >>> z = f(x) >>> # reconstruction >>> _x = f.inverse(f(x)) >>> # check this reconstruction >>> diff = torch.sum(torch.abs(_x-x)).data >>> diff < 0.1 tensor(1, dtype=torch.uint8) """ def __init__(self, in_features, momentum=0.0): super().__init__(in_features, momentum) self.log_gamma = nn.Parameter(self._unsqueeze(self.log_gamma.data)) self.beta = nn.Parameter(self._unsqueeze(self.beta.data)) self.register_buffer('running_mean', self._unsqueeze(self.running_mean)) self.register_buffer('running_var', self._unsqueeze(self.running_var)) def _unsqueeze(self, x): return x.unsqueeze(1).unsqueeze(2)
[docs]class ActNorm2d(Flow): """ Activation Normalization Initialize the bias and scale with a given minibatch, so that the output per-channel have zero mean and unit variance for that. After initialization, `bias` and `logs` will be trained as parameters. Notes ----- This is implemented with reference to the following code. https://github.com/chaiyujin/glow-pytorch/blob/master/glow/modules.py """ def __init__(self, in_features, scale=1.): super().__init__(in_features) # register mean and scale size = [1, in_features, 1, 1] self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) self.scale = float(scale) self.inited = False
[docs] def initialize_parameters(self, x): if not self.training: return assert x.device == self.bias.device with torch.no_grad(): bias = torch.mean(x.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 vars = torch.mean((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) logs = torch.log(self.scale / (torch.sqrt(vars) + epsilon())) self.bias.data.copy_(bias.data) self.logs.data.copy_(logs.data) self.inited = True
def _center(self, x, inverse=False): if not inverse: return x + self.bias else: return x - self.bias def _scale(self, x, compute_jacobian=True, inverse=False): logs = self.logs if not inverse: x = x * torch.exp(logs) else: x = x * torch.exp(-logs) if compute_jacobian: """ logs is log_std of `mean of channels` so we need to multiply pixels """ pixels = np.prod(x.size()[2:]) logdet_jacobian = torch.sum(logs) * pixels return x, logdet_jacobian return x, None
[docs] def forward(self, x, y=None, compute_jacobian=True): if not self.inited: self.initialize_parameters(x) # center and scale x = self._center(x, inverse=False) x, logdet_jacobian = self._scale(x, compute_jacobian, inverse=False) if compute_jacobian: self._logdet_jacobian = logdet_jacobian return x
[docs] def inverse(self, x, y=None): if not self.inited: self.initialize_parameters(x) # scale and center x, _ = self._scale(x, compute_jacobian=False, inverse=True) x = self._center(x, inverse=True) return x