Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Last active March 2, 2022 04:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lyken17/deb98385a06ae67ce1252c1a17ad181d to your computer and use it in GitHub Desktop.
Save Lyken17/deb98385a06ae67ce1252c1a17ad181d to your computer and use it in GitHub Desktop.
[pytorch] Fuse Conv2d and BatchNorm at module level
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.resnet import BasicBlock, ResNet
def remove_dropout(module):
module_output = module
if isinstance(module, (nn.Dropout)):
print("removing dropout")
module_output = nn.Identity()
for name, child in module.named_children():
module_output.add_module(name, remove_dropout(child))
del module
return module_output
def fuse_bn(module):
module_output = module
if isinstance(module, (nn.Sequential,)):
print("[nn.Sequential]\tfusing BN and dropout")
idx = 0
for idx in range(len(module) - 1):
if not isinstance(module[idx], nn.Conv2d) or not isinstance(
module[idx + 1], nn.BatchNorm2d
):
continue
conv = module[idx]
bn = module[idx + 1]
channels = bn.weight.shape[0]
invstd = 1 / torch.sqrt(bn.running_var + bn.eps)
conv.weight.data = (
conv.weight
* bn.weight[:, None, None, None]
* invstd[:, None, None, None]
)
if conv.bias is None:
conv.bias = nn.Parameter(torch.zeros(conv.out_channels))
conv.bias.data = (
conv.bias - bn.running_mean
) * bn.weight * invstd + bn.bias
module[idx + 1] = nn.Identity()
for name, child in module.named_children():
module_output.add_module(name, fuse_bn(child))
del module
return module_output
def get_model():
model = models.mobilenet_v2(pretrained=True)
model = model.eval()
model = fuse_bn(model)
model = remove_dropout(model)
return model
if __name__ == "__main__":
orig = models.mobilenet_v2(pretrained=True).eval()
model = models.mobilenet_v2(pretrained=True).eval()
fuse = fuse_bn(model)
fuse = remove_dropout(fuse)
print(fuse)
for i in range(10):
data = torch.randn(1, 3, 224, 224)
out1 = orig(data)
out2 = fuse(data)
diff = out1 - out2
print(diff.mean(), diff.std())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment