Skip to content

Instantly share code, notes, and snippets.

@ryujaehun
Created February 3, 2022 07:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryujaehun/3c914acb83dec7b453ae63b72d79098b to your computer and use it in GitHub Desktop.
Save ryujaehun/3c914acb83dec7b453ae63b72d79098b to your computer and use it in GitHub Desktop.
onnx file generator though PyTorch
import torch
import torchvision.models as models
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("-n", "--network", help="network for onnx file", required=False)
# args = parser.parse_args()
input_size = (1,3,224,224)
dummy_input = torch.randn(*input_size, device="cuda")
MODEL_LIST = {
models.resnet: models.resnet.__all__[1:],
models.densenet: models.densenet.__all__[1:],
models.squeezenet: models.squeezenet.__all__[1:],
models.vgg: models.vgg.__all__[1:],
}
for model_type in MODEL_LIST.keys():
for model_name in MODEL_LIST[model_type]:
model = getattr(model_type, model_name)(pretrained=True).cuda()
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, f"{model_name}.onnx", verbose=True, input_names=input_names, output_names=output_names)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment