Skip to content

Instantly share code, notes, and snippets.

@sampathweb
Created May 24, 2018 01:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sampathweb/7b6fbb35095835e9f5a12136f7494197 to your computer and use it in GitHub Desktop.
Save sampathweb/7b6fbb35095835e9f5a12136f7494197 to your computer and use it in GitHub Desktop.
FastAI - PyTorch Cifar10 Onnx Export Example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Taken from https://github.com/fastai/fastai/blob/master/courses/dl1/cifar10.ipynb"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.imports import *\n",
"from fastai.transforms import *\n",
"from fastai.conv_learner import *\n",
"from fastai.model import *\n",
"from fastai.dataset import *\n",
"from fastai.sgdr import *\n",
"from fastai.plots import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.conv_learner import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = get_data(32,4)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"PATH = \"/data/cifar10/\"\n",
"os.makedirs(PATH,exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"def get_data(sz,bs):\n",
" tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
" return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"bs=128"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"data = get_data(32,4)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"x,y=next(iter(data.trn_dl))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAHaNJREFUeJztnW2MnNd13/9nZmd2Z9/J5fubqEh0LFm2ZYeWhahOnaQxVCOA7CAJ7A+BPhhRUMRADaQfBBdoXKAfnKK24Q+BC7oWorSubbW2ayEwkiiqCzVwooh2JJISbVmiKGrJJZf7vtyd1+c5/bDDlFrd/93hvsxSvv8fQHD2nrnPPXPnOc8zc/9zzjV3hxAiPQrb7YAQYntQ8AuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJIqCX4hE6dlIZzN7EMCXARQB/Bd3/3zs+WO7dvnh224LG6O/NLRga5ZnHfm5mmKhGLFG/LCwHwXSvjZ8LCOvudvcGl6ITjl//jympqY6etvWHfxmVgTwpwB+A8A4gOfM7El3f4n1OXzbbXj6h38XtGV5g49F3FxYXLgJj/8/w0PD1OZoUVuxUA6295XC7QBQKOTckZyPVSjwtyb2cY2+6xE3YkTHKnCr5+EBLXJEW+fnUI+8OB4Ft86HXjZX6/HxA/cd7/i5G5mB+wC84u7n3L0B4JsAHtrA8YQQXWQjwX8QwBs3/D3ebhNCvA3YSPCHPlG95UusmT1iZifN7OT01akNDCeE2Ew2EvzjAA7f8PchAJdWP8ndT7j7cXc/PrZ71waGE0JsJhsJ/ucAHDOz282sDOATAJ7cHLeEEFvNulf73b1lZp8G8FdYkfoec/cXY31q1WX89KUfB21nTj9L++3csTvYfuXK5U7dfRN79+6jtpnZq9R2z7s/GGzft/cQ7TM7c4Xaqsvz1HZg/wFq2znKP0EVyPL2UD9XOIyoGGsRE2cLESWAkbe44hMTK4o96/P/ViGmmmwlG9L53f37AL6/Sb4IIbrIrSN2CiG6ioJfiERR8AuRKAp+IRJFwS9Eomxotf9mmZ2dxLee+NOg7cUzz9N+pUIl3F5cn/vNjCfUNPMqtb149mSw/ReP/RLtc3mCS31TV16jtoMHd1BbMZL0A6J6vetO7uPYKJc+J6/OUNuhI1yOHBgMS4unX+Tvc6N6jdoOHjhMbUeO/iK1DY+EZdHlBS6zjo/z92V4dJDahga4LQfPJC33hftNT/NfxFZr4aS2peVF2mc1uvMLkSgKfiESRcEvRKIo+IVIFAW/EInS1dX+anUJp1/6+6Dt6vQy73ctvDq/a9fouvyYmpqjtsogn5Kc+H55+nU+WM6TTmYmpqltcaGf2rKsTm3XfCnYfvK552gftHq5zfl87No5RG2sXNeFKa4eoJerMHcfO0JtB/5xjNqqzbD/jaVIEtFSjdp6e/h8FAu8puTwMJ+rpSw8V2fHJ2mfnLxlVyMKwWp05xciURT8QiSKgl+IRFHwC5EoCn4hEkXBL0SidFXqazQzXLwYltmqy3xvldzD16jFGpcHY9SafKzGXJPbGmHf6xHpbXg4nJQEANYT2U/GBqhtaJTbFmfCyTGNBk9Y8kXux/Is3xWpOcGlqDuPhLdl+6U9vN7h6wtvKf78T0y+zhOk0Meluawalg+vLXFZbmaWJ8dcusJfc5bxeTyym0t977wzLGP21fi5eP5K+H1uNnif1ejOL0SiKPiFSBQFvxCJouAXIlEU/EIkioJfiETZkNRnZucBLALIALTc/Xjs+XkLWJ4NyyGtVkTqIxtD1StcvorRXOb11ArBzYdXcKLoXW1xqW95kGeP7RzkNd+WFvlr6yvzLLzRcrhmXX8vl5oWIzLr4sIEtd12x35q22N9wfYRK9E+u3aH5UEAmFoMZysCwMJVLvXNz4Rr9V2c5se7MM3r+y01ubzsOb+Xzkxz+XC4Lyzd3v+eY7TPwcbOYPsL5Rdon9Vshs7/q+6uvbeFeJuhj/1CJMpGg98B/LWZ/cjMHtkMh4QQ3WGjH/sfcPdLZrYHwFNm9hN3f+bGJ7QvCo8AQLFHHzSEuFXYUDS6+6X2/5MAvgvgvsBzTrj7cXc/vp4924UQW8O6o9HMBsxs6PpjAB8BcGazHBNCbC0b+di/F8B3zez6cf67u/9lrEO5VMLB3QeDthYpYggATEHZuW997s+UeKHIQs779RTDxmKkqKODH3CoHN7SCgDGIttCjQ1HxpsKZ6tdeo0XC52Z4bLiYIVLc2Nj3P9GIyy/XbjKZdGxEV60tL/GfVyo8szDVy6NB9uvLPDst1bOpeAeD8vOAODOZcB65JgvnDsfbD92hBcm/cg/vz/Y/t+e/BvaZzXrDn53PwfgvevtL4TYXvQlXIhEUfALkSgKfiESRcEvRKIo+IVIlK4W8NwxugOf+K2PB20z8zzLar4azojqG47sMRehtsDlppEKz37bORLOvuofCGewAUAhIgMWjF97dwzy1zY/d5naXn4xvCffy69zOaxkXL46cjvfD7Gvl8uA9Swsp06QLDsAKJV4dmEx49mRM0tcBpwjRTBbzgt4AlwKRkS6zT1SkLXAbTNz4ffmh6dfpn0+9KEPBds94sNbfOr4mUKInysU/EIkioJfiERR8AuRKAp+IRKlq6v9w4MV/Oov3xO0VRf5CvzcbDgppSey3VWMVouvbo/u4MkUlaHwCnweWQG+tsRVjKUqrz3nGV+NPh+pMTc3PRtsj13lSz086aQUqRc4TVapAWBkIJyk08p4Qk1MCYjxs4nwawaApWp4PH4GxK2Z8/fawd+zLItkjBHVZ+Iqf10vv/KzYHutzs+p1ejOL0SiKPiFSBQFvxCJouAXIlEU/EIkioJfiETpqtTnnqFZD0tfWYEnU9SbYUlpqMLr3MWYql6jtqwQTt4BgGY97GM9Iq/kOX9d/b382nt1hsto5155ndqKZLyBEpevShE/5hb4NlN5lUtbfX3hpJ9yLz/lJmb5WNMLXDKdnI+8nySBx/OY9BaRkCM1/Mz4fOSR+6wXwnNSbfDjTVwKb6PWbHIpdTW68wuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJR1pT6zOwxAL8JYNLd72m37QTwLQBHAZwH8LvuzlOQ2mRZjgVSb61W51l9zDa/EJFrItTqXA6ZX+CyUV9vOMOt3Be7hkbknwKvPXfh4nlqm7jEp7pYCEtRu4Z55l4zUktwepHLb8VIncHJ6avh40XkvPGrkWzFRb4VVh6R7TJis4icZ4jU4ityWx7Zcg6RY2YkK7Tcz2sklnrC55VF6jGuppM7/58BeHBV26MAnnb3YwCebv8thHgbsWbwu/szAGZWNT8E4PH248cBfGyT/RJCbDHr/c6/190nAKD9/57Nc0kI0Q22fMHPzB4xs5NmdnIu8n1aCNFd1hv8V8xsPwC0/59kT3T3E+5+3N2Pjw6v77f4QojNZ73B/ySAh9uPHwbwvc1xRwjRLTqR+r4B4MMAdpnZOIA/BvB5AE+Y2acAXADwO50MViwVMLy7HLSVlsPtADA1SzLcMt4nRj0iuxzYybfrqoRrUmK5yiWqpWq4+CgAZJGMv5df5pl71uLy4Z7BsDw0X+d95ue5jNaIZIk1B7h8OL8Ufm0XJvhWYwuR2pNZRMHynPvIsvfyyKlfKPB7osWy8yJbZXlEWmQ+1hpc/gbY3Hde1HbN4Hf3TxLTr3c8ihDilkO/8BMiURT8QiSKgl+IRFHwC5EoCn4hEqWrBTybzRomJ8N7jDVzLttdng1nsdUbnWcw3UhvmcshIwtcfitdawTbPeeSV6HIJbapCV6UcvwcL+C5nyfTYf9QWI984WLYdwCI/fKyVOZzvFjj946Z+fA8Li7z+c3yWOFMbooaScHNPCKJ5THpMLKHYh5xo1Dkc1UgWZW1SIHUq7NhebbV6jzTVXd+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJEpXpb5GI8P46+EMuLzBU7qKzbA8NFJZn/vNGpebJl7lsleh3Bds7ymF2wFgdJTXMLjwGi9muTzPbUMHx6gtt7DsOBuR81otPh95zu8PVxq8AGlGJLFWLCsuUoizGMtWi0itDXLIzHnG3J4DfL/Gwb5hanvj9XDRUgDoKUReG5kSss0gAODS5fBYTRIrIXTnFyJRFPxCJIqCX4hEUfALkSgKfiESpaur/VnmmF8Ir7IO9vNV8b5hklBjsRpnnL4yKcYHwFs8wWh+IZyIU+7jiSB9/fx4b4zzenbFjCfiHNy9k9pm6uFV/VrGk4hiK995zrOI8shqdJaHjxkpj4diKZJsE6klWCjwfgMD4fm/6859tM8Hf/ld1DZ+bora3jh/hdo+8H5+TPfwCv0Lp16jfXKSsHQz6M4vRKIo+IVIFAW/EImi4BciURT8QiSKgl+IROlku67HAPwmgEl3v6fd9jkAvw/genbBZ939+2sdq1AooK8STproH6zQfnUi6VVrEa0pQqXCZZJe5360SLdWzpNwJqcmqG38Ipf6+svhbbcA4FBE6tsxsivYXjwYbgeAv/nBc9R2aTxcPxEAiqT2HAAcOjIabD/2ziO0z/Aol2AXFnhNw2KBy6k7RsPbr43t4WONjnHZ+eJ5uictSj1ccnzXXYeprbc3nJj0yqtv0D47BsPJZMWI7LmaTu78fwbgwUD7l9z93va/NQNfCHFrsWbwu/szAGa64IsQoots5Dv/p83slJk9ZmY7Ns0jIURXWG/wfwXAHQDuBTAB4AvsiWb2iJmdNLOT15Z48QchRHdZV/C7+xV3z9w9B/BVAPdFnnvC3Y+7+/HBAb6YJoToLusKfjPbf8OfHwdwZnPcEUJ0i06kvm8A+DCAXWY2DuCPAXzYzO7Fyj5J5wH8QUeDFYsYGwnXQOstRWqcNcNu9vaOdDLsW49nfKxKmdeDKxXCvi/V+TV0eoZngdWWeObejh6eTTc8wj9BDR8Iz0n/fl73r1rjfnzvyf9LbZGydHjgvnuD7XsOhiVAAMiMZ+69657beb9I7b+Ch6WvVsYzGZcXeD3JkREuA1YiNSXrNS4Hj+wMy7CVAX68q9PTwfZW1nkNvzWD390/GWj+WscjCCFuSfQLPyESRcEvRKIo+IVIFAW/EImi4BciUbpawBPuQCMsK9UitTgd4UzASoVnZsWoV5eprVbnUomRhKnGMpfKrl3jx2s2IoU/K1xSKla4DIhiONtrOSI59kUyCC1ScTOPSGw9pfCpVSnz1+Xg2ZZl5z6W+7iPvf1hP3pK/HgZS98EMNjP5eUXTr5EbfNzc9R2+91Hg+1H7zxA+4yfvhhsb7Q6z3TVnV+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJ0lWpL89z1OvhjKliOSxRAUCxJ5xp18p59lWMAjkeAGQRqSQjvp8/f572uTTJfWzGJKVIRlce8f/iRLgoaDXnY9Ub/DXnTd6vUOR+DO0IF84c2RWRKRGTqbhk2lPm9zAvheXU5dZ8ZCxeBLNviPu4e3f4NQNAneyhCAD9w2H/73rPUdrn1VPjwfab2cFPd34hEkXBL0SiKPiFSBQFvxCJouAXIlG6utpfKvdg3xGybVQxsk5Jau5ZdHWY4+Cr1PDI9TALJ6VcmrsabAeAv/rfp6mtcY0rAWO7j1JbaTdPjumvhVe3RyJJUFWPrOizbCYAsbXl0mB4jvtG+fw2mjxBKs/5an8tYmOCUBZRP3pK/PywEve/0s+VjP4hbsvKYf93HQontAHA2L6wstBzmScQrUZ3fiESRcEvRKIo+IVIFAW/EImi4BciURT8QiRKJ9t1HQbw5wD2AcgBnHD3L5vZTgDfAnAUK1t2/a67z8aO5chR9XCCQ0/sOpSFpb7qEq/FF6MyEKn9F6lZ1/KwH6ORhI6eIpd4CgVew683UsPvcmSbr8JwWKaq2xLtc+nqG9SW59zHmCQ2ORuuMTe4xI/XbHFbsRi7T3HZzjPiY0RZLkQShbKMy8v1nNuGdu+ltrxCZNFIbcV9B8KSeemlCdpnNZ3c+VsA/sjd7wJwP4A/NLO7ATwK4Gl3Pwbg6fbfQoi3CWsGv7tPuPuP248XAZwFcBDAQwAebz/tcQAf2yonhRCbz0195zezowDeB+BZAHvdfQJYuUAA2LPZzgkhto6Og9/MBgF8G8Bn3H3hJvo9YmYnzezk/ML6vqMLITafjoLfzEpYCfyvu/t32s1XzGx/274fwGSor7ufcPfj7n58ZHh9m2wIITafNYPfzAzA1wCcdfcv3mB6EsDD7ccPA/je5rsnhNgqOsnqewDA7wE4bWbPt9s+C+DzAJ4ws08BuADgd9Y6UCvLMLMQrp0Wyx3rLZeD7eVSuH0tFpar1FYn24kBXB2amuKZVHmd70M20M8/CWUZvy6/8pNz1Db2jnAmWAP8K9f45UvUNrxnmNoqg/z0uXxlOti++/YK7VOL1BJk238BvMYjADSb4Yy5cuTUbyzy+n6L8zyDECX+2vp38Qy92QV2/nA9soXwOezgW6itZs3gd/e/BY/NX+94JCHELYV+4SdEoij4hUgUBb8QiaLgFyJRFPxCJEp3t+vyHHVSpDEmUNSzcLZXpX99Ul+1yuW8WGFHdqW8eDG8RRYAXFvi2zTlJZ7x1yCZjAAwVORbmy1fC89V1suLhZbL/DQ4ePsOaitEzp7ZqbDEGdv+C84F31otUtwzUoC0WgtLc31FnjHXG3ld9Qafj7vf/V5qK1Z4kdfFxfAPZi0Snk2SQeiRuViN7vxCJIqCX4hEUfALkSgKfiESRcEvRKIo+IVIlK5Kfe45qo2w9JV5ZH80kld0rcZltLgjfCyPZFIVLSyvlAcj19B+bluu8gyx6UWehTdy8DC1mYUltmtVfrxmZI+8wSEuv2WsOCaA8fFwptriDH/PrMTHarZ4xl+pl0u+Q71hWTTP+Xy0cp6dV6/x1zzMVUC0IsVanRRJ7Y1kCe47cFuwvVTmxVhXozu/EImi4BciURT8QiSKgl+IRFHwC5Eo3V3th6NFEhJ6CnzFtpWFV8Vzj9RTi1AwvqrcU+RT0srDq+I7947SPmORGnjzbyxS26WpKX7MJb710+794dXtrM5XqWtLfCV9cITPVYPnCmGgP5y0VOnjSUleiCWlRLYNK0SUInLILOeJU8uLfCxr8XMuc14bshpJTKrWw7ZWgdd/HBoKb+dWiGw395bndvxMIcTPFQp+IRJFwS9Eoij4hUgUBb8QiaLgFyJR1pT6zOwwgD8HsA8rpfZOuPuXzexzAH4fwPXiZJ919+/HjlUs9mB0lMhiLX4dyok8CHCJKg6XhgoR2Qg9RL4a5HLYrh1hSQYAskkuDeUR2at3JLKNk4flw6zFa9bNzXAZamCAS5XVJS6J3XnsYLC9UOLzu0wkLwBoRmorLl/jSTpL18JyWbXO+yzP8bAY7uPyWyuyxdpCRE6tVsPzWF8M7n0LADi4MzxWHqn9uJpOdP4WgD9y9x+b2RCAH5nZU23bl9z9P3U8mhDilqGTvfomAEy0Hy+a2VkA4cu6EOJtw0195zezowDeB+DZdtOnzeyUmT1mZpFsZiHErUbHwW9mgwC+DeAz7r4A4CsA7gBwL1Y+GXyB9HvEzE6a2clri/w7rhCiu3QU/GZWwkrgf93dvwMA7n7F3TN3zwF8FcB9ob7ufsLdj7v78cEhXplECNFd1gx+MzMAXwNw1t2/eEP7/hue9nEAZzbfPSHEVtHJav8DAH4PwGkze77d9lkAnzSzewE4gPMA/mCtAzVbGS5PLQVtRePXoR6ipJXLEVkuQqPBJapWJLEs87CMUq3HJB4uAx46cju1ve8D76C2VuR1//S12WD79MXwllAAMDvL5aumhd8vAKg2+Ne43qGwj+M/nKF96s1I5l5kZzYn7wsAFD18ig+P9NM+IyP8E2qrzmsQLkbe677KELUNDu8Mtvfsi9QLHApLt8Uy9+Etx1/rCe7+t0CwgmZU0xdC3NroF35CJIqCX4hEUfALkSgKfiESRcEvRKJ0tYDn3FwV/+svToeNkWwvVpOwn2zFtBbLdV55MlLXEWiGM7NqVX681iyXyt79gfdS22CFZwM+c5L/pGJqISxFzVziGWKDg1xHq/TzbMCde7h8VSRvWn9krGbGpcNKhZ8fY7tGqG2wfyB8vH5+6pfKkazJFs88LPfyY8a2gQORuRuRYqHLVSLdFjrPdNWdX4hEUfALkSgKfiESRcEvRKIo+IVIFAW/EInSVamvUAQqw+Gso2bOr0Pz82FZY2aO73UXoxTJfBoZ4QUrS5WwDtjTw31fWuBTXJ3nWWwOnoV34AA14djd4Qpr1eWIHDbMZbSR0bBUBgDlIpftvEXmqpf36YlkK5pxCauvj8uROcJyWa3GJdhiD/ej1/l77ZH9BHPn51xGim7G9iDsLYeLyVokO3Y1uvMLkSgKfiESRcEvRKIo+IVIFAW/EImi4BciUboq9ZVLRRzZF85WW17i6XQ7KmG5aWl5fXv1DfRzCaWvzKWo/oFwFmFfP9+vZH6MF8csRyqZV/by6/IH3/lOfkwmpRW51FRv8my6rM7lq1jR1RrtF5OvuGQH41lxzSafY3aGl/p4Ac/Yfo3FIvejWIwUlDUeavVG+DxuRvbd67Pw+xz1YRW68wuRKAp+IRJFwS9Eoij4hUgUBb8QidLJXn19ZvYPZvaCmb1oZv++3X67mT1rZj8zs2+ZkeVHIcQtSSd3/jqAX3P392JlO+4Hzex+AH8C4EvufgzALIBPbZ2bQojNZs3g9xWul4Qttf85gF8D8D/b7Y8D+NiWeCiE2BI6+s5vZsX2Dr2TAJ4C8CqAOXe/niw9DiCcSC6EuCXpKPjdPXP3ewEcAnAfgLtCTwv1NbNHzOykmZ2s1/ivxYQQ3eWmVvvdfQ7A/wFwP4BRs3/6zeIhAJdInxPuftzdj/dGKq4IIbpLJ6v9u81stP24AuBfADgL4AcAfrv9tIcBfG+rnBRCbD6dJPbsB/C4mRWxcrF4wt3/wsxeAvBNM/sPAP4RwNfWOlCjmeHC5fB2Us2cuzI/H6631mxEtkCKEK/hxxMjSovhenDNZZ5YsnSRJ83suOMd1Fa9wpM6zrz4E2obGA3X6qsuh+cduHVq+LVaPFFr3TX8POxHs7ZM+8Rq+LWcn3MO7mO8hl/Yljv3gyVjZVnnyW5rBr+7nwLwvkD7Oax8/xdCvA3RL/yESBQFvxCJouAXIlEU/EIkioJfiEQxj0gXmz6Y2VUAr7f/3AVgqmuDc+THm5Efb+bt5sdt7r67kwN2NfjfNLDZSXc/vi2Dyw/5IT/0sV+IVFHwC5Eo2xn8J7Zx7BuRH29GfryZn1s/tu07vxBie9HHfiESZVuC38weNLOfmtkrZvbodvjQ9uO8mZ02s+fN7GQXx33MzCbN7MwNbTvN7Kl2QdSnzIzvAba1fnzOzC625+R5M/toF/w4bGY/MLOz7SKx/7rd3tU5ifjR1TnpWtFcd+/qP6xs1vYqgF8AUAbwAoC7u+1H25fzAHZtw7i/AuD9AM7c0PYfATzafvwogD/ZJj8+B+DfdHk+9gN4f/vxEICXAdzd7TmJ+NHVOQFgAAbbj0sAnsVKAZ0nAHyi3f6fAfyrjYyzHXf++wC84u7n3L0B4JsAHtoGP7YNd38GwMyq5oewUggV6FJBVOJH13H3CXf/cfvxIlaKxRxEl+ck4kdX8RW2vGjudgT/QQBv3PD3dhb/dAB/bWY/MrNHtsmH6+x19wlg5SQEsGcbffm0mZ1qfy3Y8q8fN2JmR7FSP+JZbOOcrPID6PKcdKNo7nYEf6hsyXZJDg+4+/sB/EsAf2hmv7JNftxKfAXAHVjZo2ECwBe6NbCZDQL4NoDPuPtCt8btwI+uz4lvoGhup2xH8I8DOHzD37T451bj7pfa/08C+C62tzLRFTPbDwDt/ye3wwl3v9I+8XIAX0WX5sTMSlgJuK+7+3fazV2fk5Af2zUn7bFvumhup2xH8D8H4Fh75bIM4BMAnuy2E2Y2YGZD1x8D+AiAM/FeW8qTWCmECmxjQdTrwdbm4+jCnJiZYaUG5Fl3/+INpq7OCfOj23PStaK53VrBXLWa+VGsrKS+CuDfbpMPv4AVpeEFAC920w8A38DKx8cmVj4JfQrAGICnAfys/f/ObfLjvwI4DeAUVoJvfxf8+GdY+Qh7CsDz7X8f7facRPzo6pwAeA9WiuKewsqF5t/dcM7+A4BXAPwPAL0bGUe/8BMiUfQLPyESRcEvRKIo+IVIFAW/EImi4BciURT8QiSKgl+IRFHwC5Eo/w8v4hsRu8LZkAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7efd93814c18>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(data.trn_ds.denorm(x)[0]);"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"'''LeNet in PyTorch.'''\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class LeNet(nn.Module):\n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16*5*5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" out = F.relu(self.conv1(x))\n",
" out = F.max_pool2d(out, 2)\n",
" out = F.relu(self.conv2(out))\n",
" out = F.max_pool2d(out, 2)\n",
" out = out.view(out.size(0), -1)\n",
" out = F.relu(self.fc1(out))\n",
" out = F.relu(self.fc2(out))\n",
" out = self.fc3(out)\n",
" return F.log_softmax(out)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# FastAI Models use AdaptiveMax and AdaptiveAvgPool that were not suported in PyTorch 0.3 for ONNX conversion\n",
"# Use a Custom Model (as shown here) to get around that or try in PyTorch 0.4 if the issue is resolved\n",
"\n",
"model = LeNet()"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"learn = ConvLearner.from_model_data(model, data)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"# Train the Model\n",
"# learn.fit(0.01, 2)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LeNet(\n",
" (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
" (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
" (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
" (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.models.model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Export to ONNX"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"graph(%1 : Float(1, 3, 32, 32)\n",
" %2 : Float(6, 3, 5, 5)\n",
" %3 : Float(6)\n",
" %4 : Float(16, 6, 5, 5)\n",
" %5 : Float(16)\n",
" %6 : Float(120, 400)\n",
" %7 : Float(120)\n",
" %8 : Float(84, 120)\n",
" %9 : Float(84)\n",
" %10 : Float(10, 84)\n",
" %11 : Float(10)) {\n",
" %13 : UNKNOWN_TYPE = Conv[kernel_shape=[5, 5], strides=[1, 1], pads=[0, 0, 0, 0], dilations=[1, 1], group=1](%1, %2), uses = [[%14.i0]], scope: LeNet/Conv2d[conv1];\n",
" %14 : Float(1, 6, 28, 28) = Add[broadcast=1, axis=1](%13, %3), uses = [%15.i0], scope: LeNet/Conv2d[conv1];\n",
" %15 : Float(1, 6, 28, 28) = Relu(%14), uses = [%16.i0], scope: LeNet;\n",
" %16 : Float(1, 6, 14, 14) = MaxPool[kernel_shape=[2, 2], pads=[0, 0], strides=[2, 2]](%15), uses = [%17.i0], scope: LeNet;\n",
" %18 : UNKNOWN_TYPE = Conv[kernel_shape=[5, 5], strides=[1, 1], pads=[0, 0, 0, 0], dilations=[1, 1], group=1](%16, %4), uses = [[%19.i0]], scope: LeNet/Conv2d[conv2];\n",
" %19 : Float(1, 16, 10, 10) = Add[broadcast=1, axis=1](%18, %5), uses = [%20.i0], scope: LeNet/Conv2d[conv2];\n",
" %20 : Float(1, 16, 10, 10) = Relu(%19), uses = [%21.i0], scope: LeNet;\n",
" %21 : Float(1, 16, 5, 5) = MaxPool[kernel_shape=[2, 2], pads=[0, 0], strides=[2, 2]](%20), uses = [%22.i0], scope: LeNet;\n",
" %22 : Float(1, 400) = Reshape[shape=[1, -1]](%21), uses = [%25.i0], scope: LeNet;\n",
" %25 : Float(1, 120) = Gemm[alpha=1, beta=1, broadcast=1, transB=1](%22, %6, %7), uses = [%26.i0], scope: LeNet/Linear[fc1];\n",
" %26 : Float(1, 120) = Relu(%25), uses = [%29.i0], scope: LeNet;\n",
" %29 : Float(1, 84) = Gemm[alpha=1, beta=1, broadcast=1, transB=1](%26, %8, %9), uses = [%30.i0], scope: LeNet/Linear[fc2];\n",
" %30 : Float(1, 84) = Relu(%29), uses = [%33.i0], scope: LeNet;\n",
" %33 : Float(1, 10) = Gemm[alpha=1, beta=1, broadcast=1, transB=1](%30, %10, %11), uses = [%34.i0], scope: LeNet/Linear[fc3];\n",
" %34 : Float(1, 10) = Softmax[axis=1](%33), uses = [%35.i0], scope: LeNet;\n",
" %35 : Float(1, 10) = Log(%34), uses = [%0.i0], scope: LeNet;\n",
" return (%35);\n",
"}\n",
"\n"
]
}
],
"source": [
"from torch.autograd import Variable\n",
"import torch.onnx\n",
"import torchvision\n",
"\n",
"model = learn.models.model\n",
"dummy_input = to_gpu(Variable(torch.randn(1, 3, 32, 32)))\n",
"\n",
"torch_out = torch.onnx._export(\n",
" model\n",
" , dummy_input\n",
" , \"lenet_cifar10.onnx\"\n",
" , verbose=True\n",
" , export_params=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lenet_cifar10.onnx\r\n"
]
}
],
"source": [
"!ls lenet*.onnx"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment