Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Last active January 10, 2021 21:17
Show Gist options
  • Save sandeepkumar-skb/186f5e5c1549fd88cbd606ea2da44b6b to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/186f5e5c1549fd88cbd606ea2da44b6b to your computer and use it in GitHub Desktop.
Folding BN into convolution
import torch
import torch.nn as nn
import copy
import torchvision.models as models
class BN_Folder():
def fold(self, model):
mymodel = copy.deepcopy(model)
mymodel.eval()
model_keys = list(mymodel._modules.keys()) # Get the module names for each of the module
prev = None
for name in model_keys:
if len(mymodel._modules[name]._modules) > 0:
mymodel._modules[name] = self.fold(mymodel._modules[name])
elif (isinstance(mymodel._modules[name], nn.BatchNorm2d) and
isinstance(mymodel._modules[prev], nn.Conv2d)):
folded_conv = self.fold_bn(mymodel._modules[prev], mymodel._modules[name])
mymodel._modules.pop(name)
mymodel._modules[prev] = folded_conv
prev = name
return mymodel
def fold_bn(self, conv, bn):
folded_conv = copy.deepcopy(conv)
conv_w = conv.weight
conv_b = conv.bias
bn_rv = bn.running_var
bn_rm = bn.running_mean
bn_eps = bn.eps
bn_w = bn.weight
bn_b = bn.bias
folded_conv.weight, folded_conv.bias = self.fold_bn_util(conv_w, conv_b, bn_w, bn_b, bn_rv, bn_rm, bn_eps)
return folded_conv
def fold_bn_util(self, conv_w, conv_b, bn_w, bn_b, bn_rv, bn_rm, bn_eps):
if not conv_b:
conv_b = torch.zeros_like(bn_b)
bn_rv = torch.rsqrt(bn_rv + bn_eps)
folded_w = conv_w * (bn_w * bn_rv).view(-1, 1,1,1)
folded_b = (conv_b - bn_rm)* bn_w * bn_rv + bn_b
return torch.nn.Parameter(folded_w), torch.nn.Parameter(folded_b)
if __name__ == "__main__":
rn18 = models.resnet18(pretrained=True)
bn_folder = BN_Folder()
new_mod = bn_folder.fold(rn18)
print(new_mod)
import torch
import torch.nn as nn
import copy
import time
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(32,64,(3,3))
self.bn = torch.nn.BatchNorm2d(64)
def forward(self, inp):
x = self.conv(inp)
out = self.bn(x)
return out
if __name__ == "__main__":
model = Model()
model.eval().cuda()
with torch.no_grad():
x = torch.randn((64,32,56,56), device='cuda')
z = model(x)
conv = model.conv
bn = model.bn
conv_w = conv.weight
conv_b = conv.bias
bn_w = bn.weight
bn_b = bn.bias
bn_rv = bn.running_var
bn_rm = bn.running_mean
bn_eps = bn.eps
bn_rv = torch.rsqrt(bn_rv + bn_eps)
folded_w = conv_w * (bn_w*bn_rv).view(-1,1,1,1)
folded_b = (conv_b - bn_rm) * bn_w*bn_rv + bn_b
folded_conv = copy.deepcopy(conv)
folded_conv.weight = torch.nn.Parameter(folded_w)
folded_conv.bias = torch.nn.Parameter(folded_b)
y = folded_conv(x)
print(torch.sum(y-z))
torch.cuda.synchronize()
start = time.time()
num_iter = 10000
for _ in range(num_iter):
z = folded_conv(x)
torch.cuda.synchronize()
print("With conv-bn folding: {:.2f}ms".format((time.time() - start)*1000/num_iter))
start = time.time()
for _ in range(num_iter):
z = model(x)
torch.cuda.synchronize()
print("Without conv-bn folding: {:.2f}ms".format((time.time() - start)*1000/num_iter))
@sandeepkumar-skb
Copy link
Author

Calculated the results over 10k iterations.

python bn_folding.py
tensor(0.0082, device='cuda:0')
With conv-bn folding: 0.68ms
Without conv-bn folding: 0.84ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment