Skip to content

Instantly share code, notes, and snippets.

@nazfox
Created February 27, 2022 08:31
Show Gist options
  • Save nazfox/b1adaf9419de16e12cba9c959d6f626e to your computer and use it in GitHub Desktop.
Save nazfox/b1adaf9419de16e12cba9c959d6f626e to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class BaseCouplingLayer(nn.Module):
def __init__(self, mask, m):
super().__init__()
self.mask = mask
self.m = m
def first(self, x):
return x * self.mask
def second(self, x):
return x * (1. - self.mask)
def split(self, x):
x1 = self.first(x)
x2 = self.second(x)
return x1, x2
def merge(self, x1, x2):
return x1 + x2
def coupling_law(self, a, b, inverse=False):
raise NotImplementedError()
def transform(self, x1, x2, inverse=False):
y1 = x1
y2 = self.coupling_law(x2, self.m(x1), inverse=inverse)
return y1, y2
def forward(self, x, inverse=False):
x1, x2 = self.split(x)
y1, y2 = self.transform(x1, x2, inverse=inverse)
y = self.merge(y1, y2)
return y
class AdditiveCouplingLayer(BaseCouplingLayer):
def coupling_law(self, a, b, inverse=False):
if inverse is False:
return a + b
else:
return a - b
class MultiplicativeCouplingLayer(BaseCouplingLayer):
def coupling_law(self, a, b, inverse=False):
if inverse is False:
return torch.mul(a, b)
else:
return torch.div(a, b)
class AffineCouplingLayer(BaseCouplingLayer):
def coupling_law(self, a, b, inverse=False):
d = b.size()[-1] // 2
b1 = b[:, :d]
b2 = b[:, d:]
if inverse is False:
return torch.mul(a, b1) + b2
else:
return torch.div(a, b1) - b2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment