Created
December 10, 2019 21:01
-
-
Save EricCousineau-TRI/74cd51d4e2d7a10d38136d5831893142 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import onnxruntime as ort | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
h = 32 | |
w = 32 | |
c = 1 | |
net = nn.Conv2d(c, c, 3) | |
print(net) | |
with torch.no_grad(): | |
input = torch.randn(1, c, h, w) | |
out = net(input) | |
print(out.shape) | |
# Per https://github.com/pytorch/pytorch/issues/25681, warnings about | |
# bad axis names may just be false positives. | |
torch.onnx.export( | |
net, input, "/tmp/trivial.onnx", | |
input_names=["input"], | |
output_names=["output"], | |
dynamic_axes={ | |
"input": {0: "batch"}, | |
"output": {0: "batch"}, | |
}) | |
without_tensorrt = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
with_tensorrt = ['TensorrtExecutionProvider'] + without_tensorrt | |
ort_session = ort.InferenceSession("/tmp/trivial.onnx") | |
for providers in (with_tensorrt, without_tensorrt): | |
ort_session.set_providers(providers) | |
print(ort_session.get_providers()) | |
for n in [1, 2, 3]: | |
print(f"n: {n}") | |
input = np.random.randn(n, c, h, w).astype(np.float32) | |
out, = ort_session.run(None, {"input": input}) | |
print(out.shape) | |
print("---") | |
""" | |
Output: | |
Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)) | |
torch.Size([1, 1, 30, 30]) | |
... (warnings) | |
['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] | |
n: 1 | |
(1, 1, 30, 30) | |
n: 2 | |
(1, 1, 30, 30) # What? | |
n: 3 | |
(1, 1, 30, 30) # What? | |
--- | |
['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
n: 1 | |
(1, 1, 30, 30) | |
n: 2 | |
(2, 1, 30, 30) | |
n: 3 | |
(3, 1, 30, 30) | |
--- | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment