EvoNorm-S0 in PyTorch from https://arxiv.org/pdf/2004.02967.pdf
import torch | |
import torch.nn as nn | |
class EvoNorm2d(nn.Module): | |
__constants__ = ['num_features', 'eps', 'nonlinearity'] | |
def __init__(self, num_features, eps=1e-5, nonlinearity=True): | |
super(EvoNorm2d, self).__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.nonlinearity = nonlinearity | |
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) | |
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) | |
if self.nonlinearity: | |
self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.ones_(self.weight) | |
nn.init.zeros_(self.bias) | |
if self.nonlinearity: | |
nn.init.ones_(self.v) | |
def group_std(self, x, groups=32): | |
N, C, H, W = x.shape | |
x = torch.reshape(x, (N, groups, C//groups, H, W)) | |
std = torch.std(x, (2, 3, 4), keepdim=True).expand_as(x) | |
return torch.reshape(std + self.eps, (N, C, H, W)) | |
def forward(self, x): | |
if self.nonlinearity: | |
num = x * torch.sigmoid(self.v * x) | |
return num/self.group_std(x) * self.weight + self.bias | |
else: | |
return x * self.weight + self.bias |
This comment has been minimized.
This comment has been minimized.
@pinouchon https://github.com/digantamisra98/EvoNorm I have EvoNorm B0 here, however I just have one error of shape mismatch in the running variance calculation to solve. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This comment has been minimized.
Do you have the EvoNorm-B0 by any chance? It looks like this is EvoNorm-S0. I changed it to 1d and it seems to work fine, but I would still prefer the batch version