Skip to content

Instantly share code, notes, and snippets.

@CasiaFan
Created November 5, 2019 10:03
Show Gist options
  • Save CasiaFan/644b9680c2ad060691e132a8ceca4e12 to your computer and use it in GitHub Desktop.
Save CasiaFan/644b9680c2ad060691e132a8ceca4e12 to your computer and use it in GitHub Desktop.
ONNX model inference with onnx_tensorrt backend
import onnx
import argparse
import onnx_tensorrt.backend as backend
import numpy as np
import time
def main():
parser = argparse.ArgumentParser(description="Onnx runtime engine.")
parser.add_argument(
"--onnx", default="/home/arkenstone/test_face_model/res50/mxnet_exported_mnet.onnx",
metavar="FILE",
help="path to onnx file",
)
parser.add_argument(
"--shape",
default="(1,3,112,112)",
help="input shape for inference",
)
args = parser.parse_args()
model = onnx.load(args.onnx)
engine = backend.prepare(model, device='CUDA:0')
shape_str = args.shape.strip('(').strip(')')
input_shape = []
for item in shape_str.split(','):
input_shape.append(int(item))
input_data = np.random.random(size=input_shape).astype(np.float32)
start = time.time()
cal = []
for i in range(110):
output_data = engine.run(input_data)[0]
cal.append(time.time())
end = time.time()
total_time = end - start
print("Total Runtimetime {:.4f} seconds".format(total_time))
start = cal[10]
Per_time = ( end -start ) / 100.0
print("Per iter runtime: {:.4f} seconds".format(Per_time))
if __name__ == "__main__":
print ("Usage: .... ")
print ("python tensorrt_run.py --onnx your.onnx --shape (1,3,112,112)")
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment