Skip to content

Instantly share code, notes, and snippets.

@halhorn
Created February 12, 2020 06:22
Show Gist options
  • Save halhorn/1439d583d900bc27c45c556c21d4371c to your computer and use it in GitHub Desktop.
Save halhorn/1439d583d900bc27c45c556c21d4371c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TFTRT"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys,os\n",
"sys.path.append(os.getcwd())\n",
"while os.getcwd().split('/')[-1] != 'ml_sandbox': os.chdir('..')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check Environment"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wed Feb 12 09:04:24 2020 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:1E.0 Off | 0 |\n",
"| N/A 43C P0 27W / 70W | 1328MiB / 15109MiB | 0% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: GPU Memory |\n",
"| GPU PID Type Process name Usage |\n",
"|=============================================================================|\n",
"| 0 22454 C ...envs/ml_sandbox-YhSVM9Gx/bin/python3.6m 1317MiB |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensorflow version: 2.1.0\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"print(\"Tensorflow version: \", tf.version.VERSION)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor Core GPU Present: True\n"
]
}
],
"source": [
"# confirm gpu environment\n",
"from tensorflow.python.client import device_lib\n",
"\n",
"def check_tensor_core_gpu_present():\n",
" local_device_protos = device_lib.list_local_devices()\n",
" for line in local_device_protos:\n",
" if \"compute capability\" in str(line):\n",
" compute_capability = float(line.physical_device_desc.split(\"compute capability: \")[-1])\n",
" if compute_capability>=7.0:\n",
" return True\n",
" \n",
"print(\"Tensor Core GPU Present:\", check_tensor_core_gpu_present())\n",
"tensor_core_gpu = check_tensor_core_gpu_present()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"compute_dtype = tf.float16"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using TFTRT"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.python.compiler.tensorrt import trt_convert as trt\n",
"from tensorflow.python.saved_model import tag_constants\n",
"if compute_dtype == tf.float16:\n",
" tf.keras.mixed_precision.experimental.set_policy(\"mixed_float16\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Functional API"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class MyModel(tf.keras.models.Model):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.dense_layers = []\n",
" self.dense_layers.append(tf.keras.layers.Dense(128, activation='relu', input_shape=(1,)))\n",
" for i in range(100):\n",
" self.dense_layers.append(tf.keras.layers.Dense(128, activation='relu'))\n",
" self.dense_layers.append(tf.keras.layers.Dense(2, name='mypredict'))\n",
" \n",
" @tf.function(input_signature=[{\n",
" 'x1': tf.TensorSpec(shape=[None, 1], dtype=compute_dtype, name='x1'),\n",
" 'x2': tf.TensorSpec(shape=[None, 1], dtype=compute_dtype, name='x2'),\n",
" }])\n",
" def call(self, inputs):\n",
" x1 = inputs['x1']\n",
" x2 = inputs['x2']\n",
" x = tf.concat([x1, x2], axis=-1)\n",
" for l in self.dense_layers:\n",
" x = l(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model = MyModel()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0.],\n",
" [0., 0.]], dtype=float16)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict({\n",
" 'x1': tf.constant([[1], [2]], compute_dtype),\n",
" 'x2': tf.constant([[1], [2]], compute_dtype),\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model_dir = 'tmp/models/dense/dense_normal'"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /home/ubuntu/.local/share/virtualenvs/ml_sandbox-YhSVM9Gx/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"If using Keras pass *_constraint arguments to layers.\n",
"INFO:tensorflow:Assets written to: tmp/models/dense/dense_normal/assets\n"
]
}
],
"source": [
"model.save(model_dir)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2020-02-12 09:04:34.615769: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer.so.6\n",
"2020-02-12 09:04:34.617253: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer_plugin.so.6\n",
"\n",
"MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
"\n",
"signature_def['__saved_model_init_op']:\n",
" The given SavedModel SignatureDef contains the following input(s):\n",
" The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['__saved_model_init_op'] tensor_info:\n",
" dtype: DT_INVALID\n",
" shape: unknown_rank\n",
" name: NoOp\n",
" Method name is: \n",
"\n",
"signature_def['serving_default']:\n",
" The given SavedModel SignatureDef contains the following input(s):\n",
" inputs['x1'] tensor_info:\n",
" dtype: DT_HALF\n",
" shape: (-1, 1)\n",
" name: serving_default_x1:0\n",
" inputs['x2'] tensor_info:\n",
" dtype: DT_HALF\n",
" shape: (-1, 1)\n",
" name: serving_default_x2:0\n",
" The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['output_1'] tensor_info:\n",
" dtype: DT_HALF\n",
" shape: (-1, 2)\n",
" name: StatefulPartitionedCall:0\n",
" Method name is: tensorflow/serving/predict\n",
"WARNING:tensorflow:From /home/ubuntu/.local/share/virtualenvs/ml_sandbox-YhSVM9Gx/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"If using Keras pass *_constraint arguments to layers.\n",
"\n",
"Defined Functions:\n",
" Function Name: '__call__'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1'), 'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2')}\n",
"\n",
" Function Name: '_default_save_signature'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1'), 'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2')}\n",
"\n",
" Function Name: 'call'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1'), 'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2')}\n",
"\n",
" Function Name: 'call_and_return_all_conditional_losses'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1'), 'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2')}\n"
]
}
],
"source": [
"!saved_model_cli show --all --dir tmp/models/dense/dense_normal"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# raw TF"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2, 2), dtype=float16, numpy=\n",
"array([[0., 0.],\n",
" [0., 0.]], dtype=float16)>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf_model = tf.saved_model.load(model_dir)\n",
"tf_model({\n",
" 'x1': tf.constant([[1], [2]], compute_dtype),\n",
" 'x2': tf.constant([[1], [2]], compute_dtype),\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.5 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"tf_model({\n",
" 'x1': tf.constant([[1], [2]], compute_dtype),\n",
" 'x2': tf.constant([[1], [2]], compute_dtype),\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TFTRT"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converting to TF-TRT...\n",
"precision_mode: FP16\n",
"INFO:tensorflow:Linked TensorRT version: (6, 0, 1)\n",
"INFO:tensorflow:Loaded TensorRT version: (6, 0, 1)\n",
"INFO:tensorflow:Could not find TRTEngineOp_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
"INFO:tensorflow:Assets written to: tmp/models/dense/dense_trt/assets\n",
"Done Converting to TF-TRT\n"
]
}
],
"source": [
"print('Converting to TF-TRT...')\n",
"trt_model_dir = 'tmp/models/dense/dense_trt'\n",
"precision_mode = trt.TrtPrecisionMode.FP16 if compute_dtype == tf.float16 else trt.TrtPrecisionMode.FP32\n",
"print('precision_mode:', precision_mode)\n",
"conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=precision_mode,\n",
" max_workspace_size_bytes=8000000000)\n",
"\n",
"converter = trt.TrtGraphConverterV2(input_saved_model_dir=model_dir,\n",
" conversion_params=conversion_params)\n",
"converter.convert()\n",
"converter.save(output_saved_model_dir=trt_model_dir)\n",
"print('Done Converting to TF-TRT')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2020-02-12 09:04:51.539274: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer.so.6\n",
"2020-02-12 09:04:51.540759: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer_plugin.so.6\n",
"\n",
"MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
"\n",
"signature_def['__saved_model_init_op']:\n",
" The given SavedModel SignatureDef contains the following input(s):\n",
" The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['__saved_model_init_op'] tensor_info:\n",
" dtype: DT_INVALID\n",
" shape: unknown_rank\n",
" name: NoOp\n",
" Method name is: \n",
"\n",
"signature_def['serving_default']:\n",
" The given SavedModel SignatureDef contains the following input(s):\n",
" inputs['x1'] tensor_info:\n",
" dtype: DT_HALF\n",
" shape: (-1, 1)\n",
" name: serving_default_x1:0\n",
" inputs['x2'] tensor_info:\n",
" dtype: DT_HALF\n",
" shape: (-1, 1)\n",
" name: serving_default_x2:0\n",
" The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['output_1'] tensor_info:\n",
" dtype: DT_HALF\n",
" shape: unknown_rank\n",
" name: PartitionedCall:0\n",
" Method name is: tensorflow/serving/predict\n",
"WARNING:tensorflow:From /home/ubuntu/.local/share/virtualenvs/ml_sandbox-YhSVM9Gx/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"If using Keras pass *_constraint arguments to layers.\n",
"\n",
"Defined Functions:\n",
" Function Name: '__call__'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1'), 'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2')}\n",
"\n",
" Function Name: '_default_save_signature'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2'), 'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1')}\n",
"\n",
" Function Name: 'call'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1'), 'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2')}\n",
"\n",
" Function Name: 'call_and_return_all_conditional_losses'\n",
" Option #1\n",
" Callable with:\n",
" Argument #1\n",
" DType: dict\n",
" Value: {'x2': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x2'), 'x1': TensorSpec(shape=(None, 1), dtype=tf.float16, name='x1')}\n"
]
}
],
"source": [
"!saved_model_cli show --all --dir tmp/models/dense/dense_trt"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['serving_default']\n",
"{'output_1': TensorSpec(shape=<unknown>, dtype=tf.float16, name='output_1')}\n",
"[[-0. 0.]\n",
" [-0. 0.]]\n"
]
}
],
"source": [
"saved_model_loaded = tf.saved_model.load(trt_model_dir, tags=[tag_constants.SERVING])\n",
"signature_keys = list(saved_model_loaded.signatures.keys())\n",
"print(signature_keys)\n",
"\n",
"infer = saved_model_loaded.signatures['serving_default']\n",
"print(infer.structured_outputs)\n",
"\n",
"y = infer(\n",
" x1=tf.constant([[1], [2]], compute_dtype),\n",
" x2=tf.constant([[1], [2]], compute_dtype),\n",
")\n",
"preds = y['output_1'].numpy()\n",
"print(preds)\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.31 ms ± 7.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"infer(\n",
" x1=tf.constant([[1], [2]], compute_dtype),\n",
" x2=tf.constant([[1], [2]], compute_dtype),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment