-
-
Save twmht/f8f73b203ce072d49e058869a2906a97 to your computer and use it in GitHub Desktop.
convert onnx to tensorrt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import math | |
import onnx | |
import os | |
import tensorrt as trt | |
from calibration import DatasetCalibrator, RandomCalibrator, DEFAULT_CALIBRATION_ALGORITHM | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='onnx to tensorrt') | |
parser.add_argument('onnx_path', help='onnx path') | |
parser.add_argument('out_path', help='out path') | |
parser.add_argument('--bs', help='max batch size', type=int) | |
parser.add_argument('--ws', help='max batch size', type=int, default=0) | |
parser.add_argument('--width', help='width', type=int, default=0) | |
parser.add_argument('--height', help='height', type=int, default=0) | |
# parser.add_argument('--input_name', help='input name', type=str) | |
parser.add_argument('--fp16', action='store_true', help='fp16 mode') | |
parser.add_argument('--int8', action='store_true', help='int8 mode') | |
parser.add_argument('--int8-cfg', default=None, help='int8 mode') | |
parser.add_argument('--int8-images', default=None, help='int8 image path') | |
parser.add_argument('--int8-batch-size', default=64, help='int8 batch size') | |
parser.add_argument('--int8-cache-file', default=None, help='int8 cache file') | |
parser.add_argument('--int8-batch-runs', default=20, help='int8 batch runs') | |
args = parser.parse_args() | |
return args | |
args = parse_args() | |
assert(not (args.fp16 and args.int8)) | |
if not os.path.exists(args.out_path): | |
os.makedirs(args.out_path) | |
assert((args.width > 0 and args.height > 0) or (args.width == 0 and args.height == 0)) | |
need_resize = (args.width > 0 and args.height > 0) | |
onnx_path = args.onnx_path | |
model = onnx.load_model(onnx_path) | |
input_names = [node.name for node in model.graph.input] | |
# setting to dynamic batch size | |
# model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = '?' | |
# for output in model.graph.output: | |
# output.type.tensor_type.shape.dim[0].dim_param = '?' | |
input_shape_dim = model.graph.input[0].type.tensor_type.shape.dim | |
# if need_resize: | |
# # assert(args.height % d[2].dim_value == 0) | |
# # assert(args.width % d[3].dim_value == 0) | |
# rate = (int(math.ceil(args.height/input_shape_dim[2].dim_value)),int(math.ceil(args.width/input_shape_dim[3].dim_value))) | |
# input_shape_dim[2].dim_value *= rate[0] | |
# input_shape_dim[3].dim_value *= rate[1] | |
# for output in model.graph.output: | |
# output_shape_dim = output.type.tensor_type.shape.dim | |
# output_shape_dim[2].dim_value *= rate[0] | |
# output_shape_dim[3].dim_value *= rate[1] | |
onnx_bytes = model.SerializeToString() | |
log_level=trt.Logger.VERBOSE | |
logger = trt.Logger(log_level) | |
builder = trt.Builder(logger) | |
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |
parser = trt.OnnxParser(network, logger) | |
parser.parse(onnx_bytes) | |
max_batch_size = args.bs | |
config = builder.create_builder_config() | |
if args.fp16: | |
config.flags = 1 << int(trt.BuilderFlag.FP16) | |
precision = 'fp16' | |
else: | |
precision = 'fp32' | |
if args.int8: | |
precision = 'int8' | |
# if args.ws != 0: | |
config.max_workspace_size = 1<<25 | |
config.min_timing_iterations = 3 | |
config.avg_timing_iterations = 2 | |
profile = builder.create_optimization_profile() | |
height = input_shape_dim[2].dim_value | |
width = input_shape_dim[3].dim_value | |
# profile.set_shape(input_names[0], min=(1, 3, height, width), opt=(max(max_batch_size//2, 1), 3, height, width), max=(max_batch_size, 3, height, width)) | |
profile.set_shape(input_names[0], min=(1, 3, 100, 100), opt=(max(max_batch_size//2, 1), 3, 300, 300), max=(max_batch_size, 3, 2000, 2000)) | |
config.add_optimization_profile(profile) | |
builder.max_batch_size = max_batch_size | |
basename = os.path.basename(args.onnx_path) | |
out_name = os.path.join(args.out_path, basename.replace('.onnx', f'_bs_{args.bs}_h_{height}_w_{width}_{precision}.trt')) | |
if args.int8: | |
# builder.int8_mode = True | |
config.set_flag(trt.BuilderFlag.INT8) | |
if args.int8_images: | |
config.int8_calibrator = DatasetCalibrator( | |
height, width, args.int8_cfg, args.int8_images, args.int8_cache_file if args.int8_cache_file else out_name.replace('.trt', '.cache'), batch_size=args.int8_batch_size, batch_runs=args.int8_batch_runs, algorithm=DEFAULT_CALIBRATION_ALGORITHM | |
) | |
else: | |
config.int8_calibrator = RandomCalibrator( | |
height, width, args.int8_cache_file if args.int8_cache_file else out_name.replace('.trt', '.cache'), algorithm=DEFAULT_CALIBRATION_ALGORITHM | |
) | |
engine = builder.build_engine(network, config) | |
print (f'save engine file {out_name}') | |
with open(out_name, 'wb') as f: | |
f.write(bytearray(engine.serialize())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment