Skip to content

Instantly share code, notes, and snippets.

@dyigitpolat
Created December 3, 2022 08:31
Show Gist options
  • Save dyigitpolat/9a2f84ae891e7dfd86127df123bb19aa to your computer and use it in GitHub Desktop.
Save dyigitpolat/9a2f84ae891e7dfd86127df123bb19aa to your computer and use it in GitHub Desktop.
batch norm fusing by openai chat bot
import torch
def fuse_linear_bn(linear, bn):
# Get the weight and bias of the linear layer
weight, bias = linear.weight, linear.bias
# Get the running mean and variance of the batch norm layer
running_mean, running_var = bn.running_mean
# Compute the scale and shift parameters for the batch norm layer
# using the weight and bias of the linear layer
scale = weight / torch.sqrt(running_var + bn.eps)
shift = bias - running_mean * scale
# Replace the weight and bias of the linear layer with the scale and shift
# parameters of the batch norm layer
linear.weight = torch.nn.Parameter(scale)
linear.bias = torch.nn.Parameter(shift)
# Return the fused linear layer
return linear
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment