Skip to content

Instantly share code, notes, and snippets.

@paulgavrikov
Created February 22, 2022 19:03
Show Gist options
  • Save paulgavrikov/08a23d592002b0e9ced52873e90994fc to your computer and use it in GitHub Desktop.
Save paulgavrikov/08a23d592002b0e9ced52873e90994fc to your computer and use it in GitHub Desktop.
import torch
SHAPE = ... # your batch shape
OUT_PATH = ... # output path
x = torch.randn(SHAPE)
with torch.no_grad():
if isinstance(model, torch.nn.DataParallel): # extract the module from dataparallel models
model = model.module
model.cpu()
model.eval() # the converter works best on models stored on the CPU
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
OUT_PATH, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model
opset_version=11) # it's best to specify the opset version. At time of writing 11 was the latest release
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment