Skip to content

Instantly share code, notes, and snippets.

@Rishit-dagli
Created October 16, 2020 04:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Rishit-dagli/c44d50ed5094a60b0f061579dfec0cd6 to your computer and use it in GitHub Desktop.
Save Rishit-dagli/c44d50ed5094a60b0f061579dfec0cd6 to your computer and use it in GitHub Desktop.
Convert PyTorch model to ONNX to TF 2 SavedModel
# Use torch.onnx.export function
model_pytorch = SimpleModel(input_size=input_size, hidden_sizes=hidden_sizes, output_size=output_size)
model_pytorch.load_state_dict(torch.load('./models/model_simple.pt'))
sample_input = torch.from_numpy(X_test[0].reshape(1, -1)).float().to(device)
sample_output = model_pytorch(dummy_input)
# Export to ONNX format
torch.onnx.export(model_pytorch, sample_input, './models/model_simple.onnx', input_names=['test_input'], output_names=['test_output'])
# Load ONNX model and convert to TensorFlow format
model_onnx = onnx.load('./models/model_simple.onnx')
tf_rep = prepare(model_onnx)
# Export model as .pb file
tf_rep.export_graph('./models/model_simple.pb')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment