Skip to content

Instantly share code, notes, and snippets.

@rmccorm4
Last active May 11, 2020 19:21
Show Gist options
  • Save rmccorm4/531e7cf9bb8f3be0940354fb3085696f to your computer and use it in GitHub Desktop.
Save rmccorm4/531e7cf9bb8f3be0940354fb3085696f to your computer and use it in GitHub Desktop.
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network(EXPLICIT_BATCH) as network, \
builder.create_builder_config() as config, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
# Fill network attributes with information by parsing model
with open("alexnet_dynamic.onnx", "rb") as f:
# Parse model and capture its exit status
parse_success = parser.parse(f.read())
# Catch any errors thrown while parsing and exit gracefully on failure
if not parse_success:
for error in range(parser.num_errors):
print(parser.get_error(error))
sys.exit(1)
# Query input names and shapes from parsed TensorRT network
network_inputs = [network.get_input(i) for i in range(network.num_inputs)]
input_names = [_input.name for _input in network_inputs] # ex: ["actual_input1"]
# Note the original model must have dynamic (-1) dimensions for variable min/opt/max values
# in your profile dimensions such as the batch dimension in this example
input_shapes = [_input.shape for _input in network_inputs] # ex: [(-1, 3, 224, 224)]
max_batch_size = 32
# Create optimization profile for dynamic batch dimension
profile0 = builder.create_optimization_profile()
for name, shape in zip(input_names, input_shapes):
profile0.set_shape(
name, min=(1, *shape[1:]), opt=(max_batch_size, *shape[1:]), max=(max_batch_size, *shape[1:])
)
config.add_optimization_profile(profile0)
# Additional builder_config flags can be set prior to building the engine
with builder.build_engine(network, config) as engine:
# Serialize our engine to a file for future use
with open("alexnet_dynamic.engine", "wb") as f:
f.write(engine.serialize())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment