Skip to content

Instantly share code, notes, and snippets.

@sampathweb
Created May 24, 2018 01:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sampathweb/7b6fbb35095835e9f5a12136f7494197 to your computer and use it in GitHub Desktop.
Save sampathweb/7b6fbb35095835e9f5a12136f7494197 to your computer and use it in GitHub Desktop.
FastAI - PyTorch Cifar10 Onnx Export Example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Taken from https://github.com/fastai/fastai/blob/master/courses/dl1/cifar10.ipynb"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.imports import *\n",
"from fastai.transforms import *\n",
"from fastai.conv_learner import *\n",
"from fastai.model import *\n",
"from fastai.dataset import *\n",
"from fastai.sgdr import *\n",
"from fastai.plots import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.conv_learner import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = get_data(32,4)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"PATH = \"/data/cifar10/\"\n",
"os.makedirs(PATH,exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"def get_data(sz,bs):\n",
" tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
" return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"bs=128"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"data = get_data(32,4)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"x,y=next(iter(data.trn_dl))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7efd93814c18>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(data.trn_ds.denorm(x)[0]);"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"'''LeNet in PyTorch.'''\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class LeNet(nn.Module):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16*5*5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" out = F.relu(self.conv1(x))\n",
" out = F.max_pool2d(out, 2)\n",
" out = F.relu(self.conv2(out))\n",
" out = F.max_pool2d(out, 2)\n",
" out = out.view(out.size(0), -1)\n",
" out = F.relu(self.fc1(out))\n",
" out = F.relu(self.fc2(out))\n",
" out = self.fc3(out)\n",
" return F.log_softmax(out)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# FastAI Models use AdaptiveMax and AdaptiveAvgPool that were not suported in PyTorch 0.3 for ONNX conversion\n",
"# Use a Custom Model (as shown here) to get around that or try in PyTorch 0.4 if the issue is resolved\n",
"\n",
"model = LeNet()"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"learn = ConvLearner.from_model_data(model, data)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"# Train the Model\n",
"# learn.fit(0.01, 2)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LeNet(\n",
" (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
" (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
" (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
" (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.models.model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Export to ONNX"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"graph(%1 : Float(1, 3, 32, 32)\n",
" %2 : Float(6, 3, 5, 5)\n",
" %3 : Float(6)\n",
" %4 : Float(16, 6, 5, 5)\n",
" %5 : Float(16)\n",
" %6 : Float(120, 400)\n",
" %7 : Float(120)\n",
" %8 : Float(84, 120)\n",
" %9 : Float(84)\n",
" %10 : Float(10, 84)\n",
" %11 : Float(10)) {\n",
" %13 : UNKNOWN_TYPE = Conv[kernel_shape=[5, 5], strides=[1, 1], pads=[0, 0, 0, 0], dilations=[1, 1], group=1](%1, %2), uses = [[%14.i0]], scope: LeNet/Conv2d[conv1];\n",
" %14 : Float(1, 6, 28, 28) = Add[broadcast=1, axis=1](%13, %3), uses = [%15.i0], scope: LeNet/Conv2d[conv1];\n",
" %15 : Float(1, 6, 28, 28) = Relu(%14), uses = [%16.i0], scope: LeNet;\n",
" %16 : Float(1, 6, 14, 14) = MaxPool[kernel_shape=[2, 2], pads=[0, 0], strides=[2, 2]](%15), uses = [%17.i0], scope: LeNet;\n",
" %18 : UNKNOWN_TYPE = Conv[kernel_shape=[5, 5], strides=[1, 1], pads=[0, 0, 0, 0], dilations=[1, 1], group=1](%16, %4), uses = [[%19.i0]], scope: LeNet/Conv2d[conv2];\n",
" %19 : Float(1, 16, 10, 10) = Add[broadcast=1, axis=1](%18, %5), uses = [%20.i0], scope: LeNet/Conv2d[conv2];\n",
" %20 : Float(1, 16, 10, 10) = Relu(%19), uses = [%21.i0], scope: LeNet;\n",
" %21 : Float(1, 16, 5, 5) = MaxPool[kernel_shape=[2, 2], pads=[0, 0], strides=[2, 2]](%20), uses = [%22.i0], scope: LeNet;\n",
" %22 : Float(1, 400) = Reshape[shape=[1, -1]](%21), uses = [%25.i0], scope: LeNet;\n",
" %25 : Float(1, 120) = Gemm[alpha=1, beta=1, broadcast=1, transB=1](%22, %6, %7), uses = [%26.i0], scope: LeNet/Linear[fc1];\n",
" %26 : Float(1, 120) = Relu(%25), uses = [%29.i0], scope: LeNet;\n",
" %29 : Float(1, 84) = Gemm[alpha=1, beta=1, broadcast=1, transB=1](%26, %8, %9), uses = [%30.i0], scope: LeNet/Linear[fc2];\n",
" %30 : Float(1, 84) = Relu(%29), uses = [%33.i0], scope: LeNet;\n",
" %33 : Float(1, 10) = Gemm[alpha=1, beta=1, broadcast=1, transB=1](%30, %10, %11), uses = [%34.i0], scope: LeNet/Linear[fc3];\n",
" %34 : Float(1, 10) = Softmax[axis=1](%33), uses = [%35.i0], scope: LeNet;\n",
" %35 : Float(1, 10) = Log(%34), uses = [%0.i0], scope: LeNet;\n",
" return (%35);\n",
"}\n",
"\n"
]
}
],
"source": [
"from torch.autograd import Variable\n",
"import torch.onnx\n",
"import torchvision\n",
"\n",
"model = learn.models.model\n",
"dummy_input = to_gpu(Variable(torch.randn(1, 3, 32, 32)))\n",
"\n",
"torch_out = torch.onnx._export(\n",
" model\n",
" , dummy_input\n",
" , \"lenet_cifar10.onnx\"\n",
" , verbose=True\n",
" , export_params=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lenet_cifar10.onnx\r\n"
]
}
],
"source": [
"!ls lenet*.onnx"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment