Skip to content

Instantly share code, notes, and snippets.

@tzechienchu
Forked from yoku001/TVM-quantization.ipynb
Created April 3, 2019 04:00
Show Gist options
  • Save tzechienchu/40305e0f17bc873194759c37ab0adb0e to your computer and use it in GitHub Desktop.
Save tzechienchu/40305e0f17bc873194759c37ab0adb0e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import numpy as np\n",
"\n",
"def download(url, path, overwrite=False):\n",
" import os\n",
" if os.path.isfile(path) and not overwrite:\n",
" print('File {} exists, skip.'.format(path))\n",
" return\n",
" print('Downloading from url {} to {}'.format(url, path))\n",
" try:\n",
" import urllib.request\n",
" urllib.request.urlretrieve(url, path)\n",
" except:\n",
" import urllib\n",
" urllib.urlretrieve(url, path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import keras\n",
"model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"from matplotlib import pyplot as plt\n",
"from keras.applications.resnet50 import preprocess_input\n",
"\n",
"# prepare data\n",
"img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'\n",
"download(img_url, 'cat.png')\n",
"img = Image.open('cat.png').resize((224, 224))\n",
"plt.imshow(img)\n",
"plt.show()\n",
"\n",
"# input preprocess\n",
"data = np.array(img)[np.newaxis, :].astype('float32')\n",
"data = preprocess_input(data).transpose([0, 3, 1, 2])\n",
"print('input_1', data.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import tvm\n",
"import tvm.relay as relay\n",
"\n",
"input_name = 'input_1'\n",
"shape_dict = {input_name: data.shape}\n",
"func, params = relay.frontend.from_keras(model, shape_dict)\n",
"\n",
"with relay.quantize.qconfig(global_scale=8.0):\n",
" func_quant = relay.quantize.quantize(func, params)\n",
" print(str(func_quant))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target = 'llvm'\n",
"\n",
"with relay.build_config(opt_level=0):\n",
" modules = relay.build_module.build(func, target, params=params)\n",
" modules_quant = relay.build_module.build(func_quant, target, params=params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def predict_tvm(modules, data):\n",
" # create module\n",
" graph, lib, params = modules\n",
" module = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu())\n",
"\n",
" # set input and parameters\n",
" module.set_input(input_name, tvm.nd.array(data.astype(np.float32)))\n",
" module.set_input(**params)\n",
" \n",
" # get output\n",
" module.run()\n",
" return module.get_output(0, tvm.nd.empty((1, 1000))).asnumpy()\n",
"\n",
"def show_top5_accuracy(y_pred, synset):\n",
" top5_ids = y_pred.flatten().argsort()[::-1][:5]\n",
"\n",
" for i, c in enumerate(top5_ids):\n",
" print(f'{i+1} : {c} {synset[c]}')\n",
"\n",
"# get ImageNet synset dictionary\n",
"synset_url = 'https://gist.githubusercontent.com/zhreshold/' \\\n",
" '4d0b62f3d01426887599d4f7ede23ee5/raw/' \\\n",
" '596b27d23537e5a1b5751d2b0481ef172f58b539/' \\\n",
" 'imagenet1000_clsid_to_human.txt'\n",
"synset_name = 'synset.txt'\n",
"download(synset_url, synset_name)\n",
"\n",
"with open(synset_name) as f:\n",
" synset = eval(f.read()) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred_tvm = predict_tvm(modules, data)\n",
"y_pred_tvm_quant = predict_tvm(modules_quant, data)\n",
"\n",
"print('\\nw/o quantization')\n",
"show_top5_accuracy(y_pred_tvm, synset)\n",
"\n",
"print('\\nwith quantization')\n",
"show_top5_accuracy(y_pred_tvm_quant, synset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save weight\n",
"model.save_weights('weight_keras.h5')\n",
"\n",
"params = modules[2]\n",
"with open(\"deploy_param.params\", \"wb\") as f:\n",
" f.write(relay.save_param_dict(params))\n",
"\n",
"params_quant = modules_quant[2]\n",
"with open(\"deploy_param_quant.params\", \"wb\") as f:\n",
" f.write(relay.save_param_dict(params_quant))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph, lib, params = modules\n",
"module = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu())\n",
"\n",
"# set input and parameters\n",
"module.set_input(input_name, tvm.nd.array(data.astype(np.float32)))\n",
"module.set_input(**params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph, lib, params = modules_quant\n",
"module_quant = tvm.contrib.graph_runtime.create(graph, lib, tvm.cpu())\n",
"\n",
"# set input and parameters\n",
"module_quant.set_input(input_name, tvm.nd.array(data.astype(np.float32)))\n",
"module_quant.set_input(**params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('w/o quantization')\n",
"%timeit -n10 module.run()\n",
"\n",
"print('\\nwith quantization')\n",
"%timeit -n10 module_quant.run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment