import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import scipy as sp
from .flows import Flow
[docs]class ChannelConv(Flow):
"""
Invertible 1 × 1 convolution.
Notes
-----
This is implemented with reference to the following code.
https://github.com/chaiyujin/glow-pytorch/blob/master/glow/modules.py
"""
def __init__(self, in_channels, decomposed=False):
super().__init__(in_channels)
w_shape = [in_channels, in_channels]
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
if not decomposed:
# Sample a random orthogonal matrix:
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
else:
# LU decomposition
np_p, np_l, np_u = sp.linalg.lu(w_init)
np_s = np.diag(np_u)
np_sign_s = np.sign(np_s)
np_log_s = np.log(np.abs(np_s))
np_u = np.triu(np_u, k=1)
l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1)
eye = np.eye(*w_shape, dtype=np.float32)
self.register_buffer('p', torch.Tensor(np_p.astype(np.float32)))
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32)))
self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32)))
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32)))
self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32)))
self.l_mask = torch.Tensor(l_mask)
self.eye = torch.Tensor(eye)
self.w_shape = w_shape
self.decomposed = decomposed
[docs] def get_parameters(self, x, inverse):
w_shape = self.w_shape
pixels = np.prod(x.size()[2:])
device = x.device
if not self.decomposed:
logdet_jacobian = torch.slogdet(self.weight.cpu())[1].to(device) * pixels
if not inverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float().view(w_shape[0], w_shape[1], 1, 1)
return weight, logdet_jacobian
else:
self.p = self.p.to(device)
self.sign_s = self.sign_s.to(device)
self.l_mask = self.l_mask.to(device)
self.eye = self.eye.to(device)
l = self.l * self.l_mask + self.eye
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
logdet_jacobian = torch.sum(self.log_s) * pixels
if not inverse:
w = torch.matmul(self.p, torch.matmul(l, u))
else:
l = torch.inverse(l.double()).float()
u = torch.inverse(u.double()).float()
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
return w.view(w_shape[0], w_shape[1], 1, 1), logdet_jacobian
[docs] def forward(self, x, y=None, compute_jacobian=True):
weight, logdet_jacobian = self.get_parameters(x, inverse=False)
z = F.conv2d(x, weight)
if compute_jacobian:
self._logdet_jacobian = logdet_jacobian
return z
[docs] def inverse(self, x, y=None):
weight, _ = self.get_parameters(x, inverse=True)
z = F.conv2d(x, weight)
return z