Skip to content

Instantly share code, notes, and snippets.

@shreejalt
Last active June 18, 2020 10:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shreejalt/2c499be21f45ff404f9fe964d24795cb to your computer and use it in GitHub Desktop.
Save shreejalt/2c499be21f45ff404f9fe964d24795cb to your computer and use it in GitHub Desktop.
Pytorch implementation to generate different families of RegNet Models(Facebook AI Research: March'2020)
'''
Name: Shreejal Trivedi
Description: Generation Script of RegNetX and RegNetY models
References: Designing Network Design Spaces from Facebook AI March'2020
'''
#Importing Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
#Downsampling used in first bottleneck block of every layer in RegNet
class Downsample(nn.Module):
def __init__(self, in_filters, out_filters, stride):
super(Downsample, self).__init__()
self.conv1x1 = nn.Conv2d(in_filters, out_filters, kernel_size=1, stride=stride, bias=False)
self.bn = nn.BatchNorm2d(out_filters)
def forward(self, x):
return self.bn(self.conv1x1(x))
#SE Attention Module for RegNetY
class SqueezeExcitation(nn.Module):
def __init__(self, in_filters, se_ratio):
super(SqueezeExcitation, self).__init__()
#Calculate bottleneck SE filters
out_filters = int(in_filters * se_ratio)
#Average Pooling Layer
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
#Squeeze
self.conv1_1x1 = nn.Conv2d(in_filters, out_filters, kernel_size=1, bias=True)
# Excite
self.conv2_1x1 = nn.Conv2d(out_filters, in_filters, kernel_size=1, bias=True)
def forward(self, x):
out = self.avgpool(x)
out = F.relu(self.conv1_1x1(out))
out = self.conv2_1x1(out).sigmoid()
out = x * out
return out
#Bottleneck Residual Block in Layer
class Bottleneck(nn.Module):
def __init__(self, in_filters, out_filters, bottleneck_ratio, group_size, stride=1, se_ratio=0):
super(Bottleneck, self).__init__()
#1x1 Bottleneck Convolution Block
bottleneck_filters = in_filters // bottleneck_ratio
self.conv1_1x1 = nn.Conv2d(in_filters, bottleneck_filters, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(bottleneck_filters)
#3x3 Convolution Block with Group Convolutions ---> ResNext alike structure
num_groups = bottleneck_filters // group_size
self.conv2_3x3 = nn.Conv2d(bottleneck_filters, bottleneck_filters, kernel_size=3, stride=stride, padding=1, groups=num_groups, bias=False)
self.bn2 = nn.BatchNorm2d(bottleneck_filters)
#Squeeze-Exictation Block: Only for RegNetY
self.se_module = SqueezeExcitation(bottleneck_filters, se_ratio) if se_ratio < 1 else None
#Downsample if stride=2
self.downsample = Downsample(in_filters, out_filters, stride) if stride != 1 or in_filters != out_filters else None
#1x1 Convolution Block
self.conv3_1x1 = nn.Conv2d(bottleneck_filters, out_filters, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_filters)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.conv1_1x1(x)))
out = F.relu(self.bn2(self.conv2_3x3(out)))
if self.se_module is not None:
out = self.se_module(out)
out = self.bn3(self.conv3_1x1(out))
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = F.relu(out)
return out
class Stem(nn.Module):
def __init__(self, out_filters, in_filters=3):
super(Stem, self).__init__()
self.conv3x3 = nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=2, padding=1, bias=False)
self.bn = nn.BatchNorm2d(out_filters)
def forward(self, x):
return F.relu(self.bn(self.conv3x3(x)))
class Layer(nn.Module):
def __init__(self, in_filters, depth, width, bottleneck_ratio, group_size, se_ratio):
super(Layer, self).__init__()
self.layers = []
#Total bottleneck blocks in a layer = Depth d
for i in range(depth):
stride = 2 if i == 0 else 1
bottleneck = Bottleneck(in_filters, width, bottleneck_ratio, group_size, stride, se_ratio)
self.layers.append(bottleneck)
in_filters = width
self.layers = nn.Sequential(*self.layers)
def forward(self, x):
out = self.layers(x)
return out
class Head(nn.Module):
def __init__(self, in_filters, classes):
super(Head, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.fc = nn.Linear(in_filters, classes)
def forward(self, x):
out = self.avgpool(x)
out = torch.flatten(out, 1)
out = self.fc(out)
return out
class RegNet(nn.Module):
def __init__(self, paramaters, classes=2):
super(RegNet, self).__init__()
#Model paramater initialization
self.in_filters = 32
self.w, self.d, self.b, self.g, self.se_ratio = parameters
self.num_layers = 4
#Stem Part of the generic ResNet/ResNeXt architecture
self.stem = Stem(self.in_filters)
self.body = []
for i in range(self.num_layers):
layer = Layer(self.in_filters, self.d[i], self.w[i], self.b, self.g, self.se_ratio)
self.body.append(layer)
self.in_filters = self.w[i]
#Body Part: Four Layers containing bottleneck residual blocks
self.body = nn.Sequential(*self.body)
#Head Part: Classification Step FC + AveragePool
self.head = Head(self.w[-1], classes)
def forward(self, x):
out = self.stem(x)
out = self.body(out)
out = self.head(out)
return out
def generate_parameters_regnet(D, w0, wa, wm, b, g, q):
u = w0 + wa * np.arange(D) # Equation 1
s = np.log(u / w0) / np.log(wm) # Equation 2
s = np.round(s) #Rounding the possible block sizes s
w = w0 * np.power(wm, s) # Equation 3
w = np.round(w / 8) * 8 # Make all the width list divisible by 8
w, d = np.unique(w.astype(np.int), return_counts=True) #Finding depth and width lists.
gtemp = np.minimum(g, w//b)
w = (np.round(w // b / gtemp) * gtemp).astype(int) #To make all the width compatible with group sizes of the 3x3 convolutional layers
g = np.unique(gtemp // b)[0]
return (w, d, b, g, q)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="RegNetX | RegNetY Models Generation")
parser.add_argument('-D', default=13, type=int, help='Network Depth: Range::[12, 13, ..., 28]')
parser.add_argument('-w0', default=24, type=int, help='Initial Width of the First Layer > 0')
parser.add_argument('-wa', default=36, type=int, help='Slope Parameter: Range::[0, 1, 2, ..., 255]')
parser.add_argument('-wm', default=2.5, type=float, help='Quantization Parameter: Range::[1.5, 3]')
parser.add_argument('-b', default=1, type=int, help='Bottleneck Ratio: Range::{1, 2, 4}')
parser.add_argument('-g', default=8, type=int, help='Group Size: Range::{1, 2, 4, 8, 16, 32} OR {16, 24, 32, 40, 48, 56, 64}')
parser.add_argument('-q', default=1, type=float, help='0 <= SE Ratio < 1')
args = parser.parse_args()
parameters = generate_parameters_regnet(args.D, args.w0, args.wa, args.wm, args.b, args.g, args.q)
model = RegNet(parameters)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment