Last active
October 31, 2018 18:16
-
-
Save rsandler00/a63cf0db03d745dde4ea07df8ec5abbc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow.contrib.tensorrt as trt | |
import numpy as np | |
import PIL | |
from timeit import default_timer as timer | |
from tqdm import tqdm | |
''' | |
This script benchmarks how long it takes to run perform inference on a pure Tensorflow (TF) model vs a converted TensorRT model | |
In order to run it, it assumes you have a frozen TF graph in the form of a *.pb file. | |
The specific model is MobileNetv2 | |
(https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet( | |
Frozen graph can be downloaded from: https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz | |
''' | |
## File paths (needs to be set) | |
pb_path = "mobilenet_v2_1.0_224_frozen.pb" # frozen TF graph | |
input_name = 'input' # input layer of the graph | |
output_names = ['MobilenetV2/Predictions/Reshape_1'] # output layer of the graph. (Can be multiple) | |
img_path = 'data/panda.jpg' # image to perform inference on | |
## Timing paramters | |
trls = 100 # number of inference trials to average over | |
## Load image | |
img_raw = np.array(PIL.Image.open(img_path).resize((224, 224))).astype(np.float) / 128 - 1 | |
img = img_raw.reshape(1, 224,224, 3) | |
##### Run w/ standard TF | |
frozen_graph = tf.GraphDef.FromString(open(pb_path, 'rb').read()) | |
inp, predictions = tf.import_graph_def(frozen_graph, return_elements = ['input:0', 'MobilenetV2/Predictions/Reshape_1:0']) | |
with tf.Session(graph=inp.graph): | |
start = timer() | |
tf_trl_time = [] | |
for _ in tqdm(range(trls)): | |
x = predictions.eval(feed_dict={inp: img + np.random.random(img.shape)/10}) | |
tf_trl_time += [timer() - start] | |
tf_avg_trl_time = np.mean(np.diff(tf_trl_time)) | |
print("TF: Top 1 prediction: ", x.argmax(), x.max()) | |
print('TF took on average %.3fs over %d trls' % (tf_avg_trl_time, trls)) | |
##### Run w/ TensorRT | |
trt_graph = trt.create_inference_graph( | |
input_graph_def=frozen_graph, | |
outputs=output_names, | |
max_batch_size=1, | |
precision_mode='FP16', # 'INT8'/'FP16' | |
) | |
tf_config = tf.ConfigProto() | |
tf_config.gpu_options.allow_growth = True | |
tf_sess = tf.Session(config=tf_config) | |
tf.import_graph_def(trt_graph, name='') | |
tf_input = tf_sess.graph.get_tensor_by_name(input_name + ':0') | |
tf_output = tf_sess.graph.get_tensor_by_name(output_names[0] + ':0') | |
start = timer() | |
trt_trl_time = [] | |
for _ in tqdm(range(trls)): | |
output = tf_sess.run(tf_output, feed_dict={tf_input: img + np.random.random(img.shape)/10}) | |
trt_trl_time += [timer() - start] | |
trt_avg_trl_time = np.mean(np.diff(trt_trl_time)) | |
print("TRT: Top 1 prediction: ", output.argmax(), output.max()) | |
print('TRT took on average %.3fs over %d trls' % (trt_avg_trl_time, trls)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment