Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import numpy as np
import PIL
from timeit import default_timer as timer
from tqdm import tqdm
'''
This script benchmarks how long it takes to run perform inference on a pure Tensorflow (TF) model vs a converted TensorRT model
In order to run it, it assumes you have a frozen TF graph in the form of a *.pb file.
The specific model is MobileNetv2
(https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet(
Frozen graph can be downloaded from: https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz
'''
## File paths (needs to be set)
pb_path = "mobilenet_v2_1.0_224_frozen.pb" # frozen TF graph
input_name = 'input' # input layer of the graph
output_names = ['MobilenetV2/Predictions/Reshape_1'] # output layer of the graph. (Can be multiple)
img_path = 'data/panda.jpg' # image to perform inference on
## Timing paramters
trls = 100 # number of inference trials to average over
## Load image
img_raw = np.array(PIL.Image.open(img_path).resize((224, 224))).astype(np.float) / 128 - 1
img = img_raw.reshape(1, 224,224, 3)
##### Run w/ standard TF
frozen_graph = tf.GraphDef.FromString(open(pb_path, 'rb').read())
inp, predictions = tf.import_graph_def(frozen_graph, return_elements = ['input:0', 'MobilenetV2/Predictions/Reshape_1:0'])
with tf.Session(graph=inp.graph):
start = timer()
tf_trl_time = []
for _ in tqdm(range(trls)):
x = predictions.eval(feed_dict={inp: img + np.random.random(img.shape)/10})
tf_trl_time += [timer() - start]
tf_avg_trl_time = np.mean(np.diff(tf_trl_time))
print("TF: Top 1 prediction: ", x.argmax(), x.max())
print('TF took on average %.3fs over %d trls' % (tf_avg_trl_time, trls))
##### Run w/ TensorRT
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=output_names,
max_batch_size=1,
precision_mode='FP16', # 'INT8'/'FP16'
)
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_sess = tf.Session(config=tf_config)
tf.import_graph_def(trt_graph, name='')
tf_input = tf_sess.graph.get_tensor_by_name(input_name + ':0')
tf_output = tf_sess.graph.get_tensor_by_name(output_names[0] + ':0')
start = timer()
trt_trl_time = []
for _ in tqdm(range(trls)):
output = tf_sess.run(tf_output, feed_dict={tf_input: img + np.random.random(img.shape)/10})
trt_trl_time += [timer() - start]
trt_avg_trl_time = np.mean(np.diff(trt_trl_time))
print("TRT: Top 1 prediction: ", output.argmax(), output.max())
print('TRT took on average %.3fs over %d trls' % (trt_avg_trl_time, trls))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment