Skip to content

Instantly share code, notes, and snippets.

@r7vme
Created July 8, 2019 17:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save r7vme/deb3b6f713aeb1698593e8f8373b8361 to your computer and use it in GitHub Desktop.
Save r7vme/deb3b6f713aeb1698593e8f8373b8361 to your computer and use it in GitHub Desktop.
convert 3dbox network to tensorrt
#!/usr/bin/env python3
import graphsurgeon as gs
import tensorflow as tf
import tensorrt as trt
import uff
if __name__ == "__main__":
### USER DEFINED VARIABLES ###
data_type = trt.DataType.HALF
#data_type = trt.DataType.FLOAT
output_nodes = ["dimension/LeakyRelu","reshape_2/Reshape","confidence/Softmax"]
input_node = "input_1"
graph_pb = "box3d.pb"
engine_file = "box3d.engine"
### END USER DEFINED VARIABLES ###
dynamic_graph = gs.DynamicGraph(graph_pb)
# replace LeakyRelu wiht LReLU_TRT plugin
nodes=[n.name for n in dynamic_graph.as_graph_def().node]
ns={}
for node in nodes:
if "LeakyRelu" in node:
ns[node]=gs.create_plugin_node(name=node,op="LReLU_TRT", negSlope=0.1)
if "orientation/l2_normalize" in node:
dynamic_graph.remove(node)
dynamic_graph.collapse_namespaces(ns)
# convert to UFF
uff_model = uff.from_tensorflow(dynamic_graph.as_graph_def(), output_nodes=output_nodes)
print("converted to UFF")
# convert to TRT
G_LOGGER = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(G_LOGGER, "")
builder = trt.Builder(G_LOGGER)
builder.max_batch_size = 16
builder.max_workspace_size = 1 << 30
if data_type==trt.DataType.HALF:
builder.fp16_mode=True
network = builder.create_network()
print("network created")
parser = trt.UffParser()
parser.register_input(input_node, trt.Dims([3, 224, 224]))
for output_node in output_nodes:
parser.register_output(output_node)
parser.parse_buffer(uff_model, network, data_type)
print("starting building an engine...")
engine = builder.build_cuda_engine(network)
print("finished building an engine...")
with open(engine_file, "wb") as f:
f.write(engine.serialize())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment