Skip to content

Instantly share code, notes, and snippets.

@ilya16
Created February 26, 2021 15:08
Show Gist options
  • Save ilya16/c622461000480e66ae906dd9dbe8ea26 to your computer and use it in GitHub Desktop.
Save ilya16/c622461000480e66ae906dd9dbe8ea26 to your computer and use it in GitHub Desktop.
Masked Normalization layers in PyTorch
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
# Masked Batch Normalization
def masked_batch_norm(input: Tensor, mask: Tensor, weight: Optional[Tensor], bias: Optional[Tensor],
running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool,
momentum: float, eps: float = 1e-5) -> Tensor:
r"""Applies Masked Batch Normalization for each channel in each data sample in a batch.
See :class:`~MaskedBatchNorm1d`, :class:`~MaskedBatchNorm2d`, :class:`~MaskedBatchNorm3d` for details.
"""
if not training and (running_mean is None or running_var is None):
raise ValueError('Expected running_mean and running_var to be not None when training=False')
num_dims = len(input.shape[2:])
_dims = (0,) + tuple(range(-num_dims, 0))
_slice = (None, ...) + (None,) * num_dims
if training:
num_elements = mask.sum(_dims)
mean = (input * mask).sum(_dims) / num_elements # (C,)
var = (((input - mean[_slice]) * mask) ** 2).sum(_dims) / num_elements # (C,)
if running_mean is not None:
running_mean.copy_(running_mean * (1 - momentum) + momentum * mean.detach())
if running_var is not None:
running_var.copy_(running_var * (1 - momentum) + momentum * var.detach())
else:
mean, var = running_mean, running_var
out = (input - mean[_slice]) / torch.sqrt(var[_slice] + eps) # (N, C, ...)
if weight is not None and bias is not None:
out = out * weight[_slice] + bias[_slice]
return out
class _MaskedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_MaskedBatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def forward(self, input: Tensor, mask: Tensor = None) -> Tensor:
self._check_input_dim(input)
if mask is not None:
self._check_input_dim(mask)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
if mask is None:
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight, self.bias, bn_training, exponential_average_factor, self.eps
)
else:
return masked_batch_norm(
input, mask, self.weight, self.bias,
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
bn_training, exponential_average_factor, self.eps
)
class MaskedBatchNorm1d(torch.nn.BatchNorm1d, _MaskedBatchNorm):
r"""Applies Batch Normalization over a masked 3D input
(a mini-batch of 1D inputs with additional channel dimension)..
See documentation of :class:`~torch.nn.BatchNorm1d` for details.
Shape:
- Input: :math:`(N, C, L)`
- Mask: :math:`(N, 1, L)`
- Output: :math:`(N, C, L)` (same shape as input)
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = True, track_running_stats: bool = True) -> None:
super(MaskedBatchNorm1d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
class MaskedBatchNorm2d(torch.nn.BatchNorm2d, _MaskedBatchNorm):
r"""Applies Batch Normalization over a masked 4D input
(a mini-batch of 2D inputs with additional channel dimension)..
See documentation of :class:`~torch.nn.BatchNorm2d` for details.
Shape:
- Input: :math:`(N, C, H, W)`
- Mask: :math:`(N, 1, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = True, track_running_stats: bool = True) -> None:
super(MaskedBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
class MaskedBatchNorm3d(torch.nn.BatchNorm3d, _MaskedBatchNorm):
r"""Applies Batch Normalization over a masked 5D input
(a mini-batch of 3D inputs with additional channel dimension).
See documentation of :class:`~torch.nn.BatchNorm3d` for details.
Shape:
- Input: :math:`(N, C, D, H, W)`
- Mask: :math:`(N, 1, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = True, track_running_stats: bool = True) -> None:
super(MaskedBatchNorm3d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.instancenorm import _InstanceNorm
# Masked Instance Normalization
def masked_instance_norm(input: Tensor, mask: Tensor, weight: Optional[Tensor], bias: Optional[Tensor],
running_mean: Optional[Tensor], running_var: Optional[Tensor], use_input_stats: bool,
momentum: float, eps: float = 1e-5) -> Tensor:
r"""Applies Masked Instance Normalization for each channel in each data sample in a batch.
See :class:`~MaskedInstanceNorm1d`, :class:`~MaskedInstanceNorm2d`, :class:`~MaskedInstanceNorm3d` for details.
"""
if not use_input_stats and (running_mean is None or running_var is None):
raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')
shape = input.shape
b, c = shape[:2]
num_dims = len(shape[2:])
_dims = tuple(range(-num_dims, 0))
_slice = (...,) + (None,) * num_dims
running_mean_ = running_mean[None, :].repeat(b, 1) if running_mean is not None else None
running_var_ = running_var[None, :].repeat(b, 1) if running_mean is not None else None
if use_input_stats:
lengths = mask.sum(_dims)
mean = (input * mask).sum(_dims) / lengths # (N, C)
var = (((input - mean[_slice]) * mask) ** 2).sum(_dims) / lengths # (N, C)
if running_mean is not None:
running_mean_.mul_(1 - momentum).add_(momentum * mean.detach())
running_mean.copy_(running_mean_.view(b, c).mean(0, keepdim=False))
if running_var is not None:
running_var_.mul_(1 - momentum).add_(momentum * var.detach())
running_var.copy_(running_var_.view(b, c).mean(0, keepdim=False))
else:
mean, var = running_mean_.view(b, c), running_var_.view(b, c)
out = (input - mean[_slice]) / torch.sqrt(var[_slice] + eps) # (N, C, ...)
if weight is not None and bias is not None:
out = out * weight[None, :][_slice] + bias[None, :][_slice]
return out
class _MaskedInstanceNorm(_InstanceNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_MaskedInstanceNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def forward(self, input: Tensor, mask: Tensor = None) -> Tensor:
self._check_input_dim(input)
if mask is not None:
self._check_input_dim(mask)
if mask is None:
return F.instance_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps
)
else:
return masked_instance_norm(
input, mask, self.weight, self.bias, self.running_mean, self.running_var,
self.training or not self.track_running_stats, self.momentum, self.eps
)
class MaskedInstanceNorm1d(torch.nn.InstanceNorm1d, _MaskedInstanceNorm):
r"""Applies Instance Normalization over a masked 3D input
(a mini-batch of 1D inputs with additional channel dimension)..
See documentation of :class:`~torch.nn.InstanceNorm1d` for details.
Shape:
- Input: :math:`(N, C, L)`
- Mask: :math:`(N, 1, L)`
- Output: :math:`(N, C, L)` (same shape as input)
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False) -> None:
super(MaskedInstanceNorm1d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
class MaskedInstanceNorm2d(torch.nn.InstanceNorm2d, _MaskedInstanceNorm):
r"""Applies Instance Normalization over a masked 4D input
(a mini-batch of 2D inputs with additional channel dimension).
See documentation of :class:`~torch.nn.InstanceNorm2d` for details.
Shape:
- Input: :math:`(N, C, H, W)`
- Mask: :math:`(N, 1, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False) -> None:
super(MaskedInstanceNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
class MaskedInstanceNorm3d(torch.nn.InstanceNorm3d, _MaskedInstanceNorm):
r"""Applies Instance Normalization over a masked 5D input
(a mini-batch of 3D inputs with additional channel dimension).
See documentation of :class:`~torch.nn.InstanceNorm3d` for details.
Shape:
- Input: :math:`(N, C, D, H, W)`
- Mask: :math:`(N, 1, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False) -> None:
super(MaskedInstanceNorm3d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment