Created
March 21, 2018 00:02
-
-
Save ThomasDelteil/faffe75b292fbc9c2d610721b5835bc5 to your computer and use it in GitHub Desktop.
Notebook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"nbformat": 4, "cells": [{"source": "# Running inference on MXNet/Gluon from an ONNX model\n\n[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.\n\nIn this tutorial we will:\n \n- learn how to load a pre-trained .onnx model file into MXNet/Gluon\n- learn how to test this model using the sample input/output\n- learn how to test the model on custom images\n\n## Pre-requisite\n\nTo run the tutorial you will need to have installed the following python modules:\n- [MXNet](http://mxnet.incubator.apache.org/install/index.html)\n- [onnx](https://github.com/onnx/onnx)\n- [onnx-mxnet](https://github.com/onnx/onnx-mxnet)\n- matplotlib\n- wget", "cell_type": "markdown", "metadata": {}}, {"source": "import numpy as np\nimport onnx_mxnet\nimport mxnet as mx\nfrom mxnet import gluon, nd\n%matplotlib inline\nimport matplotlib.pyplot as plt\nimport tarfile, os\nimport wget\nimport json", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "### Downloading supporting files\nThese are images and a vizualisation script", "cell_type": "markdown", "metadata": {}}, {"source": "image_folder = \"images\"\nutils_file = \"utils.py\" # contain utils function to plot nice visualization\nimage_net_labels_file = \"image_net_labels.json\"\nimages = ['apron', 'hammerheadshark', 'dog', 'wrench', 'dolphin', 'lotus']\nbase_url = \"https://raw.githubusercontent.com/ThomasDelteil/web-data/tutorial_onnx/mxnet/doc/tutorials/onnx/{}?raw=true\"\n\nif not os.path.isdir(image_folder):\n os.makedirs(image_folder)\n for image in images:\n wget.download(base_url.format(\"{}/{}.jpg\".format(image_folder, image)), image_folder)\nif not os.path.isfile(utils_file):\n wget.download(base_url.format(utils_file)) \nif not os.path.isfile(image_net_labels_file):\n wget.download(base_url.format(image_net_labels_file)) ", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "from utils import *", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Downloading a model from the ONNX model zoo\n\nWe download a pre-trained model, in our case the [vgg16](https://arxiv.org/abs/1409.1556) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file and some sample input/output data.", "cell_type": "markdown", "metadata": {}}, {"source": "base_url = \"https://s3.amazonaws.com/download.onnx/models/\" \ncurrent_model = \"vgg16\"\nmodel_folder = \"model\"\narchive = \"{}.tar.gz\".format(current_model)\narchive_file = os.path.join(model_folder, archive)\nurl = \"{}{}\".format(base_url, archive)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Create the model folder and download the zipped model", "cell_type": "markdown", "metadata": {}}, {"source": "os.makedirs(model_folder, exist_ok=True)\nif not os.path.isfile(archive_file): \n wget.download(url, model_folder)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Extract the model", "cell_type": "markdown", "metadata": {}}, {"source": "if not os.path.isdir(os.path.join(model_folder, current_model)):\n tar = tarfile.open(archive_file, \"r:gz\")\n tar.extractall(model_folder)\n tar.close()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The models have been pre-trained on ImageNet, let's load the label mapping of the 1000 classes.", "cell_type": "markdown", "metadata": {}}, {"source": "categories = json.load(open(image_net_labels_file, 'r'))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Loading the model into MXNet Gluon", "cell_type": "markdown", "metadata": {}}, {"source": "onnx_path = os.path.join(model_folder, current_model, \"model.onnx\")", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We get the symbol and parameter objects", "cell_type": "markdown", "metadata": {}}, {"source": "sym, params = onnx_mxnet.import_model(onnx_path)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We pick a context, CPU or GPU", "cell_type": "markdown", "metadata": {}}, {"source": "ctx = mx.cpu()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "And load them into a MXNet Gluon symbol block. For ONNX models the default input name is `input_0`.", "cell_type": "markdown", "metadata": {}}, {"source": "net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('input_0'))\nnet_params = net.collect_params()\nfor param in params:\n if param in net_params:\n net_params[param]._load_init(params[param], ctx=ctx)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can now cache the computational graph through [hybridization](https://mxnet.incubator.apache.org/tutorials/gluon/hybrid.html) to gain some performance", "cell_type": "markdown", "metadata": {}}, {"source": "net.hybridize()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Test using sample inputs and outputs\nThe model comes with sample input/output we can use to test that whether model is correctly loaded", "cell_type": "markdown", "metadata": {}}, {"source": "numpy_path = os.path.join(model_folder, current_model, 'test_data_0.npz')\nsample = np.load(numpy_path, encoding='bytes')\ninputs = sample['inputs']\noutputs = sample['outputs']", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "print(\"Input format: {}\".format(inputs[0].shape))\nprint(\"Output format: {}\".format(outputs[0].shape))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n \n\n\nWe can visualize the network (requires graphviz installed)", "cell_type": "markdown", "metadata": {}}, {"source": "mx.visualization.plot_network(sym, shape={\"input_0\":inputs[0].shape}, node_attrs={\"shape\":\"oval\",\"fixedsize\":\"false\"})", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n\nThis is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network.", "cell_type": "markdown", "metadata": {}}, {"source": "def run_batch(net, data):\n results = []\n for batch in data:\n outputs = net(batch)\n results.extend([o for o in outputs.asnumpy()])\n return np.array(results)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "result = run_batch(net, nd.array(inputs, ctx))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "print(\"Loaded model and sample output predict the same class: {}\".format(np.argmax(result) == np.argmax(outputs[0])))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\nGood the sample output and our prediction match, now we can run against real data\n\n## Test using real images", "cell_type": "markdown", "metadata": {}}, {"source": "TOP_P = 3 # How many top guesses we show in the visualization", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size.", "cell_type": "markdown", "metadata": {}}, {"source": "def transform(img):\n return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We load two sets of images in memory", "cell_type": "markdown", "metadata": {}}, {"source": "image_net_images = [plt.imread('images/{}.jpg'.format(path)) for path in ['apron', 'hammerheadshark','dog']]\ncaltech101_images = [plt.imread('images/{}.jpg'.format(path)) for path in ['wrench', 'dolphin','lotus']]\nimages = image_net_images + caltech101_images", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "And run them as a batch through the network to get the predictions", "cell_type": "markdown", "metadata": {}}, {"source": "batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), ctx=ctx)\nresult = run_batch(net, [batch])", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "plot_predictions(image_net_images, result[:3], categories, TOP_P)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n**Well done!** Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label\n\nLet's now see the results on the 3 other images", "cell_type": "markdown", "metadata": {}}, {"source": "plot_predictions(caltech101_images, result[3:7], categories, TOP_P)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n\n\n**Hmm, not so good...** Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain `wrench`, `dolphin`, or `lotus` categories and our network has been trained on ImageNet.\n\nLucky for us, the [Caltech101 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) has them, let's see how we can fine-tune our network to classify these categories correctly.\n\nWe show that in our next tutorials:\n\n- [Fine-tuning a ONNX Model using the modern imperative MXNet/Gluon API](addlink)\n- [Fine-tuning a ONNX Model using the symbolic MXNet/Module API](addlink)\n \n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment