Skip to content

Instantly share code, notes, and snippets.

@Lundez
Created May 15, 2022 18:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lundez/4e275ab1803de1eb8420b0ccc1eacab0 to your computer and use it in GitHub Desktop.
Save Lundez/4e275ab1803de1eb8420b0ccc1eacab0 to your computer and use it in GitHub Desktop.
PyTorch to ONNX, PyTorch to JIT Traced TorchScript
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# model = torch.load("model_RTC.pth")
# model.eval()
model = Exception("Add model!")
# https://colab.research.google.com/drive/1TttL1obANbt1hpZfWSifpU4HH0qmU7tj#scrollTo=JLq7Sg-v7i_L
# must be same as expected!
dummy_input = torch.rand(1, 720, 480, 3, requires_grad=True)
trace = torch.jit.trace(model, dummy_input)
torch.jit.save(trace, "trace_model.pt")
# loaded = torch.jit.load('traced_bert.pt')
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# model = torch.load("model_RTC.pth")
# model.eval()
model = Exception("Add model!")
# must be same as expected!
dummy_input = torch.rand(1, 3, 720, 480, requires_grad=True)
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"imdn.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["modelInput"], # the model's input names
output_names=["modelOutput"], # the model's output names
dynamic_axes={
"modelInput": {0: "batch_size", 2: "height", 3: "width"},
"modelOutput": {0: "batch_size", 2: "height", 3: "width"},
},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment