Skip to content

Instantly share code, notes, and snippets.

@el-hult
Last active February 25, 2021 08:15
Show Gist options
  • Save el-hult/06d838b2efb3920a30917afcf9327bb4 to your computer and use it in GitHub Desktop.
Save el-hult/06d838b2efb3920a30917afcf9327bb4 to your computer and use it in GitHub Desktop.
class MemorizingNormalizer(nn.Module):
def __init__(self, d, eps, rho):
super().__init__()
self.means = nn.Parameter(torch.zeros(d), requires_grad=False)
self.vars = nn.Parameter(torch.ones(d), requires_grad=False)
self.eps = nn.Parameter(torch.tensor(eps, dtype=float), requires_grad=False)
self.rho = nn.Parameter(torch.tensor(rho, dtype=float), requires_grad=False)
def forward(self, x):
self.means.data = self.means * self.rho + (1 - self.rho) * x.mean(axis=0)
self.vars.data = self.vars * self.rho + (1 - self.rho) * x.var(axis=0)
varse = self.vars + self.eps
lj = -0.5 * torch.log(torch.prod(varse)) # log jacobian determinant
z = (x - self.means) / torch.sqrt(varse)
return z, lj
def inverse(self, z):
return z * torch.sqrt(self.vars + self.eps) + self.means
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment