-
-
Save amiasato/902fc14afa37a7537386f7b0c5537741 to your computer and use it in GitHub Desktop.
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def lengths_to_mask(lengths, max_len=None, dtype=None): | |
""" | |
Converts a "lengths" tensor to its binary mask representation. | |
Based on: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397 | |
:lengths: N-dimensional tensor | |
:returns: N*max_len dimensional tensor. If max_len==None, max_len=max(lengtsh) | |
""" | |
assert len(lengths.shape) == 1, 'Length shape should be 1 dimensional.' | |
max_len = max_len or lengths.max().item() | |
mask = torch.arange( | |
max_len, | |
device=lengths.device, | |
dtype=lengths.dtype)\ | |
.expand(len(lengths), max_len) < lengths.unsqueeze(1) | |
if dtype is not None: | |
mask = torch.as_tensor(mask, dtype=dtype, device=lengths.device) | |
return mask | |
class MaskedBatchNorm1d(nn.BatchNorm1d): | |
""" | |
Masked verstion of the 1D Batch normalization. | |
Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py | |
Receives a N-dim tensor of sequence lengths per batch element | |
along with the regular input for masking. | |
Check pytorch's BatchNorm1d implementation for argument details. | |
""" | |
def __init__(self, num_features, eps=1e-5, momentum=0.1, | |
affine=True, track_running_stats=True): | |
super(MaskedBatchNorm1d, self).__init__( | |
num_features, | |
eps, | |
momentum, | |
affine, | |
track_running_stats | |
) | |
def forward(self, inp, lengths): | |
self._check_input_dim(inp) | |
exponential_average_factor = 0.0 | |
# We transform the mask into a sort of P(inp) with equal probabilities | |
# for all unmasked elements of the tensor, and 0 probability for masked | |
# ones. | |
mask = lengths_to_mask(lengths, max_len=inp.shape[-1], dtype=inp.dtype) | |
n = mask.sum() | |
mask = mask / n | |
mask = mask.unsqueeze(1).expand(inp.shape) | |
if self.training and self.track_running_stats: | |
if self.num_batches_tracked is not None: | |
self.num_batches_tracked += 1 | |
if self.momentum is None: # use cumulative moving average | |
exponential_average_factor = 1.0 / float(self.num_batches_tracked) | |
else: # use exponential moving average | |
exponential_average_factor = self.momentum | |
# calculate running estimates | |
if not self.track_running_stats: # Should raise an exception if n==1 | |
mean = (mask * inp).sum([0, 2]) | |
var = ((mask * inp ** 2).sum([0, 2]) - mean ** 2) * n / (n - 1) | |
elif self.training and n > 1: | |
# Here lies the trick. Using Var(X) = E[X^2] - E[X]^2 as the biased | |
# variance, we do not need to make any tensor shape manipulation. | |
# mean = E[X] is simply the sum-product of our "probability" mask with the input... | |
mean = (mask * inp).sum([0, 2]) | |
# ...whereas Var(X) is directly derived from the above formulae | |
# This should be numerically equivalent to the biased sample variance | |
var = (mask * inp ** 2).sum([0, 2]) - mean ** 2 | |
with torch.no_grad(): | |
self.running_mean = exponential_average_factor * mean\ | |
+ (1 - exponential_average_factor) * self.running_mean | |
# Update running_var with unbiased var | |
self.running_var = exponential_average_factor * var * n / (n - 1)\ | |
+ (1 - exponential_average_factor) * self.running_var | |
else: | |
mean = self.running_mean | |
var = self.running_var | |
inp = (inp - mean[None, :, None]) / (torch.sqrt(var[None, :, None] + self.eps)) | |
if self.affine: | |
inp = inp * self.weight[None, :, None] + self.bias[None, :, None] | |
return inp |
Hello @amiasato, thank you for the piece of code !
I am struggling a bit to understand if it is supposed to work for a 3d input matrix. For instance what would be the lengths
I should give to MaskedBatchNorm1d
for the x
given below ?
import torch
x_no_pad = torch.ones(2, 5, 3) # non padded samples (n=2)
x_pad = torch.zeros(1, 5, 3) # padded sample (n=1)
x_pad[:, 0, :] = 1 # we add one non-padded entry for this sample
x = torch.cat([x_no_pad, x_pad])
Plz make layer_norm~~~
Hi, could you give a sample code that utilizes this masked batch norm?
I am struggling to understand your code, and a working example could greatly help!
@YooSungHyun layernorm does not need mask as the normalization is only through the last dim. So masked elements are normalized only by their mean and variance, and non-masked elements are also only normalized by their own mean and variance. Masked and non-masked are never merged
This implementation is not working in case track_running_stats=False as self.running_mean and self.running_var is None both at training and inference. How can I modify the implementation so that it works fine in this case as well?
@tubali12345 I can't verify the code right now, but you'll need to use the masked mean/var as-is in all cases. No running stats means that the n==1
case is ill-defined, so you should raise an exception for that. Edit: I modified the code to cover your case, PLMK if it works. Edit2: I'm unsure of whether we use the biased or unbiased instantaneous variance now in the untracked stats version, so feel free to fix that as well.
I covered the n==1 case and it works fine, thanks! I think we should use unbiased variance in this case, but I am not sure either.
I ran only one apples-to-apples comparison in an audio classification task. Training loss was consistently lower with the masked version, with no visible impact in step processing times, though I don't have the loss curves right now. Input sizes varied in length from 1s to 15s, but most examples were within the 3~7s range. I've been using this masked batch norm as an almost drop-in replacement since then.