Created
November 6, 2018 22:09
-
-
Save skoppula/6773d5ba37499bd0c5045e0ec9c0b9c4 to your computer and use it in GitHub Desktop.
Simple code to export PyTorch ResNet model as ONNX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.backends.cudnn as cudnn | |
import torch.utils.data | |
import resnet | |
model_names = sorted(name for name in resnet.__dict__ | |
if name.islower() and not name.startswith("__") | |
and name.startswith("resnet") | |
and callable(resnet.__dict__[name])) | |
def remove_module_prefix(param_dict): | |
new_checkpoint = {} | |
for key in param_dict.keys(): | |
assert key.startswith('module.') | |
new_checkpoint[key[7:]] = param_dict[key] | |
return new_checkpoint | |
def main(): | |
model = resnet.__dict__['resnet110']() | |
checkpoint = torch.load('pretrained_models/resnet110.th') | |
checkpoint['state_dict'] = remove_module_prefix(checkpoint['state_dict']) | |
model.load_state_dict(checkpoint['state_dict']) | |
dummy_input = torch.autograd.Variable(torch.randn(1, 3, 32, 32)) | |
input_names = [ "data" ] | |
output_names = [ "output" ] | |
torch.onnx.export(model, dummy_input, 'cifar10_resnet110.onnx', input_names=input_names, output_names=output_names) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# From https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py | |
''' | |
Properly implemented ResNet-s for CIFAR10 as described in paper [1]. | |
The implementation and structure of this file is hugely influenced by [2] | |
which is implemented for ImageNet and doesn't have option A for identity. | |
Moreover, most of the implementations on the web is copy-paste from | |
torchvision's resnet and has wrong number of params. | |
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following | |
number of layers and parameters: | |
name | layers | params | |
ResNet20 | 20 | 0.27M | |
ResNet32 | 32 | 0.46M | |
ResNet44 | 44 | 0.66M | |
ResNet56 | 56 | 0.85M | |
ResNet110 | 110 | 1.7M | |
ResNet1202| 1202 | 19.4m | |
which this implementation indeed has. | |
Reference: | |
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun | |
Deep Residual Learning for Image Recognition. arXiv:1512.03385 | |
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py | |
If you use this implementation in you work, please don't forget to mention the | |
author, Yerlan Idelbayev. | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
from torch.autograd import Variable | |
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] | |
def _weights_init(m): | |
classname = m.__class__.__name__ | |
print(classname) | |
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): | |
init.kaiming_normal(m.weight) | |
class LambdaLayer(nn.Module): | |
def __init__(self, lambd): | |
super(LambdaLayer, self).__init__() | |
self.lambd = lambd | |
def forward(self, x): | |
return self.lambd(x) | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, in_planes, planes, stride=1, option='A'): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_planes != planes: | |
if option == 'A': | |
""" | |
For CIFAR10 ResNet paper uses option A. | |
""" | |
self.shortcut = LambdaLayer(lambda x: | |
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) | |
elif option == 'B': | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), | |
nn.BatchNorm2d(self.expansion * planes) | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += self.shortcut(x) | |
out = F.relu(out) | |
return out | |
class ResNet(nn.Module): | |
def __init__(self, block, num_blocks, num_classes=10): | |
super(ResNet, self).__init__() | |
self.in_planes = 16 | |
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(16) | |
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) | |
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) | |
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) | |
self.linear = nn.Linear(64, num_classes) | |
self.apply(_weights_init) | |
def _make_layer(self, block, planes, num_blocks, stride): | |
strides = [stride] + [1]*(num_blocks-1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(self.in_planes, planes, stride)) | |
self.in_planes = planes * block.expansion | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.layer1(out) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = F.avg_pool2d(out, out.size()[3]) | |
out = out.view(out.size(0), -1) | |
out = self.linear(out) | |
return out | |
def resnet20(): | |
return ResNet(BasicBlock, [3, 3, 3]) | |
def resnet32(): | |
return ResNet(BasicBlock, [5, 5, 5]) | |
def resnet44(): | |
return ResNet(BasicBlock, [7, 7, 7]) | |
def resnet56(): | |
return ResNet(BasicBlock, [9, 9, 9]) | |
def resnet110(): | |
return ResNet(BasicBlock, [18, 18, 18]) | |
def resnet1202(): | |
return ResNet(BasicBlock, [200, 200, 200]) | |
def test(net): | |
import numpy as np | |
total_params = 0 | |
for x in filter(lambda p: p.requires_grad, net.parameters()): | |
total_params += np.prod(x.data.numpy().shape) | |
print("Total number of params", total_params) | |
print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) | |
if __name__ == "__main__": | |
for net_name in __all__: | |
if net_name.startswith('resnet'): | |
print(net_name) | |
test(globals()[net_name]()) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ONNX file generated here: https://www.dropbox.com/s/jk8g9noswa0b0c0/cifar10_resnet101.onnx?dl=0