Skip to content

Instantly share code, notes, and snippets.

@adujardin
Forked from rmccorm4/alexnet_onnx.py
Last active August 10, 2020 16:03
Show Gist options
  • Save adujardin/5d0a9ec73aa81c694330d39e1638d512 to your computer and use it in GitHub Desktop.
Save adujardin/5d0a9ec73aa81c694330d39e1638d512 to your computer and use it in GitHub Desktop.
import argparse
import torch
import torchvision
import shutil
# https://github.com/vita-epfl/openpifpaf/blob/master/openpifpaf/export_onnx.py
try:
import onnx
import onnx.utils
except ImportError:
onnx = None
try:
import onnxsim
except ImportError:
onnxsim = None
def optimize(infile, outfile=None):
if outfile is None:
assert infile.endswith('.onnx')
outfile = infile
infile = infile.replace('.onnx', '.unoptimized.onnx')
shutil.copyfile(outfile, infile)
model = onnx.load(infile)
optimized_model = onnx.optimizer.optimize(model)
onnx.save(optimized_model, outfile)
def check(modelfile):
model = onnx.load(modelfile)
onnx.checker.check_model(model)
def polish(infile, outfile=None):
if outfile is None:
assert infile.endswith('.onnx')
outfile = infile
infile = infile.replace('.onnx', '.unpolished.onnx')
shutil.copyfile(outfile, infile)
model = onnx.load(infile)
polished_model = onnx.utils.polish_model(model)
onnx.save(polished_model, outfile)
def simplify(infile, outfile=None):
if outfile is None:
assert infile.endswith('.onnx')
outfile = infile
infile = infile.replace('.onnx', '.unsimplified.onnx')
shutil.copyfile(outfile, infile)
simplified_model = onnxsim.simplify(infile, check_n=0, perform_optimization=False)
onnx.save(simplified_model, outfile)
parser = argparse.ArgumentParser()
parser.add_argument("--opset", type=int, default=11, help="ONNX opset version to generate models with.")
parser.add_argument('--outfile', default='bar.onnx')
parser.add_argument('--dynamic-dimensions', dest='dynamic', default=True, action='store_true')
args = parser.parse_args()
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
input_names = [ "actual_input_1" ] #+ [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
if args.dynamic:
# Dynamic Shape
dynamic_axes = {"actual_input_1":{0:"batch_size"}, "output1":{0:"batch_size"}}
print(dynamic_axes)
torch.onnx.export(model, dummy_input, args.outfile, verbose=True, opset_version=args.opset,
input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes)
else:
# Fixed Shape
torch.onnx.export(model, dummy_input, args.outfile, verbose=True, opset_version=args.opset,
input_names=input_names, output_names=output_names)
if onnx:
if True and onnxsim:
simplify(args.outfile)
if False:
optimize(args.outfile)
if False:
polish(args.outfile)
if True:
check(args.outfile)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment