-
-
Save guillaume-be/ebf4977131f93882add1516073ad2f42 to your computer and use it in GitHub Desktop.
IO Bindings with onnxruntime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "59fb24ff", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import onnxruntime\n", | |
"import torch\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e2a609b1", | |
"metadata": {}, | |
"source": [ | |
"### Create and export model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "8085be8f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor([[-0.8987, 0.2827, -0.9207, -0.3547, 0.2086],\n", | |
" [-0.1900, 0.2688, -1.2061, -0.7395, -0.6321],\n", | |
" [ 0.5187, 0.2548, -1.4914, -1.1244, -1.4728],\n", | |
" [ 1.2275, 0.2408, -1.7768, -1.5093, -2.3135],\n", | |
" [ 1.9362, 0.2269, -2.0622, -1.8941, -3.1543],\n", | |
" [ 2.6449, 0.2129, -2.3475, -2.2790, -3.9950],\n", | |
" [ 3.3536, 0.1990, -2.6329, -2.6638, -4.8357],\n", | |
" [ 4.0623, 0.1850, -2.9183, -3.0487, -5.6764]],\n", | |
" grad_fn=<AddmmBackward0>)\n", | |
"============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 =============\n", | |
"verbose: False, log level: Level.ERROR\n", | |
"======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"class Net(torch.nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.dense = torch.nn.Linear(in_features=2, out_features=5)\n", | |
" \n", | |
" def forward(self, input_0: torch.Tensor) -> torch.Tensor:\n", | |
" return self.dense(input_0)\n", | |
"\n", | |
"model = Net()\n", | |
"input_0 = torch.arange(8* 2).view(8,2).float()\n", | |
"print(model(input_0))\n", | |
"torch.onnx.export(\n", | |
" model, \n", | |
" input_0, \n", | |
" \"net.onnx\", \n", | |
" input_names=[\"some_input\"], \n", | |
" output_names=[\"some_output\"],\n", | |
" dynamic_axes={\n", | |
" 'some_input': {0: 'batch_size'}\n", | |
" }\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ad648021", | |
"metadata": {}, | |
"source": [ | |
"### Run basic ONNX inference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "b403d130", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ -0.89867944, 0.28273344, -0.9206959 , -0.35467649,\n", | |
" 0.20861459],\n", | |
" [ -0.18996829, 0.26876992, -1.2060621 , -0.7395382 ,\n", | |
" -0.6321051 ],\n", | |
" [ 0.5187429 , 0.25480622, -1.4914281 , -1.1243999 ,\n", | |
" -1.4728248 ],\n", | |
" [ 1.2274541 , 0.24084276, -1.7767942 , -1.5092618 ,\n", | |
" -2.3135445 ],\n", | |
" [ 1.9361652 , 0.2268793 , -2.0621605 , -1.8941233 ,\n", | |
" -3.1542642 ],\n", | |
" [ 2.644876 , 0.2129156 , -2.3475266 , -2.278985 ,\n", | |
" -3.9949834 ],\n", | |
" [ 3.3535876 , 0.1989519 , -2.6328928 , -2.663847 ,\n", | |
" -4.835704 ],\n", | |
" [ 4.062299 , 0.1849882 , -2.9182587 , -3.0487087 ,\n", | |
" -5.6764235 ],\n", | |
" [ 4.7710094 , 0.1710245 , -3.203625 , -3.4335701 ,\n", | |
" -6.5171432 ],\n", | |
" [ 5.4797206 , 0.15706176, -3.4889905 , -3.818432 ,\n", | |
" -7.3578625 ],\n", | |
" [ 6.1884317 , 0.14309806, -3.7743576 , -4.2032933 ,\n", | |
" -8.198583 ],\n", | |
" [ 6.897143 , 0.1291334 , -4.059723 , -4.5881553 ,\n", | |
" -9.039302 ],\n", | |
" [ 7.605855 , 0.11517066, -4.345089 , -4.973017 ,\n", | |
" -9.880022 ],\n", | |
" [ 8.314566 , 0.10120696, -4.630455 , -5.357878 ,\n", | |
" -10.720741 ],\n", | |
" [ 9.023276 , 0.08724326, -4.915822 , -5.74274 ,\n", | |
" -11.561461 ],\n", | |
" [ 9.731987 , 0.07327956, -5.201188 , -6.127601 ,\n", | |
" -12.402181 ]], dtype=float32)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ort_session = onnxruntime.InferenceSession(\"net.onnx\", providers=[\"CUDAExecutionProvider\"])\n", | |
"\n", | |
"outputs = ort_session.run(\n", | |
" None,\n", | |
" {\"some_input\": np.arange(2*16).reshape(16, 2).astype(np.float32)},\n", | |
")\n", | |
"outputs[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0e974564", | |
"metadata": {}, | |
"source": [ | |
"### Use I/O-binding" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "dcb2cdac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ -0.8987, 0.2827, -0.9207, -0.3547, 0.2086],\n", | |
" [ -0.1900, 0.2688, -1.2061, -0.7395, -0.6321],\n", | |
" [ 0.5187, 0.2548, -1.4914, -1.1244, -1.4728],\n", | |
" [ 1.2275, 0.2408, -1.7768, -1.5093, -2.3135],\n", | |
" [ 1.9362, 0.2269, -2.0622, -1.8941, -3.1543],\n", | |
" [ 2.6449, 0.2129, -2.3475, -2.2790, -3.9950],\n", | |
" [ 3.3536, 0.1990, -2.6329, -2.6638, -4.8357],\n", | |
" [ 4.0623, 0.1850, -2.9183, -3.0487, -5.6764],\n", | |
" [ 4.7710, 0.1710, -3.2036, -3.4336, -6.5171],\n", | |
" [ 5.4797, 0.1571, -3.4890, -3.8184, -7.3579],\n", | |
" [ 6.1884, 0.1431, -3.7744, -4.2033, -8.1986],\n", | |
" [ 6.8971, 0.1291, -4.0597, -4.5882, -9.0393],\n", | |
" [ 7.6059, 0.1152, -4.3451, -4.9730, -9.8800],\n", | |
" [ 8.3146, 0.1012, -4.6305, -5.3579, -10.7207],\n", | |
" [ 9.0233, 0.0872, -4.9158, -5.7427, -11.5615],\n", | |
" [ 9.7320, 0.0733, -5.2012, -6.1276, -12.4022]], device='cuda:0')" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ort_session = onnxruntime.InferenceSession(\"net.onnx\", providers=[\"CUDAExecutionProvider\"])\n", | |
"io_binding = ort_session.io_binding()\n", | |
"\n", | |
"input_0 = torch.arange(2*16).view((16, 2)).float().cuda()\n", | |
"\n", | |
"io_binding.bind_input(\n", | |
" \"some_input\",\n", | |
" input_0.device.type,\n", | |
" 0,\n", | |
" np.float32,\n", | |
" tuple(input_0.shape),\n", | |
" input_0.data_ptr(),\n", | |
")\n", | |
"\n", | |
"output_buffer = torch.empty(\n", | |
" (input_0.shape[0], 5),\n", | |
" dtype=torch.float, \n", | |
" device=\"cuda:0\").contiguous()\n", | |
"output_buffer\n", | |
"\n", | |
"io_binding.bind_output(\n", | |
" \"some_output\",\n", | |
" output_buffer.device.type,\n", | |
" 0,\n", | |
" np.float32,\n", | |
" [16,5],\n", | |
" output_buffer.data_ptr(),\n", | |
")\n", | |
"\n", | |
"io_binding.synchronize_inputs()\n", | |
"ort_session.run_with_iobinding(io_binding)\n", | |
"io_binding.synchronize_outputs()\n", | |
"\n", | |
"output_buffer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b0973e53", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "transformers", | |
"language": "python", | |
"name": "transformers" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment