Skip to content

Instantly share code, notes, and snippets.

@ilya16
ilya16 / masked_batchnorm.py
Created February 26, 2021 15:08
Masked Normalization layers in PyTorch
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
# Masked Batch Normalization