Skip to content

Instantly share code, notes, and snippets.

@MohitLamba94
Last active July 21, 2020 10:22
Show Gist options
  • Save MohitLamba94/006fefea58d2c32f16a006c310e10c00 to your computer and use it in GitHub Desktop.
Save MohitLamba94/006fefea58d2c32f16a006c310e10c00 to your computer and use it in GitHub Desktop.
This python file finds the numerical/experimental receptive field of any complicated CNN.
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.init as init
import imageio
def downshuffle(var,r):
b,c,h,w = var.size()
out_channel = c*(r**2)
out_h = h//r
out_w = w//r
return var.contiguous().view(b, c, out_h, r, out_w, r).permute(0,1,3,5,2,4).contiguous().view(b,out_channel, out_h, out_w).contiguous()
def calculate_RF(img_np,cnn,RF_name):
img_ip = Variable(torch.from_numpy(img_np).float(),requires_grad=True)
original_image = (np.clip(img_ip[0].detach().cpu().numpy().transpose(1,2,0),0,1)*255).astype(np.uint8)
optimizer = torch.optim.Adam([img_ip], lr=1)
optimizer.zero_grad()
img_out = cnn(img_ip)
grad=torch.zeros(img_out.size())
b,c,h,w = img_out.size()
grad[0,0,h//2,w//2] = 1
img_out.backward(gradient=grad)
optimizer.step()
updated_image = torch.where(img_ip !=0, torch.ones(img_ip.size()).float(), torch.zeros(img_ip.size()).float())
updated_image = (np.clip(updated_image[0].detach().cpu().numpy().transpose(1,2,0),0,1)*255).astype(np.uint8)
idx = np.nonzero(updated_image)
print(RF_name,'-- row_min:',idx[0].min(),' row_max:',idx[0].max(),' col_min:',idx[1].min(),' col_max:',idx[1].max())
imageio.imwrite('input_img.jpg', original_image)
imageio.imwrite(RF_name, updated_image)
return
class unet3(nn.Module):
def __init__(self):
super(unet3, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=1,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=2)
nn.init.constant_(self.conv1.bias, 1)
nn.init.constant_(self.conv1.weight, 1)
self.pixelshuffle = nn.PixelShuffle(8)
self.conv2 = nn.Conv2d(in_channels=1, out_channels=64,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=2)
nn.init.constant_(self.conv2.bias, 1)
nn.init.constant_(self.conv2.weight, 1)
def forward(self, x):
return self.pixelshuffle(self.conv2(self.conv1(self.conv1(x))))
class unet5(nn.Module):
def __init__(self):
super(unet5, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=1,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=2)
nn.init.constant_(self.conv1.bias, 1)
nn.init.constant_(self.conv1.weight, 1)
self.pixelshuffle = nn.PixelShuffle(8)
self.conv3 = nn.Conv2d(in_channels=1, out_channels=1,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=1)
nn.init.constant_(self.conv3.bias, 1)
nn.init.constant_(self.conv3.weight, 1)
self.conv4 = nn.Conv2d(in_channels=1, out_channels=64,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=1)
nn.init.constant_(self.conv4.bias, 1)
nn.init.constant_(self.conv4.weight, 1)
def forward(self, x):
return self.pixelshuffle(self.conv4(self.conv3(self.conv1(self.conv1(self.conv1(x))))))
class simpl5(nn.Module):
def __init__(self):
super(simpl5, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=1,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=1)
nn.init.constant_(self.conv.bias, 1)
nn.init.constant_(self.conv.weight, 1)
def forward(self, x):
return self.conv(self.conv(self.conv(self.conv(self.conv(x)))))
class ours(nn.Module):
def __init__(self):
super(ours, self).__init__()
self.conv1 = nn.Conv2d(in_channels=64, out_channels=1,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=1)
nn.init.constant_(self.conv1.bias, 1)
nn.init.constant_(self.conv1.weight, 1)
self.conv11 = nn.Conv2d(in_channels=1, out_channels=1,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=1)
nn.init.constant_(self.conv11.bias, 1)
nn.init.constant_(self.conv11.weight, 1)
self.conv2 = nn.Conv2d(in_channels=1, out_channels=64,
kernel_size=3, padding = 3//2, groups=1, bias=True, stride=1)
nn.init.constant_(self.conv2.bias, 1)
nn.init.constant_(self.conv2.weight, 1)
self.pixelshuffle = nn.PixelShuffle(8)
def forward(self, x):
return self.pixelshuffle(self.conv2(self.conv11(self.conv1(downshuffle(x,8)))))
img_np = np.zeros((1,1,256,256))
mycnn = unet3()
calculate_RF(img_np,mycnn,'unet3_image.jpg')
mycnn = unet5()
calculate_RF(img_np,mycnn,'unet5_image.jpg')
mycnn = simpl5()
calculate_RF(img_np,mycnn,'simpl5_image.jpg')
mycnn = ours()
calculate_RF(img_np,mycnn,'ours3_image.jpg')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment