Skip to content

Instantly share code, notes, and snippets.

@rmccorm4
Last active August 2, 2020 11:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rmccorm4/b72abac18aed6be4c1725db18eba4930 to your computer and use it in GitHub Desktop.
Save rmccorm4/b72abac18aed6be4c1725db18eba4930 to your computer and use it in GitHub Desktop.
import argparse
import torch
import torchvision
parser = argparse.ArgumentParser()
parser.add_argument("--opset", type=int, default=11, help="ONNX opset version to generate models with.")
args = parser.parse_args()
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
input_names = [ "actual_input_1" ] #+ [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
# Fixed Shape
torch.onnx.export(model, dummy_input, "alexnet_fixed.onnx", verbose=True, opset_version=args.opset,
input_names=input_names, output_names=output_names)
# Dynamic Shape
dynamic_axes = {"actual_input_1":{0:"batch_size"}, "output1":{0:"batch_size"}}
print(dynamic_axes)
torch.onnx.export(model, dummy_input, "alexnet_dynamic.onnx", verbose=True, opset_version=args.opset,
input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment