Skip to content

Instantly share code, notes, and snippets.

@csvance
Last active April 8, 2019 17:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csvance/47ec78d67894c0d454ca98029d4d323c to your computer and use it in GitHub Desktop.
Save csvance/47ec78d67894c0d454ca98029d4d323c to your computer and use it in GitHub Desktop.
import copy
import numpy as np
import sys
import time
import plac
class FrozenGraph(object):
def __init__(self, model, shape):
import tensorflow.keras as keras
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.contrib import tensorrt as tftrt
shape = (None, shape[0], shape[1], shape[2])
x_name = 'image_tensor_x'
with K.get_session() as sess:
x_tensor = tf.placeholder(tf.float32, shape, x_name)
K.set_learning_phase(0)
y_tensor = model(x_tensor)
y_name = y_tensor.name[:-2]
graph = sess.graph.as_graph_def()
graph0 = tf.graph_util.convert_variables_to_constants(sess, graph, [y_name])
graph1 = tf.graph_util.remove_training_nodes(graph0)
self.x_name = [x_name]
self.y_name = [y_name]
self.frozen = graph1
class TfEngine(object):
def __init__(self, graph):
import tensorflow.keras as keras
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.contrib import tensorrt as tftrt
g = tf.Graph()
with g.as_default():
x_op, y_op = tf.import_graph_def(
graph_def=graph.frozen, return_elements=graph.x_name + graph.y_name)
self.x_tensor = x_op.outputs[0]
self.y_tensor = y_op.outputs[0]
config = tf.ConfigProto(gpu_options=
tf.GPUOptions(per_process_gpu_memory_fraction=0.5,
allow_growth=True))
self.sess = tf.Session(graph=g, config=config)
def infer(self, x):
y = self.sess.run(self.y_tensor,
feed_dict={self.x_tensor: x})
return y
class TftrtEngine(TfEngine):
def __init__(self, graph, batch_size, precision):
import tensorflow.keras as keras
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.contrib import tensorrt as tftrt
tftrt_graph = tftrt.create_inference_graph(
graph.frozen,
outputs=graph.y_name,
max_batch_size=batch_size,
max_workspace_size_bytes=1 << 30,
precision_mode=precision,
minimum_segment_size=2)
self.tftrt_graph = tftrt_graph
opt_graph = copy.deepcopy(graph)
opt_graph.frozen = tftrt_graph
super(TftrtEngine, self).__init__(opt_graph)
self.batch_size = batch_size
def infer(self, x):
num_tests = x.shape[0]
y = np.empty((num_tests, 1), np.float32)
batch_size = self.batch_size
for i in range(0, num_tests, batch_size):
x_part = x[i: i + batch_size]
y_part = self.sess.run(self.y_tensor,
feed_dict={self.x_tensor: x_part})
y[i: i + batch_size] = y_part
return y
@plac.annotations(
inference_type=("Type of inference to test (TF, FP32, FP16, INT8)", 'option', 'T', str),
batch_size=("Size of the TensorRT batch", 'option', 'B', int),
test_size=("Number of samples run the inference on", 'option', 'S', int),
input_dims=("Comma seperate input dimensions ie 224, 224, 3", 'option', 'D', str),
model_path=("Saved Keras model", 'positional', None, str)
)
def main(inference_type: str = "FP16",
batch_size: int = 1,
test_size: int = 1,
input_dims: str = "224, 224, 3",
model_path: str = "mobilenet.h5"):
import tensorflow.keras as keras
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.contrib import tensorrt as tftrt
model = load_model(model_path)
model_dims = [int(d) for d in input_dims.split(",")]
frozen_graph = FrozenGraph(model, model_dims)
test_dims = [int(d) for d in input_dims.split(",")]
test_dims.insert(0, test_size)
x_test = np.random.random(test_dims)
if inference_type == 'TF':
tf_engine = TfEngine(frozen_graph)
t0 = time.time()
y_tf = tf_engine.infer(x_test)
print(y_tf)
t1 = time.time()
elif inference_type == 'FP32':
tftrt_engine = TftrtEngine(frozen_graph, batch_size, 'FP32')
t0 = time.time()
y_tftrt = tftrt_engine.infer(x_test)
print(y_tftrt)
t1 = time.time()
elif inference_type == 'FP16':
tftrt_engine = TftrtEngine(frozen_graph, batch_size, 'FP16')
t0 = time.time()
y_tftrt = tftrt_engine.infer(x_test)
print(y_tftrt)
t1 = time.time()
elif inference_type == 'INT8':
tftrt_engine = TftrtEngine(frozen_graph, batch_size, 'INT8')
t0 = time.time()
y_tftrt = tftrt_engine.infer(x_test)
print(y_tftrt)
t1 = time.time()
else:
raise Exception("Invalid inference_type")
print('Time', t1 - t0)
if __name__ == '__main__':
plac.call(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment