Skip to content

Instantly share code, notes, and snippets.

@kashif
Last active Apr 10, 2020
Embed
What would you like to do?
EvoNorm-S0 in PyTorch from https://arxiv.org/pdf/2004.02967.pdf
import torch
import torch.nn as nn
class EvoNorm2d(nn.Module):
__constants__ = ['num_features', 'eps', 'nonlinearity']
def __init__(self, num_features, eps=1e-5, nonlinearity=True):
super(EvoNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.nonlinearity = nonlinearity
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
if self.nonlinearity:
self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlinearity:
nn.init.ones_(self.v)
def group_std(self, x, groups=32):
N, C, H, W = x.shape
x = torch.reshape(x, (N, groups, C//groups, H, W))
std = torch.std(x, (2, 3, 4), keepdim=True).expand_as(x)
return torch.reshape(std + self.eps, (N, C, H, W))
def forward(self, x):
if self.nonlinearity:
num = x * torch.sigmoid(self.v * x)
return num/self.group_std(x) * self.weight + self.bias
else:
return x * self.weight + self.bias
@pinouchon
Copy link

pinouchon commented Apr 9, 2020

Do you have the EvoNorm-B0 by any chance? It looks like this is EvoNorm-S0. I changed it to 1d and it seems to work fine, but I would still prefer the batch version

@digantamisra98
Copy link

digantamisra98 commented Apr 9, 2020

@pinouchon https://github.com/digantamisra98/EvoNorm I have EvoNorm B0 here, however I just have one error of shape mismatch in the running variance calculation to solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment