Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active May 31, 2019 19:18
Show Gist options
  • Save bkj/347042fc927261570b2c493d49ceb2d5 to your computer and use it in GitHub Desktop.
Save bkj/347042fc927261570b2c493d49ceb2d5 to your computer and use it in GitHub Desktop.
simple_batchnorm.py
#!/usr/bin/env python
"""
simple_batchnorm.py
"""
class SimpleBatchNorm1d(nn.Module):
def __init__(self, dim, momentum=0.5, eps=1e-5, affine=True, track_running_stats=True):
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
super().__init__()
self.running_mean = torch.zeros(1, dim)
self.running_var = torch.zeros(1, dim) + 1
if self.affine:
self.gamma = nn.Parameter(torch.rand(1, dim))
self.beta = nn.Parameter(torch.zeros(1, dim))
def forward(self, x):
if self.training or not self.track_running_stats:
x_mean = x.mean(dim=0, keepdim=True)
x_var = x.var(dim=0, keepdim=True, unbiased=False)
if self.track_running_stats:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * x_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * (x.shape[0] / (x.shape[0] - 1)) * x_var
out = (x - x_mean) / (x_var + self.eps).sqrt()
else:
out = (x - self.running_mean) / (self.running_var + self.eps).sqrt()
if self.affine:
out = out * self.gamma + self.beta
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment