# Source code for pixyz.flows.coupling

import torch
import numpy as np

from .flows import Flow

[docs]class AffineCoupling(Flow):
r"""
Affine coupling layer

.. math::
:nowrap:

\begin{eqnarray*}
\mathbf{y}_{1:d} &=& \mathbf{x}_{1:d} \\
\mathbf{y}_{d+1:D} &=& \mathbf{x}_{d+1:D} \odot \exp(s(\mathbf{x}_{1:d})+t(\mathbf{x}_{1:d}))
\end{eqnarray*}

"""

scale_net=None, translate_net=None, scale_translate_net=None,
super().__init__(in_features)

else:
raise ValueError

self.scale_net = None
self.translate_net = None
self.scale_translate_net = None

if scale_net and translate_net:
self.scale_net = scale_net
self.translate_net = translate_net
elif scale_translate_net:
self.scale_translate_net = scale_translate_net
else:
raise ValueError

"""
Parameters
----------
x : torch.Tensor

Returns
-------

Examples
--------
>>> scale_translate_net = lambda x: (x, x)
>>> f1 = AffineCoupling(4, mask_type="channel_wise", scale_translate_net=scale_translate_net,
>>> x1 = torch.randn([1,4,3,3])
tensor([[[[1.]],
<BLANKLINE>
[[1.]],
<BLANKLINE>
[[0.]],
<BLANKLINE>
[[0.]]]])
>>> f2 = AffineCoupling(2, mask_type="checkerboard", scale_translate_net=scale_translate_net,
>>> x2 = torch.randn([1,2,5,5])
tensor([[[[0., 1., 0., 1., 0.],
[1., 0., 1., 0., 1.],
[0., 1., 0., 1., 0.],
[1., 0., 1., 0., 1.],
[0., 1., 0., 1., 0.]]]])

"""
if x.dim() == 4:
[_, channels, height, width] = x.shape
else:

elif x.dim() == 2:
[_, n_features] = x.shape

raise ValueError

[docs]    def get_parameters(self, x, y=None):
r"""
Parameters
----------
x : torch.tensor
y : torch.tensor

Returns
-------
s : torch.tensor
t : torch.tensor

Examples
--------
>>> # In case of using scale_translate_net
>>> scale_translate_net = lambda x: (x, x)
>>> f1 = AffineCoupling(4, mask_type="channel_wise", scale_translate_net=scale_translate_net,
>>> x1 = torch.randn([1,4,3,3])
>>> log_s, t = f1.get_parameters(x1)
>>> # In case of using scale_net and translate_net
>>> scale_net = lambda x: x
>>> translate_net = lambda x: x
>>> f2 = AffineCoupling(4, mask_type="channel_wise", scale_net=scale_net, translate_net=translate_net,
>>> x2 = torch.randn([1,4,3,3])
>>> log_s, t = f2.get_parameters(x2)
"""

if self.scale_translate_net:
if y is None:
log_s, t = self.scale_translate_net(x)
else:
log_s, t = self.scale_translate_net(x, y)
else:
if y is None:
log_s = self.scale_net(x)
t = self.translate_net(x)
else:
log_s = self.scale_net(x, y)
t = self.translate_net(x, y)

return log_s, t

[docs]    def forward(self, x, y=None, compute_jacobian=True):

log_s = log_s * (1 - mask)
t = t * (1 - mask)

if compute_jacobian:
self._logdet_jacobian = log_s.contiguous().view(log_s.size(0), -1).sum(-1)

return x

[docs]    def inverse(self, z, y=None):

log_s = log_s * (1 - mask)
t = t * (1 - mask)

return z

[docs]    def extra_repr(self):
)

r"""
Parameters
----------
height : int
width : int

Returns
-------

Examples
--------
array([[1., 0., 1., 0.],
[0., 1., 0., 1.],
[1., 0., 1., 0.],
[0., 1., 0., 1.],
[1., 0., 1., 0.]], dtype=float32)
array([[0., 1., 0., 1.],
[1., 0., 1., 0.],
[0., 1., 0., 1.],
[1., 0., 1., 0.],
[0., 1., 0., 1.]], dtype=float32)

"""
mask = np.arange(height).reshape(-1, 1) + np.arange(width)

r"""
Parameters
----------
channels : int

Returns
-------

Examples
--------