Skip to content

Instantly share code, notes, and snippets.

Last active October 31, 2018 18:16
Show Gist options
  • Save rsandler00/a63cf0db03d745dde4ea07df8ec5abbc to your computer and use it in GitHub Desktop.
Save rsandler00/a63cf0db03d745dde4ea07df8ec5abbc to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import numpy as np
import PIL
from timeit import default_timer as timer
from tqdm import tqdm
This script benchmarks how long it takes to run perform inference on a pure Tensorflow (TF) model vs a converted TensorRT model
In order to run it, it assumes you have a frozen TF graph in the form of a *.pb file.
The specific model is MobileNetv2
Frozen graph can be downloaded from:
## File paths (needs to be set)
pb_path = "mobilenet_v2_1.0_224_frozen.pb" # frozen TF graph
input_name = 'input' # input layer of the graph
output_names = ['MobilenetV2/Predictions/Reshape_1'] # output layer of the graph. (Can be multiple)
img_path = 'data/panda.jpg' # image to perform inference on
## Timing paramters
trls = 100 # number of inference trials to average over
## Load image
img_raw = np.array(, 224))).astype(np.float) / 128 - 1
img = img_raw.reshape(1, 224,224, 3)
##### Run w/ standard TF
frozen_graph = tf.GraphDef.FromString(open(pb_path, 'rb').read())
inp, predictions = tf.import_graph_def(frozen_graph, return_elements = ['input:0', 'MobilenetV2/Predictions/Reshape_1:0'])
with tf.Session(graph=inp.graph):
start = timer()
tf_trl_time = []
for _ in tqdm(range(trls)):
x = predictions.eval(feed_dict={inp: img + np.random.random(img.shape)/10})
tf_trl_time += [timer() - start]
tf_avg_trl_time = np.mean(np.diff(tf_trl_time))
print("TF: Top 1 prediction: ", x.argmax(), x.max())
print('TF took on average %.3fs over %d trls' % (tf_avg_trl_time, trls))
##### Run w/ TensorRT
trt_graph = trt.create_inference_graph(
precision_mode='FP16', # 'INT8'/'FP16'
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_sess = tf.Session(config=tf_config)
tf.import_graph_def(trt_graph, name='')
tf_input = tf_sess.graph.get_tensor_by_name(input_name + ':0')
tf_output = tf_sess.graph.get_tensor_by_name(output_names[0] + ':0')
start = timer()
trt_trl_time = []
for _ in tqdm(range(trls)):
output =, feed_dict={tf_input: img + np.random.random(img.shape)/10})
trt_trl_time += [timer() - start]
trt_avg_trl_time = np.mean(np.diff(trt_trl_time))
print("TRT: Top 1 prediction: ", output.argmax(), output.max())
print('TRT took on average %.3fs over %d trls' % (trt_avg_trl_time, trls))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment