Skip to content

Instantly share code, notes, and snippets.

@RomanSteinberg
Created July 16, 2019 09:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RomanSteinberg/63ed16aa2e2e4e19c1ad84fdc2b1f551 to your computer and use it in GitHub Desktop.
Save RomanSteinberg/63ed16aa2e2e4e19c1ad84fdc2b1f551 to your computer and use it in GitHub Desktop.
PyTorch -> TensorRT
import tensorrt as trt
import os
import torch
import onnx
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def convert_to_trt(image_width, image_height):
onnx_file_path = 'model.onnx'
if not os.path.exists(onnx_file_path):
convert_to_onnx(image_width, image_height)
engine_file_path = 'model.trt'
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network,
TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 30 # 1GB
builder.max_batch_size = 1
# Parse model file
if not os.path.exists(onnx_file_path):
print('ONNX file {} not found.'.format(onnx_file_path))
exit(0)
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
def convert_to_onnx(image_width, image_height):
model = PytorchModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()
model.to('cpu')
onnx_file_path = 'model.onnx'
dummy_input = torch.randn(1, 3, image_height, image_width)
torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True)
model = onnx.load(onnx_file_path)
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment