Skip to content

Instantly share code, notes, and snippets.

@dsuess
Last active March 28, 2019 23:55
Show Gist options
  • Save dsuess/bd4f3385451241a48338c0e01f74d4fc to your computer and use it in GitHub Desktop.
Save dsuess/bd4f3385451241a48338c0e01f74d4fc to your computer and use it in GitHub Desktop.
An example of a failed import of a (trivial) ONNX model in TensorRT
import torch
from torch import nn
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
class Model(nn.Module):
def forward(self, x):
y = (2 * x)[0:1]
return y
print('TensorRT version:', trt.__version__)
model = Model().eval()
dummy_input = torch.randn(4, 4, 4)
with torch.no_grad():
torch.onnx.export(model, dummy_input, 'test.onnx', verbose=True)
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network() as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
with open('test.onnx', 'rb') as model:
success = parser.parse(model.read())
assert success, f'{parser.num_errors} detected during parsing'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment