Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Created February 12, 2019 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ajbrock/d4a52ea75a7284ebafbccf87cf63414c to your computer and use it in GitHub Desktop.
Save ajbrock/d4a52ea75a7284ebafbccf87cf63414c to your computer and use it in GitHub Desktop.
# Manual BN
# Calculate means and variances using mean-of-squares mins mean-squared
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# # Calculate expected value of x (m) and expected value of x**2 (m2)
# Mean of x
m = torch.mean(x, [0, 2, 3], keepdim=True)
# Mean of x squared
m2 = torch.mean(x ** 2, [0, 2, 3], keepdim=True)
# Calculate variance as mean of squared minus mean squared.
var = (m2 - m **2)
if return_mean_var:
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
else:
return fused_bn(x, m, var, gain, bias, eps)
# Apply scale and shift--if gain and bias are provided, fuse them here
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
# Prepare scale
scale = torch.rsqrt(var + eps)
# If a gain is provided, use it
if gain is not None:
scale = scale * gain
# Prepare shift
shift = mean * scale
# If bias is provided, use it
if bias is not None:
shift = shift - bias
return x * scale - shift
#return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias
#### This is just a module wrapper that does bookkeeping, the above two functions are what do the batchnorming
# My batchnorm, supports standing stats
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
# momentum for updating running stats
self.momentum = momentum
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Register buffers
self.register_buffer('stored_mean', torch.zeros(num_channels))
self.register_buffer('stored_var', torch.ones(num_channels))
self.register_buffer('accumulation_counter', torch.zeros(1))
# Accumulate running means and vars
self.accumulate_standing = False
# reset standing stats
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
def forward(self, x, gain, bias):
if self.training:
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
if self.accumulate_standing:
self.stored_mean[:] = self.stored_mean + mean.data
self.stored_var[:] = self.stored_var + var.data
self.accumulation_counter += 1.0
# If not accumulating standing stats, take running averages
else:
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
# If not in training mode, don't update stats
else:
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
# If using standing stats, divide them by the accumulation counter
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment