This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import numpy as np | |
import sys | |
import time | |
import plac | |
class FrozenGraph(object): | |
def __init__(self, model, shape): | |
import tensorflow.keras as keras | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras import backend as K | |
import tensorflow as tf | |
from tensorflow.contrib import tensorrt as tftrt | |
shape = (None, shape[0], shape[1], shape[2]) | |
x_name = 'image_tensor_x' | |
with K.get_session() as sess: | |
x_tensor = tf.placeholder(tf.float32, shape, x_name) | |
K.set_learning_phase(0) | |
y_tensor = model(x_tensor) | |
y_name = y_tensor.name[:-2] | |
graph = sess.graph.as_graph_def() | |
graph0 = tf.graph_util.convert_variables_to_constants(sess, graph, [y_name]) | |
graph1 = tf.graph_util.remove_training_nodes(graph0) | |
self.x_name = [x_name] | |
self.y_name = [y_name] | |
self.frozen = graph1 | |
class TfEngine(object): | |
def __init__(self, graph): | |
import tensorflow.keras as keras | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras import backend as K | |
import tensorflow as tf | |
from tensorflow.contrib import tensorrt as tftrt | |
g = tf.Graph() | |
with g.as_default(): | |
x_op, y_op = tf.import_graph_def( | |
graph_def=graph.frozen, return_elements=graph.x_name + graph.y_name) | |
self.x_tensor = x_op.outputs[0] | |
self.y_tensor = y_op.outputs[0] | |
config = tf.ConfigProto(gpu_options= | |
tf.GPUOptions(per_process_gpu_memory_fraction=0.5, | |
allow_growth=True)) | |
self.sess = tf.Session(graph=g, config=config) | |
def infer(self, x): | |
y = self.sess.run(self.y_tensor, | |
feed_dict={self.x_tensor: x}) | |
return y | |
class TftrtEngine(TfEngine): | |
def __init__(self, graph, batch_size, precision): | |
import tensorflow.keras as keras | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras import backend as K | |
import tensorflow as tf | |
from tensorflow.contrib import tensorrt as tftrt | |
tftrt_graph = tftrt.create_inference_graph( | |
graph.frozen, | |
outputs=graph.y_name, | |
max_batch_size=batch_size, | |
max_workspace_size_bytes=1 << 30, | |
precision_mode=precision, | |
minimum_segment_size=2) | |
self.tftrt_graph = tftrt_graph | |
opt_graph = copy.deepcopy(graph) | |
opt_graph.frozen = tftrt_graph | |
super(TftrtEngine, self).__init__(opt_graph) | |
self.batch_size = batch_size | |
def infer(self, x): | |
num_tests = x.shape[0] | |
y = np.empty((num_tests, 1), np.float32) | |
batch_size = self.batch_size | |
for i in range(0, num_tests, batch_size): | |
x_part = x[i: i + batch_size] | |
y_part = self.sess.run(self.y_tensor, | |
feed_dict={self.x_tensor: x_part}) | |
y[i: i + batch_size] = y_part | |
return y | |
@plac.annotations( | |
inference_type=("Type of inference to test (TF, FP32, FP16, INT8)", 'option', 'T', str), | |
batch_size=("Size of the TensorRT batch", 'option', 'B', int), | |
test_size=("Number of samples run the inference on", 'option', 'S', int), | |
input_dims=("Comma seperate input dimensions ie 224, 224, 3", 'option', 'D', str), | |
model_path=("Saved Keras model", 'positional', None, str) | |
) | |
def main(inference_type: str = "FP16", | |
batch_size: int = 1, | |
test_size: int = 1, | |
input_dims: str = "224, 224, 3", | |
model_path: str = "mobilenet.h5"): | |
import tensorflow.keras as keras | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras import backend as K | |
import tensorflow as tf | |
from tensorflow.contrib import tensorrt as tftrt | |
model = load_model(model_path) | |
model_dims = [int(d) for d in input_dims.split(",")] | |
frozen_graph = FrozenGraph(model, model_dims) | |
test_dims = [int(d) for d in input_dims.split(",")] | |
test_dims.insert(0, test_size) | |
x_test = np.random.random(test_dims) | |
if inference_type == 'TF': | |
tf_engine = TfEngine(frozen_graph) | |
t0 = time.time() | |
y_tf = tf_engine.infer(x_test) | |
print(y_tf) | |
t1 = time.time() | |
elif inference_type == 'FP32': | |
tftrt_engine = TftrtEngine(frozen_graph, batch_size, 'FP32') | |
t0 = time.time() | |
y_tftrt = tftrt_engine.infer(x_test) | |
print(y_tftrt) | |
t1 = time.time() | |
elif inference_type == 'FP16': | |
tftrt_engine = TftrtEngine(frozen_graph, batch_size, 'FP16') | |
t0 = time.time() | |
y_tftrt = tftrt_engine.infer(x_test) | |
print(y_tftrt) | |
t1 = time.time() | |
elif inference_type == 'INT8': | |
tftrt_engine = TftrtEngine(frozen_graph, batch_size, 'INT8') | |
t0 = time.time() | |
y_tftrt = tftrt_engine.infer(x_test) | |
print(y_tftrt) | |
t1 = time.time() | |
else: | |
raise Exception("Invalid inference_type") | |
print('Time', t1 - t0) | |
if __name__ == '__main__': | |
plac.call(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment