Skip to content

Instantly share code, notes, and snippets.

@maxwillzq
Created July 28, 2020 19:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxwillzq/a594b38e2830476de61e00fc700ad306 to your computer and use it in GitHub Desktop.
Save maxwillzq/a594b38e2830476de61e00fc700ad306 to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
import numpy as np
import onnx
import os
import sys
import json
from google.protobuf.json_format import MessageToJson
from google.protobuf.json_format import Parse
import torch
import torch.onnx
import torchvision
__all__ = ['json2onnx', 'onnx2json', 'torch2onnx']
def json2onnx(json_path, onnx_model_path):
"""convert json file to onnx model file"""
def convert_json_to_onnx(json_path):
with open(json_path) as f:
onnx_json = json.load(f)
# Convert JSON to String
onnx_str = json.dumps(onnx_json)
# Convert String to onnx model
convert_model = Parse(onnx_str, onnx.ModelProto())
return convert_model
if os.path.exists(onnx_model_path):
print("output model file " + onnx_model_path +
" exists. skip the conversion")
return
convert_model = convert_json_to_onnx(json_path)
onnx.save(convert_model, onnx_model_path)
print("successully create onnx from json. new onnx model " + onnx_model_path)
def onnx2json(model_path, json_path):
"""output onnx model to json format"""
def convert_onnx_to_json(model_path):
""" Convert onnx model to JSON """
onnx_model = onnx.load(model_path)
s = MessageToJson(onnx_model)
onnx_json = json.loads(s)
return onnx_json
if os.path.exists(json_path):
print("output file " + json_path +
" exists. skip the conversion")
return
convert_json = convert_onnx_to_json(model_path)
# Convert JSON to file
with open(json_path, 'w') as f:
json.dump(convert_json, f, sort_keys=True, indent=2)
print("successully convert onnx to json " + json_path)
def torch2onnx(torch_model_path, onnx_model_path, dummy_input):
""" convert torch model to onnx model
User can use function torch.rand(*input_shape) to create dummy_input
"""
def get_model(model_name):
if model_name == "vgg19":
return torchvision.models.vgg19(pretrained=True)
elif model_name == "resnet50":
return torchvision.models.resnet50(pretrained=True)
elif model_name == "mobilenet":
return torchvision.models.mobilenet_v2(pretrained=True)
elif model_name == "mnasnet":
return torchvision.models.mnasnet1_0(pretrained=True)
elif model_name == "densenet161":
return torchvision.models.densenet161(pretrained=True)
elif model_name == "googlenet":
return torchvision.models.googlenet(pretrained=True)
elif model_name == "inception":
return torchvision.models.inception_v3(pretrained=True)
elif model_name == "squeezenet":
return torchvision.models.squeezenet1_0(pretrained=True)
elif model_name == "shufflenet":
return torchvision.models.shufflenet_v2_x1_0(pretrained=True)
else:
try:
return torch.load(model_name)
except:
print(f"fail to to load torch model {model_name}, check if file \"{model_name}\" exists")
sys.exit()
if os.path.exists(onnx_model_path):
print("output model file " + onnx_model_path + " exists. skip the conversion")
return
model = get_model(torch_model_path)
torch.onnx.export(model, dummy_input, onnx_model_path)
print("successully convert torch to onnx model " + onnx_model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment