Created
February 3, 2021 16:18
-
-
Save chetanambi/c548446086bb19f5c9330caa6209ac44 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "skl2onnx.ipynb", | |
"provenance": [], | |
"toc_visible": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "lclvsatpL0hT", | |
"outputId": "d4ae4d96-d988-4d35-e532-36f996d74644" | |
}, | |
"source": [ | |
"!pip install skl2onnx onnxruntime" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting skl2onnx\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8a/41/47cb3c420d3a1d0a1ad38ef636ac2d4929c938c2f209582bcf3b33440b1a/skl2onnx-1.7.0-py2.py3-none-any.whl (191kB)\n", | |
"\u001b[K |████████████████████████████████| 194kB 5.3MB/s \n", | |
"\u001b[?25hCollecting onnxruntime\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b3/a9/f009251fd1b91a2e1ce6f22d4b5be9936fbd0072842c5087a2a49706c509/onnxruntime-1.6.0-cp36-cp36m-manylinux2014_x86_64.whl (4.1MB)\n", | |
"\u001b[K |████████████████████████████████| 4.1MB 6.0MB/s \n", | |
"\u001b[?25hRequirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (1.4.1)\n", | |
"Requirement already satisfied: scikit-learn>=0.19 in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (0.22.2.post1)\n", | |
"Requirement already satisfied: protobuf in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (3.12.4)\n", | |
"Collecting onnxconverter-common>=1.5.1\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/fe/7a/7e30c643cd7d2ad87689188ef34ce93e657bd14da3605f87bcdbc19cd5b1/onnxconverter_common-1.7.0-py2.py3-none-any.whl (64kB)\n", | |
"\u001b[K |████████████████████████████████| 71kB 8.2MB/s \n", | |
"\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (1.15.0)\n", | |
"Collecting onnx>=1.2.1\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/f1/db/608877fea324c3a44aaa50dbcb23ff5b7e3d222a7c5511c19d1651db512e/onnx-1.8.1-cp36-cp36m-manylinux2010_x86_64.whl (14.5MB)\n", | |
"\u001b[K |████████████████████████████████| 14.5MB 324kB/s \n", | |
"\u001b[?25hRequirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (1.19.5)\n", | |
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.19->skl2onnx) (1.0.0)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf->skl2onnx) (51.3.3)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.6/dist-packages (from onnx>=1.2.1->skl2onnx) (3.7.4.3)\n", | |
"Installing collected packages: onnx, onnxconverter-common, skl2onnx, onnxruntime\n", | |
"Successfully installed onnx-1.8.1 onnxconverter-common-1.7.0 onnxruntime-1.6.0 skl2onnx-1.7.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aW8ldQL3LgYG" | |
}, | |
"source": [ | |
"# Step 1: Build the model using Sklearn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "YDUIgy9HLYPG", | |
"outputId": "57af6222-bc42-4999-e25f-0b9ecbd82750" | |
}, | |
"source": [ | |
"import time\r\n", | |
"from sklearn.datasets import load_boston\r\n", | |
"from sklearn.ensemble import RandomForestRegressor\r\n", | |
"from sklearn.model_selection import train_test_split\r\n", | |
"\r\n", | |
"boston = load_boston()\r\n", | |
"\r\n", | |
"X, y = boston.data, boston.target\r\n", | |
"\r\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\r\n", | |
"\r\n", | |
"boston_skl = RandomForestRegressor(random_state=42)\r\n", | |
"boston_skl.fit(X_train, y_train)\r\n", | |
"\r\n", | |
"start_time = time.time()\r\n", | |
"pred_skl = boston_skl.predict(X_test)\r\n", | |
"\r\n", | |
"print(\"Time taken by Sklearn: \", time.time() - start_time)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Time taken by Sklearn: 0.008110284805297852\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "RB5j--Z0Lj-6" | |
}, | |
"source": [ | |
"# Step 2: Convert the model from Sklearn to ONNX" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qMRAQXHULl64" | |
}, | |
"source": [ | |
"from skl2onnx import convert_sklearn\r\n", | |
"from skl2onnx.common.data_types import FloatTensorType\r\n", | |
"\r\n", | |
"initial_type = [('float_input', FloatTensorType([None, 13]))]\r\n", | |
"onx = convert_sklearn(boston_skl, initial_types=initial_type)\r\n", | |
"\r\n", | |
"with open(\"boston.onnx\", \"wb\") as f:\r\n", | |
" f.write(onx.SerializeToString())" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Gtf-XqPbMBml" | |
}, | |
"source": [ | |
"## Load the model from the disk" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hwKD35RpLvtT" | |
}, | |
"source": [ | |
"import onnx\r\n", | |
"onnx_model = onnx.load(\"boston.onnx\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "t9v9MmFYMFtq" | |
}, | |
"source": [ | |
"# Step 3: Inference/Prediction using ONNX model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "YFgIMI77LuRH", | |
"outputId": "091642c0-ac67-4d8e-d675-3daee9d47102" | |
}, | |
"source": [ | |
"import onnxruntime as rt\r\n", | |
"import numpy as np\r\n", | |
"\r\n", | |
"sess = rt.InferenceSession(\"boston.onnx\")\r\n", | |
"input_name = sess.get_inputs()[0].name\r\n", | |
"label_name = sess.get_outputs()[0].name\r\n", | |
"\r\n", | |
"start_time = time.time()\r\n", | |
"pred_onnx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]\r\n", | |
"\r\n", | |
"print(\"Time taken by ONNX: \", time.time() - start_time)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Time taken by ONNX: 0.0012629032135009766\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment