Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active August 20, 2019 11:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vadimkantorov/67249eeca48a081a592622cc1a3cd14c to your computer and use it in GitHub Desktop.
Save vadimkantorov/67249eeca48a081a592622cc1a3cd14c to your computer and use it in GitHub Desktop.
Invertible 1x1 convolution in pure PyTorch (extracted from Glow packages)
# Original code from OpenAI Glow: https://github.com/openai/glow/blob/master/model.py
# This impl is inspired by this PyTorch reference: https://github.com/rosinality/glow-pytorch/blob/master/model.py
# This impl does not include inverse() and log_abs_det_jacobian() computation.
import torch
class InvConvNd(torch.nn.Module):
def __init__(self, in_channels, gain = 1e-3):
super().__init__()
Q = torch.empty(in_channels, in_channels)
torch.nn.init.orthogonal_(Q, gain = gain)
w_p, w_l, w_u = torch.lu_unpack(*Q.lu())
w_s = w_u.diag()
u_mask = torch.ones_like(w_u).triu(1)
self.register_buffer('w_p', w_p)
self.register_buffer('u_mask', u_mask)
self.register_buffer('l_eye', w_l.diag())
self.register_buffer('s_sign', w_s.sign())
self.w_l = torch.nn.Parameter(w_l)
self.w_s_abs_log = torch.nn.Parameter(w_s.abs().log())
self.w_u = torch.nn.Parameter(w_u)
def forward(self, x):
return torch.nn.functional.conv1d(x, self.weight[:, :, None]) if x.ndim == 3 else torch.nn.functional.conv2d(x, self.weight[:, :, None, None])
@property
def weight(self):
return (self.w_p @ (self.w_l * self.u_mask.t() + self.l_eye) @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * self.w_s_abs_log.exp())))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment