Created
July 28, 2020 19:46
-
-
Save maxwillzq/a594b38e2830476de61e00fc700ad306 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import numpy as np | |
import onnx | |
import os | |
import sys | |
import json | |
from google.protobuf.json_format import MessageToJson | |
from google.protobuf.json_format import Parse | |
import torch | |
import torch.onnx | |
import torchvision | |
__all__ = ['json2onnx', 'onnx2json', 'torch2onnx'] | |
def json2onnx(json_path, onnx_model_path): | |
"""convert json file to onnx model file""" | |
def convert_json_to_onnx(json_path): | |
with open(json_path) as f: | |
onnx_json = json.load(f) | |
# Convert JSON to String | |
onnx_str = json.dumps(onnx_json) | |
# Convert String to onnx model | |
convert_model = Parse(onnx_str, onnx.ModelProto()) | |
return convert_model | |
if os.path.exists(onnx_model_path): | |
print("output model file " + onnx_model_path + | |
" exists. skip the conversion") | |
return | |
convert_model = convert_json_to_onnx(json_path) | |
onnx.save(convert_model, onnx_model_path) | |
print("successully create onnx from json. new onnx model " + onnx_model_path) | |
def onnx2json(model_path, json_path): | |
"""output onnx model to json format""" | |
def convert_onnx_to_json(model_path): | |
""" Convert onnx model to JSON """ | |
onnx_model = onnx.load(model_path) | |
s = MessageToJson(onnx_model) | |
onnx_json = json.loads(s) | |
return onnx_json | |
if os.path.exists(json_path): | |
print("output file " + json_path + | |
" exists. skip the conversion") | |
return | |
convert_json = convert_onnx_to_json(model_path) | |
# Convert JSON to file | |
with open(json_path, 'w') as f: | |
json.dump(convert_json, f, sort_keys=True, indent=2) | |
print("successully convert onnx to json " + json_path) | |
def torch2onnx(torch_model_path, onnx_model_path, dummy_input): | |
""" convert torch model to onnx model | |
User can use function torch.rand(*input_shape) to create dummy_input | |
""" | |
def get_model(model_name): | |
if model_name == "vgg19": | |
return torchvision.models.vgg19(pretrained=True) | |
elif model_name == "resnet50": | |
return torchvision.models.resnet50(pretrained=True) | |
elif model_name == "mobilenet": | |
return torchvision.models.mobilenet_v2(pretrained=True) | |
elif model_name == "mnasnet": | |
return torchvision.models.mnasnet1_0(pretrained=True) | |
elif model_name == "densenet161": | |
return torchvision.models.densenet161(pretrained=True) | |
elif model_name == "googlenet": | |
return torchvision.models.googlenet(pretrained=True) | |
elif model_name == "inception": | |
return torchvision.models.inception_v3(pretrained=True) | |
elif model_name == "squeezenet": | |
return torchvision.models.squeezenet1_0(pretrained=True) | |
elif model_name == "shufflenet": | |
return torchvision.models.shufflenet_v2_x1_0(pretrained=True) | |
else: | |
try: | |
return torch.load(model_name) | |
except: | |
print(f"fail to to load torch model {model_name}, check if file \"{model_name}\" exists") | |
sys.exit() | |
if os.path.exists(onnx_model_path): | |
print("output model file " + onnx_model_path + " exists. skip the conversion") | |
return | |
model = get_model(torch_model_path) | |
torch.onnx.export(model, dummy_input, onnx_model_path) | |
print("successully convert torch to onnx model " + onnx_model_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment