Skip to content

Instantly share code, notes, and snippets.

@catethos
Created December 23, 2021 14:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catethos/63375d072fc1fd9e3212dfe460a7fa51 to your computer and use it in GitHub Desktop.
Save catethos/63375d072fc1fd9e3212dfe460a7fa51 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "9fcfd4ff-6340-4ef1-a501-ece7fdf138a1",
"metadata": {
"tags": []
},
"source": [
"# Libraries"
]
},
{
"cell_type": "code",
"execution_count": 239,
"id": "a29b0d3a-f593-4684-bba9-77e18e8669e2",
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import beta\n",
"import numpy as np\n",
"from torch import nn\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import onnxruntime as ort"
]
},
{
"cell_type": "markdown",
"id": "38db0175-8d59-4bdc-967c-50ecd23c73e9",
"metadata": {
"tags": []
},
"source": [
"# Data"
]
},
{
"cell_type": "code",
"execution_count": 180,
"id": "8e0dbad0-841b-4b68-b271-d32aae16a0a5",
"metadata": {},
"outputs": [],
"source": [
"x = np.linspace(0,1,150).reshape(150,1).astype(np.float32)\n",
"x = torch.from_numpy(x)\n",
"\n",
"y = beta(1,2).cdf(x).astype(np.float32)\n",
"y = torch.from_numpy(y)"
]
},
{
"cell_type": "markdown",
"id": "8184ce8d-efa9-449b-bc2d-ee93e2c28005",
"metadata": {},
"source": [
"# Model"
]
},
{
"cell_type": "code",
"execution_count": 190,
"id": "15fe497e-762f-4527-8cd1-9dd7764f9745",
"metadata": {},
"outputs": [],
"source": [
"class Network(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super().__init__()\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(1, 25),\n",
" nn.ReLU(),\n",
" nn.Linear(25, 1)\n",
" )\n",
"\n",
"\n",
" def forward(self, x):\n",
" return self.layers(x)"
]
},
{
"cell_type": "markdown",
"id": "c8c38d77-eef5-4551-9eb4-3d23b5ce3b68",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": 220,
"id": "c9751d33-b9ce-4261-a15e-126a12ef99ce",
"metadata": {},
"outputs": [],
"source": [
"net = Network()\n",
"loss_function = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
"\n",
"for epoch in range(0, 500):\n",
" optimizer.zero_grad()\n",
" outputs = net(x)\n",
" loss = loss_function(outputs, y)\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "markdown",
"id": "93994203-db33-4ea8-808c-e981049565d0",
"metadata": {},
"source": [
"## Exporting to ONNX"
]
},
{
"cell_type": "code",
"execution_count": 230,
"id": "1a2236f1-ce20-46dd-b92d-bdc43578c1c9",
"metadata": {},
"outputs": [],
"source": [
"torch.onnx.export(net, x, \"percentile.onnx\",\n",
" input_names = ['input'],\n",
" output_names = ['output'])"
]
},
{
"cell_type": "markdown",
"id": "74fa9197-94d3-4bf9-85b4-7ddfca5d1aaa",
"metadata": {},
"source": [
"# Inference Time"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6de4285b-6b79-4117-9155-171a81964806",
"metadata": {},
"outputs": [],
"source": [
"sess = ort.InferenceSession(\"percentile.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": 248,
"id": "31e6d34f-5f1f-4358-a320-59f956576a73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 110 µs, sys: 5 µs, total: 115 µs\n",
"Wall time: 78.7 µs\n",
"CPU times: user 1.18 ms, sys: 8 µs, total: 1.19 ms\n",
"Wall time: 939 µs\n"
]
}
],
"source": [
"# onnx is so much faster\n",
"%time y_est = sess.run(None, {'input': x.numpy()})[0]\n",
"%time y_true = beta(1,2).cdf(x)"
]
},
{
"cell_type": "code",
"execution_count": 247,
"id": "400350cf-8c99-443e-8912-4bcf7bd05675",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f09319506a0>]"
]
},
"execution_count": 247,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(y_est, y_true)\n",
"plt.plot([0,1],[0,1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9711541-2169-4cee-9b38-c3eec6d3b11e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "290f61ae-d229-4a28-a9d7-da19508c1ff3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "analysis",
"language": "python",
"name": "analysis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment