Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Last active October 17, 2018 03:47
Show Gist options
  • Save XinDongol/2753e126e4468ebfc2f5fca462495776 to your computer and use it in GitHub Desktop.
Save XinDongol/2753e126e4468ebfc2f5fca462495776 to your computer and use it in GitHub Desktop.
fold bn
import torch
import torchvision
import numpy as np
class Fold_BN_v1(torch.nn.Module):
'''
Do fold bn with conv.
You can change conv with any other layers
we assume that:
when use_running is False:
use batch statistics during training
always use running statistics for inference
'''
def __init__(self, in_channels, out_channels,
kernel_size=-1, stride=-1, padding=-1, groups=1, dropout=0,
affine=True, bias=True, use_running = False):
super(Fold_BN_v1, self).__init__()
self.use_running = use_running
self.bn = torch.nn.BatchNorm2d(out_channels)
# you can replace it with your own layer
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias)
def forward(self, x):
tmp = self.conv(x)
c = tmp.size(1)
# to get running value
if self.use_running or (self.training == False):
#print('Running ...')
self.mu = self.bn.running_mean
self.var = self.bn.running_var
else:
#print('Training ...')
self.mu = tmp.transpose(0,1).contiguous().view(c,-1).mean(1)
self.var = tmp.transpose(0,1).contiguous().view(c,-1).var(1)
w_conv = self.conv.weight.clone().view(self.conv.out_channels, -1)
w_bn = torch.diag(self.bn.weight.div(torch.sqrt(self.bn.eps+self.var)))
self.conv.weight.data = torch.mm(w_bn, w_conv).view(self.conv.weight.size())
# maybe you want to use quantize weights here
if self.conv.bias is not None:
b_conv = self.conv.bias
else:
b_conv = torch.zeros(conv.weight.size(0))
b_bn = self.bn.bias - self.bn.weight.mul(self.mu).div(torch.sqrt(self.var + self.bn.eps))
self.conv.bias.data = b_conv + b_bn
self.real_output = self.bn(tmp)
return self.conv(x)
if __name__ == '__main__':
# test new layer
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
net = Fold_BN_v1(in_channels=3, out_channels=5,
kernel_size=3, stride=1, padding=1, groups=1, dropout=0,
affine=True, bias=True, use_running = False)
net.train()
y1 = net.forward(x)
y2 = net.real_output
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)
print('delta: ', (y1 - y2).norm()/(16* 3* 256* 256))
net.eval()
y1 = net.forward(x)
y2 = net.real_output
print(torch.isnan(y1).data.sum().item())
print(torch.isnan(y2).data.sum().item())
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)
print('delta: ', (y1 - y2).norm()/(16* 3* 256* 256))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment