Skip to content

Instantly share code, notes, and snippets.

@sumanmichael
Created June 21, 2021 11:01
Show Gist options
  • Save sumanmichael/2e5108ae084ade5731f18f17eb6ea961 to your computer and use it in GitHub Desktop.
Save sumanmichael/2e5108ae084ade5731f18f17eb6ea961 to your computer and use it in GitHub Desktop.
Tensorflow BatchNormalization's Equivalent in PyTorch (incl. Loading Weights)
class BatchNorm2d(nn.Module):
# `num_features`: the number of output channels for a convolutional layer.
def __init__(self, num_features):
super().__init__()
shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.zeros(shape)
self.moving_variance = torch.ones(shape)
def forward(self, x):
if self.moving_mean.device != x.device:
self.moving_mean = self.moving_mean.to(x.device)
self.moving_variance = self.moving_variance.to(x.device)
y = self._batch_norm(x, eps=1e-3, momentum=0.99)
return y
def _batch_norm(self, x, eps, momentum):
# Corresponding Equivalents
gamma = self.weight
beta = self.bias
if not torch.is_grad_enabled():
x_hat = (x - self.moving_mean) / torch.sqrt(self.moving_var + eps)
else:
assert len(x.shape) == 4
# When using a 2D Conv, calculate the mean and variance on the channel dimension (axis=1).
mean = x.mean(dim=(0, 2, 3), keepdim=True)
var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
x_hat = (x - mean) / torch.sqrt(var + eps)
self.moving_mean = momentum * self.moving_mean + (1.0 - momentum) * mean
self.moving_var = momentum * self.moving_var + (1.0 - momentum) * var
y = gamma * x_hat + beta # Scale and shift
return y
#Loading Weights
bn = BatchNorm2d(256)
#TF weights with rank 1
bnn.weight = torch.nn.Parameter(torch.tensor(gamma).view(1, -1, 1, 1))
bnn.bias = torch.nn.Parameter(torch.tensor(beta).view(1, -1, 1, 1))
bnn.moving_mean = torch.tensor(moving_mean).view(1, -1, 1, 1)
bnn.moving_variance = torch.tensor(moving_variance).view(1, -1, 1, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment