Skip to content

Instantly share code, notes, and snippets.

@vishwanath79
Last active December 28, 2020 05:14
Show Gist options
  • Save vishwanath79/0db14419c2841c843d09bf5420a1bdd0 to your computer and use it in GitHub Desktop.
Save vishwanath79/0db14419c2841c843d09bf5420a1bdd0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"from tensorflow import keras\n",
"import onnxruntime"
]
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [],
"source": [
"model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])\n",
"# 1 layer 1 value"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"dense_2 (Dense) (None, 1) 2 \n",
"=================================================================\n",
"Total params: 2\n",
"Trainable params: 2\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.compile(optimizer='sgd',loss='mean_squared_error')\n",
"model.summary()\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [],
"source": [
"# Neural network to predict y = 3x + 1\n",
"xs = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=float)\n",
"ys = np.array([4.0, 7.0, 10.0, 13.0, 16.0, 19.0], dtype=float)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [
{
"data": {
"text/plain": "<tensorflow.python.keras.callbacks.History at 0x7f96ba8d43d0>"
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Train model\n",
"model.fit(xs,ys,epochs=500,verbose=False)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 28,
"outputs": [
{
"data": {
"text/plain": "array([[31.02069]], dtype=float32)"
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"original = model.predict([10.0])\n",
"\n",
"original"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 29,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: tf/saved/assets\n"
]
}
],
"source": [
"# save model architecture, weights, and training configuration in a single file/folder\n",
"model.save('tf/saved')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 30,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2020-12-27 12:44:43,732 - WARNING - '--tag' not specified for saved_model. Using --tag serve\r\n",
"2020-12-27 12:44:43,800 - INFO - Signatures found in model: [serving_default].\r\n",
"2020-12-27 12:44:43,800 - WARNING - '--signature_def' not specified, using first signature: serving_default\r\n",
"WARNING:tensorflow:From /Users/vishwanath/opt/miniconda3/envs/onnx_virt/lib/python3.8/site-packages/tf2onnx/tf_loader.py:416: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\r\n",
"Instructions for updating:\r\n",
"Use `tf.compat.v1.graph_util.extract_sub_graph`\r\n",
"2020-12-27 12:44:43,816 - WARNING - From /Users/vishwanath/opt/miniconda3/envs/onnx_virt/lib/python3.8/site-packages/tf2onnx/tf_loader.py:416: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\r\n",
"Instructions for updating:\r\n",
"Use `tf.compat.v1.graph_util.extract_sub_graph`\r\n",
"2020-12-27 12:44:43,823 - INFO - Using tensorflow=2.4.0, onnx=1.8.0, tf2onnx=1.7.2/995bd6\r\n",
"2020-12-27 12:44:43,823 - INFO - Using opset <onnx, 10>\r\n",
"2020-12-27 12:44:43,823 - INFO - Computed 0 values for constant folding\r\n",
"2020-12-27 12:44:43,829 - INFO - Optimizing ONNX model\r\n",
"2020-12-27 12:44:43,836 - INFO - After optimization: Identity -5 (5->0)\r\n",
"2020-12-27 12:44:43,836 - INFO - \r\n",
"2020-12-27 12:44:43,837 - INFO - Successfully converted TensorFlow model tf/saved to ONNX\r\n",
"2020-12-27 12:44:43,837 - INFO - ONNX model is saved at tf/output/simplemodel.onnx\r\n"
]
}
],
"source": [
"#Convert to ONNX\n",
"\n",
"# opset 8 to generate the graph. By specifying --opset the user can override the default to generate a graph with the desired opset. For example --opset 5 would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.\n",
"!python -m tf2onnx.convert --opset 10 --saved-model tf/saved --output tf/output/simplemodel.onnx\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 31,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input name dense_2_input:0\n",
"input shape ['unk__6', 1]\n"
]
}
],
"source": [
"path = 'tf/output/simplemodel.onnx'\n",
"sess = onnxruntime.InferenceSession(path)\n",
"# get name and shape\n",
"input_name = sess.get_inputs()[0].name\n",
"print(\"input name\", input_name)\n",
"input_shape = sess.get_inputs()[0].shape\n",
"print(\"input shape\", input_shape)\n",
"\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 32,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Input values to ONNX model = \n",
"[[10.]]\n",
"\n",
"Output value from ONNX model = \n",
"[array([[31.02069]], dtype=float32)]\n"
]
}
],
"source": [
"# predict\n",
"x = np.array([[10.0]], dtype=np.float32)\n",
"print(\"\\nInput values to ONNX model = \")\n",
"print(x)\n",
"res = sess.run(None, {input_name: x})\n",
"print(\"\\nOutput value from ONNX model = \")\n",
"print(res)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 33,
"outputs": [],
"source": [
"# Check if equal\n",
"assert res == original"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 33,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment