Skip to content

Instantly share code, notes, and snippets.

@chetanambi
Created February 3, 2021 16:18
Show Gist options
  • Save chetanambi/c548446086bb19f5c9330caa6209ac44 to your computer and use it in GitHub Desktop.
Save chetanambi/c548446086bb19f5c9330caa6209ac44 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "skl2onnx.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lclvsatpL0hT",
"outputId": "d4ae4d96-d988-4d35-e532-36f996d74644"
},
"source": [
"!pip install skl2onnx onnxruntime"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting skl2onnx\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8a/41/47cb3c420d3a1d0a1ad38ef636ac2d4929c938c2f209582bcf3b33440b1a/skl2onnx-1.7.0-py2.py3-none-any.whl (191kB)\n",
"\u001b[K |████████████████████████████████| 194kB 5.3MB/s \n",
"\u001b[?25hCollecting onnxruntime\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b3/a9/f009251fd1b91a2e1ce6f22d4b5be9936fbd0072842c5087a2a49706c509/onnxruntime-1.6.0-cp36-cp36m-manylinux2014_x86_64.whl (4.1MB)\n",
"\u001b[K |████████████████████████████████| 4.1MB 6.0MB/s \n",
"\u001b[?25hRequirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (1.4.1)\n",
"Requirement already satisfied: scikit-learn>=0.19 in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (0.22.2.post1)\n",
"Requirement already satisfied: protobuf in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (3.12.4)\n",
"Collecting onnxconverter-common>=1.5.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/fe/7a/7e30c643cd7d2ad87689188ef34ce93e657bd14da3605f87bcdbc19cd5b1/onnxconverter_common-1.7.0-py2.py3-none-any.whl (64kB)\n",
"\u001b[K |████████████████████████████████| 71kB 8.2MB/s \n",
"\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (1.15.0)\n",
"Collecting onnx>=1.2.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/f1/db/608877fea324c3a44aaa50dbcb23ff5b7e3d222a7c5511c19d1651db512e/onnx-1.8.1-cp36-cp36m-manylinux2010_x86_64.whl (14.5MB)\n",
"\u001b[K |████████████████████████████████| 14.5MB 324kB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.6/dist-packages (from skl2onnx) (1.19.5)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.19->skl2onnx) (1.0.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf->skl2onnx) (51.3.3)\n",
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.6/dist-packages (from onnx>=1.2.1->skl2onnx) (3.7.4.3)\n",
"Installing collected packages: onnx, onnxconverter-common, skl2onnx, onnxruntime\n",
"Successfully installed onnx-1.8.1 onnxconverter-common-1.7.0 onnxruntime-1.6.0 skl2onnx-1.7.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aW8ldQL3LgYG"
},
"source": [
"# Step 1: Build the model using Sklearn"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YDUIgy9HLYPG",
"outputId": "57af6222-bc42-4999-e25f-0b9ecbd82750"
},
"source": [
"import time\r\n",
"from sklearn.datasets import load_boston\r\n",
"from sklearn.ensemble import RandomForestRegressor\r\n",
"from sklearn.model_selection import train_test_split\r\n",
"\r\n",
"boston = load_boston()\r\n",
"\r\n",
"X, y = boston.data, boston.target\r\n",
"\r\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\r\n",
"\r\n",
"boston_skl = RandomForestRegressor(random_state=42)\r\n",
"boston_skl.fit(X_train, y_train)\r\n",
"\r\n",
"start_time = time.time()\r\n",
"pred_skl = boston_skl.predict(X_test)\r\n",
"\r\n",
"print(\"Time taken by Sklearn: \", time.time() - start_time)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Time taken by Sklearn: 0.008110284805297852\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RB5j--Z0Lj-6"
},
"source": [
"# Step 2: Convert the model from Sklearn to ONNX"
]
},
{
"cell_type": "code",
"metadata": {
"id": "qMRAQXHULl64"
},
"source": [
"from skl2onnx import convert_sklearn\r\n",
"from skl2onnx.common.data_types import FloatTensorType\r\n",
"\r\n",
"initial_type = [('float_input', FloatTensorType([None, 13]))]\r\n",
"onx = convert_sklearn(boston_skl, initial_types=initial_type)\r\n",
"\r\n",
"with open(\"boston.onnx\", \"wb\") as f:\r\n",
" f.write(onx.SerializeToString())"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gtf-XqPbMBml"
},
"source": [
"## Load the model from the disk"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hwKD35RpLvtT"
},
"source": [
"import onnx\r\n",
"onnx_model = onnx.load(\"boston.onnx\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "t9v9MmFYMFtq"
},
"source": [
"# Step 3: Inference/Prediction using ONNX model"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YFgIMI77LuRH",
"outputId": "091642c0-ac67-4d8e-d675-3daee9d47102"
},
"source": [
"import onnxruntime as rt\r\n",
"import numpy as np\r\n",
"\r\n",
"sess = rt.InferenceSession(\"boston.onnx\")\r\n",
"input_name = sess.get_inputs()[0].name\r\n",
"label_name = sess.get_outputs()[0].name\r\n",
"\r\n",
"start_time = time.time()\r\n",
"pred_onnx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]\r\n",
"\r\n",
"print(\"Time taken by ONNX: \", time.time() - start_time)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Time taken by ONNX: 0.0012629032135009766\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment