Source code for pixyz.flows.operations

import torch
import torch.nn.functional as F
import numpy as np

from .flows import Flow
from ..utils import sum_samples


[docs]class Squeeze(Flow): """ Squeeze operation. c * s * s -> 4c * s/2 * s/2 Examples -------- >>> import torch >>> a = torch.tensor([i+1 for i in range(16)]).view(1,1,4,4) >>> print(a) tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]]]]) >>> f = Squeeze() >>> print(f(a)) tensor([[[[ 1, 3], [ 9, 11]], <BLANKLINE> [[ 2, 4], [10, 12]], <BLANKLINE> [[ 5, 7], [13, 15]], <BLANKLINE> [[ 6, 8], [14, 16]]]]) >>> print(f.inverse(f(a))) tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16]]]]) """ def __init__(self): super().__init__(None) self._logdet_jacobian = 0
[docs] def forward(self, x, y=None, compute_jacobian=True): [_, channels, height, width] = x.shape if height % 2 != 0 or width % 2 != 0: raise ValueError x = x.permute(0, 2, 3, 1) x = x.view(-1, height // 2, 2, width // 2, 2, channels) x = x.permute(0, 1, 3, 5, 2, 4) x = x.contiguous().view(-1, height // 2, width // 2, channels * 4) z = x.permute(0, 3, 1, 2) return z
[docs] def inverse(self, z, y=None): [_, channels, height, width] = z.shape if channels % 4 != 0: raise ValueError z = z.permute(0, 2, 3, 1) z = z.view(-1, height, width, channels // 4, 2, 2) z = z.permute(0, 1, 4, 2, 5, 3) z = z.contiguous().view(-1, 2 * height, 2 * width, channels // 4) x = z.permute(0, 3, 1, 2) return x
[docs]class Unsqueeze(Squeeze): """ Unsqueeze operation. c * s * s -> c/4 * 2s * 2s Examples -------- >>> import torch >>> a = torch.tensor([i+1 for i in range(16)]).view(1,4,2,2) >>> print(a) tensor([[[[ 1, 2], [ 3, 4]], <BLANKLINE> [[ 5, 6], [ 7, 8]], <BLANKLINE> [[ 9, 10], [11, 12]], <BLANKLINE> [[13, 14], [15, 16]]]]) >>> f = Unsqueeze() >>> print(f(a)) tensor([[[[ 1, 5, 2, 6], [ 9, 13, 10, 14], [ 3, 7, 4, 8], [11, 15, 12, 16]]]]) >>> print(f.inverse(f(a))) tensor([[[[ 1, 2], [ 3, 4]], <BLANKLINE> [[ 5, 6], [ 7, 8]], <BLANKLINE> [[ 9, 10], [11, 12]], <BLANKLINE> [[13, 14], [15, 16]]]]) """
[docs] def forward(self, x, y=None, compute_jacobian=True): return super().inverse(x)
[docs] def inverse(self, z, y=None): return super().forward(z)
[docs]class Permutation(Flow): """ Examples -------- >>> import torch >>> a = torch.tensor([i+1 for i in range(16)]).view(1,4,2,2) >>> print(a) tensor([[[[ 1, 2], [ 3, 4]], <BLANKLINE> [[ 5, 6], [ 7, 8]], <BLANKLINE> [[ 9, 10], [11, 12]], <BLANKLINE> [[13, 14], [15, 16]]]]) >>> perm = [0,3,1,2] >>> f = Permutation(perm) >>> f(a) tensor([[[[ 1, 2], [ 3, 4]], <BLANKLINE> [[13, 14], [15, 16]], <BLANKLINE> [[ 5, 6], [ 7, 8]], <BLANKLINE> [[ 9, 10], [11, 12]]]]) >>> f.inverse(f(a)) tensor([[[[ 1, 2], [ 3, 4]], <BLANKLINE> [[ 5, 6], [ 7, 8]], <BLANKLINE> [[ 9, 10], [11, 12]], <BLANKLINE> [[13, 14], [15, 16]]]]) """ def __init__(self, permute_indices): super().__init__(len(permute_indices)) self.permute_indices = permute_indices self.inv_permute_indices = np.argsort(self.permute_indices) self._logdet_jacobian = 0
[docs] def forward(self, x, y=None, compute_jacobian=True): if x.dim() == 2: return x[:, self.permute_indices] elif x.dim() == 4: return x[:, self.permute_indices, :, :] raise ValueError
[docs] def inverse(self, z, y=None): if z.dim() == 2: return z[:, self.inv_permute_indices] elif z.dim() == 4: return z[:, self.inv_permute_indices, :, :] raise ValueError
[docs]class Shuffle(Permutation): def __init__(self, in_features): permute_indices = np.random.permutation(in_features) super().__init__(permute_indices)
[docs]class Reverse(Permutation): def __init__(self, in_features): permute_indices = np.array(np.arange(0, in_features)[::-1]) super().__init__(permute_indices)
[docs]class Flatten(Flow): def __init__(self, in_size=None): super().__init__(None) self.in_size = in_size self._logdet_jacobian = 0
[docs] def forward(self, x, y=None, compute_jacobian=True): self.in_size = x.shape[1:] return x.view(x.size(0), -1)
[docs] def inverse(self, z, y=None): if self.in_size is None: raise ValueError return z.view(z.size(0), self.in_size[0], self.in_size[1], self.in_size[2])
[docs]class Preprocess(Flow): def __init__(self): super().__init__(None) self.register_buffer('data_constraint', torch.tensor([0.05], dtype=torch.float32))
[docs] @staticmethod def logit(x): return x.log() - (1. - x).log()
[docs] def forward(self, x, y=None, compute_jacobian=True): # 1. transform the domain of x from [0, 1] to [0, 255] x = x * 255 # 2-1. add noise to pixels to dequantize them and transform its domain ([0, 255]->[0, 1]). x = (x + torch.rand_like(x)) / 256. # 2-2. transform pixel values with logit to be unconstrained ([0, 1]->(0, 1)). x = (1 + (2 * x - 1) * (1 - self.data_constraint)) / 2. # 2-3. apply the logit function ((0, 1)->(-inf, inf)). z = self.logit(x) if compute_jacobian: # log-det Jacobian of transformation (2) logdet_jacobian = F.softplus(z) + F.softplus(-z) \ - F.softplus(self.data_constraint.log() - (1. - self.data_constraint).log()) logdet_jacobian = sum_samples(logdet_jacobian) # log-det Jacobian of transformation (1) logdet_jacobian = logdet_jacobian - np.log(256.) * z[0].numel() self._logdet_jacobian = logdet_jacobian return z
[docs] def inverse(self, z, y=None): # transform the domain of z from (-inf, inf) to (0, 1). return torch.sigmoid(z)