Last active
August 20, 2019 11:21
-
-
Save vadimkantorov/67249eeca48a081a592622cc1a3cd14c to your computer and use it in GitHub Desktop.
Invertible 1x1 convolution in pure PyTorch (extracted from Glow packages)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original code from OpenAI Glow: https://github.com/openai/glow/blob/master/model.py | |
# This impl is inspired by this PyTorch reference: https://github.com/rosinality/glow-pytorch/blob/master/model.py | |
# This impl does not include inverse() and log_abs_det_jacobian() computation. | |
import torch | |
class InvConvNd(torch.nn.Module): | |
def __init__(self, in_channels, gain = 1e-3): | |
super().__init__() | |
Q = torch.empty(in_channels, in_channels) | |
torch.nn.init.orthogonal_(Q, gain = gain) | |
w_p, w_l, w_u = torch.lu_unpack(*Q.lu()) | |
w_s = w_u.diag() | |
u_mask = torch.ones_like(w_u).triu(1) | |
self.register_buffer('w_p', w_p) | |
self.register_buffer('u_mask', u_mask) | |
self.register_buffer('l_eye', w_l.diag()) | |
self.register_buffer('s_sign', w_s.sign()) | |
self.w_l = torch.nn.Parameter(w_l) | |
self.w_s_abs_log = torch.nn.Parameter(w_s.abs().log()) | |
self.w_u = torch.nn.Parameter(w_u) | |
def forward(self, x): | |
return torch.nn.functional.conv1d(x, self.weight[:, :, None]) if x.ndim == 3 else torch.nn.functional.conv2d(x, self.weight[:, :, None, None]) | |
@property | |
def weight(self): | |
return (self.w_p @ (self.w_l * self.u_mask.t() + self.l_eye) @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * self.w_s_abs_log.exp()))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment