Skip to content

Instantly share code, notes, and snippets.

@twmht
Last active August 25, 2021 10:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twmht/f8f73b203ce072d49e058869a2906a97 to your computer and use it in GitHub Desktop.
Save twmht/f8f73b203ce072d49e058869a2906a97 to your computer and use it in GitHub Desktop.
convert onnx to tensorrt
import argparse
import math
import onnx
import os
import tensorrt as trt
from calibration import DatasetCalibrator, RandomCalibrator, DEFAULT_CALIBRATION_ALGORITHM
def parse_args():
parser = argparse.ArgumentParser(description='onnx to tensorrt')
parser.add_argument('onnx_path', help='onnx path')
parser.add_argument('out_path', help='out path')
parser.add_argument('--bs', help='max batch size', type=int)
parser.add_argument('--ws', help='max batch size', type=int, default=0)
parser.add_argument('--width', help='width', type=int, default=0)
parser.add_argument('--height', help='height', type=int, default=0)
# parser.add_argument('--input_name', help='input name', type=str)
parser.add_argument('--fp16', action='store_true', help='fp16 mode')
parser.add_argument('--int8', action='store_true', help='int8 mode')
parser.add_argument('--int8-cfg', default=None, help='int8 mode')
parser.add_argument('--int8-images', default=None, help='int8 image path')
parser.add_argument('--int8-batch-size', default=64, help='int8 batch size')
parser.add_argument('--int8-cache-file', default=None, help='int8 cache file')
parser.add_argument('--int8-batch-runs', default=20, help='int8 batch runs')
args = parser.parse_args()
return args
args = parse_args()
assert(not (args.fp16 and args.int8))
if not os.path.exists(args.out_path):
os.makedirs(args.out_path)
assert((args.width > 0 and args.height > 0) or (args.width == 0 and args.height == 0))
need_resize = (args.width > 0 and args.height > 0)
onnx_path = args.onnx_path
model = onnx.load_model(onnx_path)
input_names = [node.name for node in model.graph.input]
# setting to dynamic batch size
# model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = '?'
# for output in model.graph.output:
# output.type.tensor_type.shape.dim[0].dim_param = '?'
input_shape_dim = model.graph.input[0].type.tensor_type.shape.dim
# if need_resize:
# # assert(args.height % d[2].dim_value == 0)
# # assert(args.width % d[3].dim_value == 0)
# rate = (int(math.ceil(args.height/input_shape_dim[2].dim_value)),int(math.ceil(args.width/input_shape_dim[3].dim_value)))
# input_shape_dim[2].dim_value *= rate[0]
# input_shape_dim[3].dim_value *= rate[1]
# for output in model.graph.output:
# output_shape_dim = output.type.tensor_type.shape.dim
# output_shape_dim[2].dim_value *= rate[0]
# output_shape_dim[3].dim_value *= rate[1]
onnx_bytes = model.SerializeToString()
log_level=trt.Logger.VERBOSE
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.parse(onnx_bytes)
max_batch_size = args.bs
config = builder.create_builder_config()
if args.fp16:
config.flags = 1 << int(trt.BuilderFlag.FP16)
precision = 'fp16'
else:
precision = 'fp32'
if args.int8:
precision = 'int8'
# if args.ws != 0:
config.max_workspace_size = 1<<25
config.min_timing_iterations = 3
config.avg_timing_iterations = 2
profile = builder.create_optimization_profile()
height = input_shape_dim[2].dim_value
width = input_shape_dim[3].dim_value
# profile.set_shape(input_names[0], min=(1, 3, height, width), opt=(max(max_batch_size//2, 1), 3, height, width), max=(max_batch_size, 3, height, width))
profile.set_shape(input_names[0], min=(1, 3, 100, 100), opt=(max(max_batch_size//2, 1), 3, 300, 300), max=(max_batch_size, 3, 2000, 2000))
config.add_optimization_profile(profile)
builder.max_batch_size = max_batch_size
basename = os.path.basename(args.onnx_path)
out_name = os.path.join(args.out_path, basename.replace('.onnx', f'_bs_{args.bs}_h_{height}_w_{width}_{precision}.trt'))
if args.int8:
# builder.int8_mode = True
config.set_flag(trt.BuilderFlag.INT8)
if args.int8_images:
config.int8_calibrator = DatasetCalibrator(
height, width, args.int8_cfg, args.int8_images, args.int8_cache_file if args.int8_cache_file else out_name.replace('.trt', '.cache'), batch_size=args.int8_batch_size, batch_runs=args.int8_batch_runs, algorithm=DEFAULT_CALIBRATION_ALGORITHM
)
else:
config.int8_calibrator = RandomCalibrator(
height, width, args.int8_cache_file if args.int8_cache_file else out_name.replace('.trt', '.cache'), algorithm=DEFAULT_CALIBRATION_ALGORITHM
)
engine = builder.build_engine(network, config)
print (f'save engine file {out_name}')
with open(out_name, 'wb') as f:
f.write(bytearray(engine.serialize()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment