Skip to content

Instantly share code, notes, and snippets.

@amiasato
Last active April 16, 2024 06:58
Show Gist options
  • Star 16 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save amiasato/902fc14afa37a7537386f7b0c5537741 to your computer and use it in GitHub Desktop.
Save amiasato/902fc14afa37a7537386f7b0c5537741 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
def lengths_to_mask(lengths, max_len=None, dtype=None):
"""
Converts a "lengths" tensor to its binary mask representation.
Based on: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397
:lengths: N-dimensional tensor
:returns: N*max_len dimensional tensor. If max_len==None, max_len=max(lengtsh)
"""
assert len(lengths.shape) == 1, 'Length shape should be 1 dimensional.'
max_len = max_len or lengths.max().item()
mask = torch.arange(
max_len,
device=lengths.device,
dtype=lengths.dtype)\
.expand(len(lengths), max_len) < lengths.unsqueeze(1)
if dtype is not None:
mask = torch.as_tensor(mask, dtype=dtype, device=lengths.device)
return mask
class MaskedBatchNorm1d(nn.BatchNorm1d):
"""
Masked verstion of the 1D Batch normalization.
Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py
Receives a N-dim tensor of sequence lengths per batch element
along with the regular input for masking.
Check pytorch's BatchNorm1d implementation for argument details.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MaskedBatchNorm1d, self).__init__(
num_features,
eps,
momentum,
affine,
track_running_stats
)
def forward(self, inp, lengths):
self._check_input_dim(inp)
exponential_average_factor = 0.0
# We transform the mask into a sort of P(inp) with equal probabilities
# for all unmasked elements of the tensor, and 0 probability for masked
# ones.
mask = lengths_to_mask(lengths, max_len=inp.shape[-1], dtype=inp.dtype)
n = mask.sum()
mask = mask / n
mask = mask.unsqueeze(1).expand(inp.shape)
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if not self.track_running_stats: # Should raise an exception if n==1
mean = (mask * inp).sum([0, 2])
var = ((mask * inp ** 2).sum([0, 2]) - mean ** 2) * n / (n - 1)
elif self.training and n > 1:
# Here lies the trick. Using Var(X) = E[X^2] - E[X]^2 as the biased
# variance, we do not need to make any tensor shape manipulation.
# mean = E[X] is simply the sum-product of our "probability" mask with the input...
mean = (mask * inp).sum([0, 2])
# ...whereas Var(X) is directly derived from the above formulae
# This should be numerically equivalent to the biased sample variance
var = (mask * inp ** 2).sum([0, 2]) - mean ** 2
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# Update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
inp = (inp - mean[None, :, None]) / (torch.sqrt(var[None, :, None] + self.eps))
if self.affine:
inp = inp * self.weight[None, :, None] + self.bias[None, :, None]
return inp
@seovchinnikov
Copy link

Thanks!
Did you test the convergence of your version compared to original one with full-length mask?

@amiasato
Copy link
Author

amiasato commented May 5, 2020

I ran only one apples-to-apples comparison in an audio classification task. Training loss was consistently lower with the masked version, with no visible impact in step processing times, though I don't have the loss curves right now. Input sizes varied in length from 1s to 15s, but most examples were within the 3~7s range. I've been using this masked batch norm as an almost drop-in replacement since then.

@CharlieCheckpt
Copy link

Hello @amiasato, thank you for the piece of code !

I am struggling a bit to understand if it is supposed to work for a 3d input matrix. For instance what would be the lengths I should give to MaskedBatchNorm1d for the x given below ?

import torch 

x_no_pad = torch.ones(2, 5, 3)  # non padded samples (n=2)
x_pad = torch.zeros(1, 5, 3)  # padded sample (n=1)
x_pad[:, 0, :] = 1  # we add one non-padded entry for this sample
x = torch.cat([x_no_pad, x_pad])

@YooSungHyun
Copy link

Plz make layer_norm~~~

@4722794
Copy link

4722794 commented Aug 7, 2023

Hi, could you give a sample code that utilizes this masked batch norm?
I am struggling to understand your code, and a working example could greatly help!

@hypnopump
Copy link

@YooSungHyun layernorm does not need mask as the normalization is only through the last dim. So masked elements are normalized only by their mean and variance, and non-masked elements are also only normalized by their own mean and variance. Masked and non-masked are never merged

@tubali12345
Copy link

This implementation is not working in case track_running_stats=False as self.running_mean and self.running_var is None both at training and inference. How can I modify the implementation so that it works fine in this case as well?

@amiasato
Copy link
Author

amiasato commented Apr 15, 2024

@tubali12345 I can't verify the code right now, but you'll need to use the masked mean/var as-is in all cases. No running stats means that the n==1 case is ill-defined, so you should raise an exception for that. Edit: I modified the code to cover your case, PLMK if it works. Edit2: I'm unsure of whether we use the biased or unbiased instantaneous variance now in the untracked stats version, so feel free to fix that as well.

@tubali12345
Copy link

I covered the n==1 case and it works fine, thanks! I think we should use unbiased variance in this case, but I am not sure either.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment