import tensorflow as tf | |
import tensorflow.contrib.tensorrt as trt | |
import numpy as np | |
import PIL | |
from timeit import default_timer as timer | |
from tqdm import tqdm | |
''' | |
This script benchmarks how long it takes to run perform inference on a pure Tensorflow (TF) model vs a converted TensorRT model | |
In order to run it, it assumes you have a frozen TF graph in the form of a *.pb file. | |
The specific model is MobileNetv2 | |
(https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet( | |
Frozen graph can be downloaded from: https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz | |
''' | |
## File paths (needs to be set) | |
pb_path = "mobilenet_v2_1.0_224_frozen.pb" # frozen TF graph | |
input_name = 'input' # input layer of the graph | |
output_names = ['MobilenetV2/Predictions/Reshape_1'] # output layer of the graph. (Can be multiple) | |
img_path = 'data/panda.jpg' # image to perform inference on | |
## Timing paramters | |
trls = 100 # number of inference trials to average over | |
## Load image | |
img_raw = np.array(PIL.Image.open(img_path).resize((224, 224))).astype(np.float) / 128 - 1 | |
img = img_raw.reshape(1, 224,224, 3) | |
##### Run w/ standard TF | |
frozen_graph = tf.GraphDef.FromString(open(pb_path, 'rb').read()) | |
inp, predictions = tf.import_graph_def(frozen_graph, return_elements = ['input:0', 'MobilenetV2/Predictions/Reshape_1:0']) | |
with tf.Session(graph=inp.graph): | |
start = timer() | |
tf_trl_time = [] | |
for _ in tqdm(range(trls)): | |
x = predictions.eval(feed_dict={inp: img + np.random.random(img.shape)/10}) | |
tf_trl_time += [timer() - start] | |
tf_avg_trl_time = np.mean(np.diff(tf_trl_time)) | |
print("TF: Top 1 prediction: ", x.argmax(), x.max()) | |
print('TF took on average %.3fs over %d trls' % (tf_avg_trl_time, trls)) | |
##### Run w/ TensorRT | |
trt_graph = trt.create_inference_graph( | |
input_graph_def=frozen_graph, | |
outputs=output_names, | |
max_batch_size=1, | |
precision_mode='FP16', # 'INT8'/'FP16' | |
) | |
tf_config = tf.ConfigProto() | |
tf_config.gpu_options.allow_growth = True | |
tf_sess = tf.Session(config=tf_config) | |
tf.import_graph_def(trt_graph, name='') | |
tf_input = tf_sess.graph.get_tensor_by_name(input_name + ':0') | |
tf_output = tf_sess.graph.get_tensor_by_name(output_names[0] + ':0') | |
start = timer() | |
trt_trl_time = [] | |
for _ in tqdm(range(trls)): | |
output = tf_sess.run(tf_output, feed_dict={tf_input: img + np.random.random(img.shape)/10}) | |
trt_trl_time += [timer() - start] | |
trt_avg_trl_time = np.mean(np.diff(trt_trl_time)) | |
print("TRT: Top 1 prediction: ", output.argmax(), output.max()) | |
print('TRT took on average %.3fs over %d trls' % (trt_avg_trl_time, trls)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment