Skip to content

Instantly share code, notes, and snippets.

@soumith
Created December 24, 2018 04:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/f7e9afbdc561a2cfa9a2c5bdf443aa8b to your computer and use it in GitHub Desktop.
Save soumith/f7e9afbdc561a2cfa9a2c5bdf443aa8b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
class simpnet_imgnet_drpall(nn.Module):
"""
args: classes
scale
network_idx (0,1):simpnet5m, simpnet8m
mode : stride mode (1,2,3,4,5)
"""
def __init__(self, classes=1000, scale=1.0, network_idx=0, mode=1, simpnet_name='simpnet_imgnet_drpall'):
super(simpnet_imgnet_drpall, self).__init__()
self.cfg = {
'simpnet5m': [['C', 66], ['C', 128], ['C', 128], ['C', 128], ['C', 192], ['C', 192], ['C', 192], ['C', 192], ['C', 192], ['C', 288], ['P'], ['C', 288], ['C', 355], ['C', 432]],
'simpnet8m': [['C', 128], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 182], ['C', 430], ['P'], ['C', 430], ['C', 455], ['C', 600]]}
self.scale = scale
self.networks = ['simpnet5m', 'simpnet8m']
self.network_idx = network_idx
self.mode = mode
self.strides = {1: [2, 2, 2, 1, 1], #s1
2: [2, 2, 1, 2, 1, 1], #s4
3: [2, 2, 1, 1, 2, 1], #s3
4: [2, 1, 2, 1, 2, 1], #s5
5: [2, 1, 2, 1, 2, 1, 1]}#s6
self.features = self._make_layers(scale)
self.classifier = nn.Linear(round(self.cfg[self.networks[network_idx]][-1][1] * scale), classes)
def load_my_state_dict(self, state_dict):
own_state = self.state_dict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if name not in own_state:
continue
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
print("STATE_DICT: {}".format(name))
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ... Using Initial Params'.format(
name, own_state[name].size(), param.size()))
def forward(self, x):
out = self.features(x)
#Global Max Pooling
out = F.max_pool2d(out, kernel_size=out.size()[2:])
out = F.dropout2d(out, 0.01, training=False)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, scale):
layers = []
input_channel = 3
idx = 0
for x in self.cfg[self.networks[self.network_idx]]:
if idx == len(self.strides[self.mode]) or x[0] == 'P':
layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
nn.Dropout2d(p=0.00)]
if x[0] != 'C':
continue
filters = round(x[1] * scale)
if idx < len(self.strides[self.mode]):
stride = self.strides[self.mode][idx]
else:
stride = 1
if idx in (len(self.strides[self.mode])-1, 9, 12):
layers += [nn.Conv2d(input_channel, filters, kernel_size=[3, 3], stride=(stride, stride), padding=(1, 1)),
nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True),
nn.ReLU(inplace=True)]
else:
layers += [nn.Conv2d(input_channel, filters, kernel_size=[3, 3], stride=(stride, stride), padding=(1, 1)),
nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True),
nn.ReLU(inplace=True),
nn.Dropout2d(p=0.000)]
input_channel = filters
idx += 1
model = nn.Sequential(*layers)
print(model)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
return model
model = simpnet_imgnet_drpall()
model = nn.DataParallel(model).cuda()
x = torch.randn(10, 3, 224, 224).cuda()
y = model(x)
y.sum().backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment