Skip to content

Instantly share code, notes, and snippets.

@SunDoge
Created March 17, 2021 07:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SunDoge/1cfa4e219f349cba36659b19577f172a to your computer and use it in GitHub Desktop.
Save SunDoge/1cfa4e219f349cba36659b19577f172a to your computer and use it in GitHub Desktop.
from torch import nn
import math
import torch
class MomentumBatchNorm3d(nn.BatchNorm3d):
def __init__(self, num_features, eps=1e-5, momentum=1.0, affine=True, track_running_stats=True, total_iters=100):
super(MomentumBatchNorm3d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
self.total_iters = total_iters
self.cur_iter = 0
self.mean_last_batch = None
self.var_last_batch = None
def momentum_cosine_decay(self):
self.cur_iter += 1
self.momentum = (
math.cos(math.pi * (self.cur_iter / self.total_iters)) + 1) * 0.5
def forward(self, x):
# if not self.training:
# return super().forward(x)
# Changed
mean = torch.mean(x, dim=[0, 2, 3, 4])
var = torch.var(x, dim=[0, 2, 3, 4])
n = x.numel() / x.size(1)
with torch.no_grad():
tmp_running_mean = self.momentum * mean + \
(1 - self.momentum) * self.running_mean
# update running_var with unbiased var
tmp_running_var = self.momentum * var * n / \
(n - 1) + (1 - self.momentum) * self.running_var
# Changed
x = (x - tmp_running_mean[None, :, None, None, None].detach()) / (
torch.sqrt(tmp_running_var[None, :,
None, None, None].detach() + self.eps)
)
if self.affine:
x = x * self.weight[None, :, None, None, None] + \
self.bias[None, :, None, None, None]
# update the parameters
if self.mean_last_batch is None and self.var_last_batch is None:
self.mean_last_batch = mean
self.var_last_batch = var
else:
self.running_mean = (
self.momentum * ((mean + self.mean_last_batch) * 0.5) +
(1 - self.momentum) * self.running_mean
)
self.running_var = (
self.momentum * ((var + self.var_last_batch)
* 0.5) * n / (n - 1)
+ (1 - self.momentum) * self.running_var
)
self.mean_last_batch = None
self.var_last_batch = None
self.momentum_cosine_decay()
return x
if __name__ == '__main__':
bn = MomentumBatchNorm3d(3)
x = torch.rand(2, 3, 4, 32, 32)
y = bn(x)
assert y.shape == x.shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment