Skip to content

Instantly share code, notes, and snippets.

@guillaume-be
Created July 22, 2023 08:07
Show Gist options
  • Save guillaume-be/ebf4977131f93882add1516073ad2f42 to your computer and use it in GitHub Desktop.
Save guillaume-be/ebf4977131f93882add1516073ad2f42 to your computer and use it in GitHub Desktop.
IO Bindings with onnxruntime
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "59fb24ff",
"metadata": {},
"outputs": [],
"source": [
"import onnxruntime\n",
"import torch\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"id": "e2a609b1",
"metadata": {},
"source": [
"### Create and export model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8085be8f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-0.8987, 0.2827, -0.9207, -0.3547, 0.2086],\n",
" [-0.1900, 0.2688, -1.2061, -0.7395, -0.6321],\n",
" [ 0.5187, 0.2548, -1.4914, -1.1244, -1.4728],\n",
" [ 1.2275, 0.2408, -1.7768, -1.5093, -2.3135],\n",
" [ 1.9362, 0.2269, -2.0622, -1.8941, -3.1543],\n",
" [ 2.6449, 0.2129, -2.3475, -2.2790, -3.9950],\n",
" [ 3.3536, 0.1990, -2.6329, -2.6638, -4.8357],\n",
" [ 4.0623, 0.1850, -2.9183, -3.0487, -5.6764]],\n",
" grad_fn=<AddmmBackward0>)\n",
"============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 =============\n",
"verbose: False, log level: Level.ERROR\n",
"======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
"\n"
]
}
],
"source": [
"class Net(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.dense = torch.nn.Linear(in_features=2, out_features=5)\n",
" \n",
" def forward(self, input_0: torch.Tensor) -> torch.Tensor:\n",
" return self.dense(input_0)\n",
"\n",
"model = Net()\n",
"input_0 = torch.arange(8* 2).view(8,2).float()\n",
"print(model(input_0))\n",
"torch.onnx.export(\n",
" model, \n",
" input_0, \n",
" \"net.onnx\", \n",
" input_names=[\"some_input\"], \n",
" output_names=[\"some_output\"],\n",
" dynamic_axes={\n",
" 'some_input': {0: 'batch_size'}\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ad648021",
"metadata": {},
"source": [
"### Run basic ONNX inference"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b403d130",
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ -0.89867944, 0.28273344, -0.9206959 , -0.35467649,\n",
" 0.20861459],\n",
" [ -0.18996829, 0.26876992, -1.2060621 , -0.7395382 ,\n",
" -0.6321051 ],\n",
" [ 0.5187429 , 0.25480622, -1.4914281 , -1.1243999 ,\n",
" -1.4728248 ],\n",
" [ 1.2274541 , 0.24084276, -1.7767942 , -1.5092618 ,\n",
" -2.3135445 ],\n",
" [ 1.9361652 , 0.2268793 , -2.0621605 , -1.8941233 ,\n",
" -3.1542642 ],\n",
" [ 2.644876 , 0.2129156 , -2.3475266 , -2.278985 ,\n",
" -3.9949834 ],\n",
" [ 3.3535876 , 0.1989519 , -2.6328928 , -2.663847 ,\n",
" -4.835704 ],\n",
" [ 4.062299 , 0.1849882 , -2.9182587 , -3.0487087 ,\n",
" -5.6764235 ],\n",
" [ 4.7710094 , 0.1710245 , -3.203625 , -3.4335701 ,\n",
" -6.5171432 ],\n",
" [ 5.4797206 , 0.15706176, -3.4889905 , -3.818432 ,\n",
" -7.3578625 ],\n",
" [ 6.1884317 , 0.14309806, -3.7743576 , -4.2032933 ,\n",
" -8.198583 ],\n",
" [ 6.897143 , 0.1291334 , -4.059723 , -4.5881553 ,\n",
" -9.039302 ],\n",
" [ 7.605855 , 0.11517066, -4.345089 , -4.973017 ,\n",
" -9.880022 ],\n",
" [ 8.314566 , 0.10120696, -4.630455 , -5.357878 ,\n",
" -10.720741 ],\n",
" [ 9.023276 , 0.08724326, -4.915822 , -5.74274 ,\n",
" -11.561461 ],\n",
" [ 9.731987 , 0.07327956, -5.201188 , -6.127601 ,\n",
" -12.402181 ]], dtype=float32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ort_session = onnxruntime.InferenceSession(\"net.onnx\", providers=[\"CUDAExecutionProvider\"])\n",
"\n",
"outputs = ort_session.run(\n",
" None,\n",
" {\"some_input\": np.arange(2*16).reshape(16, 2).astype(np.float32)},\n",
")\n",
"outputs[0]"
]
},
{
"cell_type": "markdown",
"id": "0e974564",
"metadata": {},
"source": [
"### Use I/O-binding"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dcb2cdac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ -0.8987, 0.2827, -0.9207, -0.3547, 0.2086],\n",
" [ -0.1900, 0.2688, -1.2061, -0.7395, -0.6321],\n",
" [ 0.5187, 0.2548, -1.4914, -1.1244, -1.4728],\n",
" [ 1.2275, 0.2408, -1.7768, -1.5093, -2.3135],\n",
" [ 1.9362, 0.2269, -2.0622, -1.8941, -3.1543],\n",
" [ 2.6449, 0.2129, -2.3475, -2.2790, -3.9950],\n",
" [ 3.3536, 0.1990, -2.6329, -2.6638, -4.8357],\n",
" [ 4.0623, 0.1850, -2.9183, -3.0487, -5.6764],\n",
" [ 4.7710, 0.1710, -3.2036, -3.4336, -6.5171],\n",
" [ 5.4797, 0.1571, -3.4890, -3.8184, -7.3579],\n",
" [ 6.1884, 0.1431, -3.7744, -4.2033, -8.1986],\n",
" [ 6.8971, 0.1291, -4.0597, -4.5882, -9.0393],\n",
" [ 7.6059, 0.1152, -4.3451, -4.9730, -9.8800],\n",
" [ 8.3146, 0.1012, -4.6305, -5.3579, -10.7207],\n",
" [ 9.0233, 0.0872, -4.9158, -5.7427, -11.5615],\n",
" [ 9.7320, 0.0733, -5.2012, -6.1276, -12.4022]], device='cuda:0')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ort_session = onnxruntime.InferenceSession(\"net.onnx\", providers=[\"CUDAExecutionProvider\"])\n",
"io_binding = ort_session.io_binding()\n",
"\n",
"input_0 = torch.arange(2*16).view((16, 2)).float().cuda()\n",
"\n",
"io_binding.bind_input(\n",
" \"some_input\",\n",
" input_0.device.type,\n",
" 0,\n",
" np.float32,\n",
" tuple(input_0.shape),\n",
" input_0.data_ptr(),\n",
")\n",
"\n",
"output_buffer = torch.empty(\n",
" (input_0.shape[0], 5),\n",
" dtype=torch.float, \n",
" device=\"cuda:0\").contiguous()\n",
"output_buffer\n",
"\n",
"io_binding.bind_output(\n",
" \"some_output\",\n",
" output_buffer.device.type,\n",
" 0,\n",
" np.float32,\n",
" [16,5],\n",
" output_buffer.data_ptr(),\n",
")\n",
"\n",
"io_binding.synchronize_inputs()\n",
"ort_session.run_with_iobinding(io_binding)\n",
"io_binding.synchronize_outputs()\n",
"\n",
"output_buffer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0973e53",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "transformers",
"language": "python",
"name": "transformers"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment