Skip to content

Instantly share code, notes, and snippets.

@yangkky
Last active December 10, 2022 20:24
Show Gist options
  • Save yangkky/364413426ec798589463a3a88be24219 to your computer and use it in GitHub Desktop.
Save yangkky/364413426ec798589463a3a88be24219 to your computer and use it in GitHub Desktop.
Masked 1D batchnorm in PyTorch.
import torch
import torch.nn as nn
from torch.nn import init
class MaskedBatchNorm1d(nn.Module):
""" A masked version of nn.BatchNorm1d. Only tested for 3D inputs.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, L)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, L)`
- input_mask: (N, 1, L) tensor of ones and zeros, where the zeros indicate locations not to use.
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MaskedBatchNorm1d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
if affine:
self.weight = nn.Parameter(torch.Tensor(num_features, 1))
self.bias = nn.Parameter(torch.Tensor(num_features, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.track_running_stats = track_running_stats
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, 1))
self.register_buffer('running_var', torch.ones(num_features, 1))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input, input_mask=None):
# Calculate the masked mean and variance
B, C, L = input.shape
if input_mask is not None and input_mask.shape != (B, 1, L):
raise ValueError('Mask should have shape (B, 1, L).')
if C != self.num_features:
raise ValueError('Expected %d channels but input has %d channels' % (self.num_features, C))
if input_mask is not None:
masked = input * input_mask
n = input_mask.sum()
else:
masked = input
n = B * L
# Sum
masked_sum = masked.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True)
# Divide by sum of mask
current_mean = masked_sum / n
current_var = ((masked - current_mean) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n
# Update running stats
if self.track_running_stats and self.training:
if self.num_batches_tracked == 0:
self.running_mean = current_mean
self.running_var = current_var
else:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * current_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * current_var
self.num_batches_tracked += 1
# Norm the input
if self.track_running_stats and not self.training:
normed = (masked - self.running_mean) / (torch.sqrt(self.running_var + self.eps))
else:
normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps))
# Apply affine parameters
if self.affine:
normed = normed * self.weight + self.bias
return normed
@EZ4NO1
Copy link

EZ4NO1 commented May 8, 2021

it has few bugs:
self.running_mean size change from (C,1) to (1,C,1) ,and it make errors when save and load model state_dict (line 90)
self.running_var has the same problem
the calculation of variance does not apply masking properly (line 83)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment