Last active
March 2, 2022 04:13
-
-
Save Lyken17/deb98385a06ae67ce1252c1a17ad181d to your computer and use it in GitHub Desktop.
[pytorch] Fuse Conv2d and BatchNorm at module level
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torchvision import models | |
from torchvision.models.resnet import BasicBlock, ResNet | |
def remove_dropout(module): | |
module_output = module | |
if isinstance(module, (nn.Dropout)): | |
print("removing dropout") | |
module_output = nn.Identity() | |
for name, child in module.named_children(): | |
module_output.add_module(name, remove_dropout(child)) | |
del module | |
return module_output | |
def fuse_bn(module): | |
module_output = module | |
if isinstance(module, (nn.Sequential,)): | |
print("[nn.Sequential]\tfusing BN and dropout") | |
idx = 0 | |
for idx in range(len(module) - 1): | |
if not isinstance(module[idx], nn.Conv2d) or not isinstance( | |
module[idx + 1], nn.BatchNorm2d | |
): | |
continue | |
conv = module[idx] | |
bn = module[idx + 1] | |
channels = bn.weight.shape[0] | |
invstd = 1 / torch.sqrt(bn.running_var + bn.eps) | |
conv.weight.data = ( | |
conv.weight | |
* bn.weight[:, None, None, None] | |
* invstd[:, None, None, None] | |
) | |
if conv.bias is None: | |
conv.bias = nn.Parameter(torch.zeros(conv.out_channels)) | |
conv.bias.data = ( | |
conv.bias - bn.running_mean | |
) * bn.weight * invstd + bn.bias | |
module[idx + 1] = nn.Identity() | |
for name, child in module.named_children(): | |
module_output.add_module(name, fuse_bn(child)) | |
del module | |
return module_output | |
def get_model(): | |
model = models.mobilenet_v2(pretrained=True) | |
model = model.eval() | |
model = fuse_bn(model) | |
model = remove_dropout(model) | |
return model | |
if __name__ == "__main__": | |
orig = models.mobilenet_v2(pretrained=True).eval() | |
model = models.mobilenet_v2(pretrained=True).eval() | |
fuse = fuse_bn(model) | |
fuse = remove_dropout(fuse) | |
print(fuse) | |
for i in range(10): | |
data = torch.randn(1, 3, 224, 224) | |
out1 = orig(data) | |
out2 = fuse(data) | |
diff = out1 - out2 | |
print(diff.mean(), diff.std()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment