Skip to content

Instantly share code, notes, and snippets.

@Rishit-dagli
Created October 16, 2020 04:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Rishit-dagli/19b1d80f885fc43fcc7c8e9801b1dcd7 to your computer and use it in GitHub Desktop.
Save Rishit-dagli/19b1d80f885fc43fcc7c8e9801b1dcd7 to your computer and use it in GitHub Desktop.
Convert PyTorch .pt model to ONNX
# Use torch.onnx.export function
model_pytorch = SimpleModel(input_size=input_size, hidden_sizes=hidden_sizes, output_size=output_size)
model_pytorch.load_state_dict(torch.load('./models/model_simple.pt'))
sample_input = torch.from_numpy(X_test[0].reshape(1, -1)).float().to(device)
sample_output = model_pytorch(dummy_input)
# Export to ONNX format
torch.onnx.export(model_pytorch, sample_input, './models/model_simple.onnx', input_names=['test_input'], output_names=['test_output'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment