Skip to content

Instantly share code, notes, and snippets.

@jalola
Created September 23, 2018 09:22
Show Gist options
  • Save jalola/dafac4a5ebe29ff2c9c3d598134b1aa0 to your computer and use it in GitHub Desktop.
Save jalola/dafac4a5ebe29ff2c9c3d598134b1aa0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import numpy as np
from onnx_coreml import convert
from torch.autograd import Variable
import torch.onnx
import torchvision
import onnx
def check_onnx_compatible(model, model_name, sz, input_names, output_names):
dummy_input = Variable(torch.randn(3, sz, sz)).cuda()
torch.onnx.export(model, dummy_input, \
model_name, input_names = input_names, output_names = output_names, verbose=True)
# Check again by onnx
# Load the ONNX model
onnx_model = onnx.load(model_name)
# Check that the IR is well formed
onnx.checker.check_model(onnx_model)
# Print a human readable representation of the graph
# onnx.helper.printable_graph(onnx_model.graph)
print("Done")
return onnx_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment