Skip to content

Instantly share code, notes, and snippets.

@ThomasDelteil
Created March 21, 2018 00:02
Show Gist options
  • Save ThomasDelteil/faffe75b292fbc9c2d610721b5835bc5 to your computer and use it in GitHub Desktop.
Save ThomasDelteil/faffe75b292fbc9c2d610721b5835bc5 to your computer and use it in GitHub Desktop.
Notebook
Display the source blob
Display the rendered blob
Raw
{"nbformat": 4, "cells": [{"source": "# Running inference on MXNet/Gluon from an ONNX model\n\n[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n\nIn this tutorial we will:\n \n- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n- learn how to test this model using the sample input/output\n- learn how to test the model on custom images\n\n## Pre-requisite\n\nTo run the tutorial you will need to have installed the following python modules:\n- [MXNet](http://mxnet.incubator.apache.org/install/index.html)\n- [onnx](https://github.com/onnx/onnx)\n- [onnx-mxnet](https://github.com/onnx/onnx-mxnet)\n- matplotlib\n- wget", "cell_type": "markdown", "metadata": {}}, {"source": "import numpy as np\nimport onnx_mxnet\nimport mxnet as mx\nfrom mxnet import gluon, nd\n%matplotlib inline\nimport matplotlib.pyplot as plt\nimport tarfile, os\nimport wget\nimport json", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "### Downloading supporting files\nThese are images and a vizualisation script", "cell_type": "markdown", "metadata": {}}, {"source": "image_folder = \"images\"\nutils_file = \"utils.py\" # contain utils function to plot nice visualization\nimage_net_labels_file = \"image_net_labels.json\"\nimages = ['apron', 'hammerheadshark', 'dog', 'wrench', 'dolphin', 'lotus']\nbase_url = \"https://raw.githubusercontent.com/ThomasDelteil/web-data/tutorial_onnx/mxnet/doc/tutorials/onnx/{}?raw=true\"\n\nif not os.path.isdir(image_folder):\n os.makedirs(image_folder)\n for image in images:\n wget.download(base_url.format(\"{}/{}.jpg\".format(image_folder, image)), image_folder)\nif not os.path.isfile(utils_file):\n wget.download(base_url.format(utils_file)) \nif not os.path.isfile(image_net_labels_file):\n wget.download(base_url.format(image_net_labels_file)) ", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "from utils import *", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Downloading a model from the ONNX model zoo\n\nWe download a pre-trained model, in our case the [vgg16](https://arxiv.org/abs/1409.1556) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file and some sample input/output data.", "cell_type": "markdown", "metadata": {}}, {"source": "base_url = \"https://s3.amazonaws.com/download.onnx/models/\" \ncurrent_model = \"vgg16\"\nmodel_folder = \"model\"\narchive = \"{}.tar.gz\".format(current_model)\narchive_file = os.path.join(model_folder, archive)\nurl = \"{}{}\".format(base_url, archive)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Create the model folder and download the zipped model", "cell_type": "markdown", "metadata": {}}, {"source": "os.makedirs(model_folder, exist_ok=True)\nif not os.path.isfile(archive_file): \n wget.download(url, model_folder)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Extract the model", "cell_type": "markdown", "metadata": {}}, {"source": "if not os.path.isdir(os.path.join(model_folder, current_model)):\n tar = tarfile.open(archive_file, \"r:gz\")\n tar.extractall(model_folder)\n tar.close()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes.", "cell_type": "markdown", "metadata": {}}, {"source": "categories = json.load(open(image_net_labels_file, 'r'))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Loading the model into MXNet Gluon", "cell_type": "markdown", "metadata": {}}, {"source": "onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We get the symbol and parameter objects", "cell_type": "markdown", "metadata": {}}, {"source": "sym, params = onnx_mxnet.import_model(onnx_path)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We pick a context, CPU or GPU", "cell_type": "markdown", "metadata": {}}, {"source": "ctx = mx.cpu()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "And load them into a MXNet Gluon symbol block. For ONNX models the default input name is `input_0`.", "cell_type": "markdown", "metadata": {}}, {"source": "net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0'))\nnet_params = net.collect_params()\nfor param in params:\n if param in net_params:\n net_params[param]._load_init(params[param], ctx=ctx)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can now cache the computational graph through [hybridization](https://mxnet.incubator.apache.org/tutorials/gluon/hybrid.html) to gain some performance", "cell_type": "markdown", "metadata": {}}, {"source": "net.hybridize()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Test using sample inputs and outputs\nThe model comes with sample input/output we can use to test that whether model is correctly loaded", "cell_type": "markdown", "metadata": {}}, {"source": "numpy_path = os.path.join(model_folder, current_model, 'test_data_0.npz')\nsample = np.load(numpy_path, encoding='bytes')\ninputs = sample['inputs']\noutputs = sample['outputs']", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "print(\"Input format: {}\".format(inputs[0].shape))\nprint(\"Output format: {}\".format(outputs[0].shape))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n \n\n\nWe can visualize the network (requires graphviz installed)", "cell_type": "markdown", "metadata": {}}, {"source": "mx.visualization.plot_network(sym, shape={\"input_0\":inputs[0].shape}, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n\nThis is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network.", "cell_type": "markdown", "metadata": {}}, {"source": "def run_batch(net, data):\n results = []\n for batch in data:\n outputs = net(batch)\n results.extend([o for o in outputs.asnumpy()])\n return np.array(results)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "result = run_batch(net, nd.array(inputs, ctx))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "print(\"Loaded model and sample output predict the same class: {}\".format(np.argmax(result) == np.argmax(outputs[0])))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\nGood the sample output and our prediction match, now we can run against real data\n\n## Test using real images", "cell_type": "markdown", "metadata": {}}, {"source": "TOP_P = 3 # How many top guesses we show in the visualization", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size.", "cell_type": "markdown", "metadata": {}}, {"source": "def transform(img):\n return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We load two sets of images in memory", "cell_type": "markdown", "metadata": {}}, {"source": "image_net_images = [plt.imread('images/{}.jpg'.format(path)) for path in ['apron', 'hammerheadshark','dog']]\ncaltech101_images = [plt.imread('images/{}.jpg'.format(path)) for path in ['wrench', 'dolphin','lotus']]\nimages = image_net_images + caltech101_images", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "And run them as a batch through the network to get the predictions", "cell_type": "markdown", "metadata": {}}, {"source": "batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), ctx=ctx)\nresult = run_batch(net, [batch])", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "plot_predictions(image_net_images, result[:3], categories, TOP_P)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n\nLet's now see the results on the 3 other images", "cell_type": "markdown", "metadata": {}}, {"source": "plot_predictions(caltech101_images, result[3:7], categories, TOP_P)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n\n\n**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n\nLucky for us, the [Caltech101 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) has them, let's see how we can fine-tune our network to classify these categories correctly.\n\nWe show that in our next tutorials:\n\n- [Fine-tuning a ONNX Model using the modern imperative MXNet/Gluon API](addlink)\n- [Fine-tuning a ONNX Model using the symbolic MXNet/Module API](addlink)\n \n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment