Skip to content

Instantly share code, notes, and snippets.

Created May 16, 2023 15:57
Show Gist options
  • Save Kautenja/99757c6dd428cf014cad248c2dce57f6 to your computer and use it in GitHub Desktop.
Save Kautenja/99757c6dd428cf014cad248c2dce57f6 to your computer and use it in GitHub Desktop.
A PyTorch method for fusing normalization statistics directly into a convolutional layer
import torch
from torch.nn import Conv2d
def fuse_normalize_into_conv_2d(conv: Conv2d, mean: torch.Tensor, std: torch.Tensor) -> Conv2d:
Fuse normalization statistics into a convolutional layer.
conv: The convolutional layer to fuse the norm layer into.
mean: The mean value vector with shape [in_channels].
std: The standard deviation vector with shape [in_channels].
The convolutional layer.
This fusion is based on the following re-write of norm+conv. First,
we can fuse the scale (standard deviation) into the convolutional
weights using the associate property of multiplication.
\frac{(x-a)}{b} * w = (x-a) * \frac{w}{b}
Next, we can use the distributive property to re-write the subtraction
of the mean in such a way that it can be lumped into a single bias term.
(x-a) * \frac{w}{b} + c = x * \frac{w}{b} + (c - a * \frac{w}{b})
Ultimately, this means the weight of the layer gets scaled:
w \gets \frac{w}{b}
and from the bias term we remove the convolution of the mean with the
(scaled) weight.
c \gets c - a * \frac{w}{b}
# Ensure the mean and standard deviation are in [N, C, H, W] format.
mean = mean.view(1, 3, 1, 1)
std = std.view(1, 3, 1, 1)
# Fuse the standard deviation into the convolutional weight.
conv.weight[:] = conv.weight / std
_, _, H, W = conv.weight.shape
yc = H // 2
xc = W // 2
# Fuse the mean into the convolutional weight.
mean = mean.expand(1, 3, H+1, W+1)
offset = torch.conv2d(mean, conv.weight, padding='same')[:, :, yc:yc+1, xc:xc+1].squeeze()
if conv.bias is None: # No bias, assign one directly.
conv.bias = nn.Parameter(-offset)
else: # Adjust the existing bias term.
conv.bias[:] = conv.bias - offset
return conv
# Explicitly define the outward facing API of this module.
__all__ = [fuse_normalize_into_conv_2d.__name__]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment