Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active December 22, 2022 12:48
Show Gist options
  • Save sayakpaul/ba4a4c47fcc661b9d18ea3b53e51f82e to your computer and use it in GitHub Desktop.
Save sayakpaul/ba4a4c47fcc661b9d18ea3b53e51f82e to your computer and use it in GitHub Desktop.
import time
from typing import Union
import numpy as np
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants
BATCH_SIZE = 8
BATCH_INPUT = tf.random.normal((BATCH_SIZE, 224, 224, 3))
N_WARMUP_RUN = 50
N_RUN = 1000
def benchmark(model: Union[tf.keras.Model, str]) -> str:
"""Benchmarking utility for a TensorFlow model and its optimized
TRT version.
Args:
model: Either a TensorFlow model of instance `tf.keras.Model` or a path to
the Saved TensorRT model.
Returns:
a string containing throughput information for the given model.
References:
* https://github.com/tensorflow/tensorrt/blob/master/tftrt/benchmarking-python/image_classification/NGC-TFv2-TF-TRT-inference-from-Keras-saved-model.ipynb
"""
elapsed_time = []
if isinstance(model, tf.keras.Model):
predict_fn = model.predict
else:
saved_model_loaded = tf.saved_model.load(model, tags=[tag_constants.SERVING])
predict_fn = saved_model_loaded.signatures["serving_default"]
for i in range(N_WARMUP_RUN):
_ = predict_fn(BATCH_INPUT)
for i in range(N_RUN):
start_time = time.time()
_ = predict_fn(BATCH_INPUT)
end_time = time.time()
elapsed_time = np.append(elapsed_time, end_time - start_time)
if i % 50 == 0:
print("Step {}: {:4.1f}ms".format(i, (elapsed_time[-50:].mean()) * 1000))
return_str = "Throughput: {:.0f} images/s".format(
N_RUN * BATCH_SIZE / elapsed_time.sum()
)
print(return_str)
return return_str
print("Starting TF benchmarking.")
model = tf.keras.applications.ResNet50(weights="imagenet")
_ = benchmark(model)
print("Converting to TF-TRT FP32...")
model.save("resnet50_saved_model")
converter = trt.TrtGraphConverterV2(
input_saved_model_dir="resnet50_saved_model",
precision_mode=trt.TrtPrecisionMode.FP32,
max_workspace_size_bytes=8000000000,
)
converter.convert()
trt_model_path = "resnet50_saved_model_TFTRT_FP32"
converter.save(output_saved_model_dir=trt_model_path)
print("Done Converting to TF-TRT FP32")
print("Starting TRT benchmarking.")
benchmark(trt_model_path)

Machine information

Parameter Value
OS Ubuntu 20.04.5
Python 3.8.10
CUDA 11.8
TensorFlow 2.10.1
TensorRT 8.5.1
GPU T4

Results

TensorFlow

Throughput: 89 images/s

TensorRT

Throughput: 497 images/s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment