Skip to content

Instantly share code, notes, and snippets.

@Shiina18
Last active February 7, 2023 05:17
Show Gist options
  • Save Shiina18/d050cf514b2a76928b21d68618855ac1 to your computer and use it in GitHub Desktop.
Save Shiina18/d050cf514b2a76928b21d68618855ac1 to your computer and use it in GitHub Desktop.
from typing import Union, Tuple, List, Dict
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
def torch2onnx(
model: nn.Module,
args: Union[Tuple[torch.Tensor], torch.Tensor],
onnx_path: str,
input_names: List[str],
output_names: List[str],
dynamic_axes: Dict[str, Union[Dict[int, str], List[int]]],
opset_version: int,
**kwargs,
):
device = 'cpu'
model.to(device)
if not isinstance(args, tuple):
args = (args,)
args = tuple(arg.to(device) for arg in args)
torch.onnx.export(
model=model,
args=args,
f=onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
**kwargs,
)
model.eval()
with torch.no_grad():
torch_outputs = model(*args).numpy()
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model, full_check=True)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# model serialization after graph optimization
sess_options.optimized_model_filepath = onnx_path
ort_sess = ort.InferenceSession(
path_or_bytes=onnx_path,
sess_options=sess_options,
providers=['CPUExecutionProvider'],
)
input_dict = {
input_name: args[i].numpy() for i, input_name in enumerate(input_names)
}
# Assume only one output
onnx_outputs = ort_sess.run(None, input_dict)[0]
if not np.allclose(
torch_outputs, onnx_outputs,
rtol=1e-3, atol=1e-5, equal_nan=True,
):
print("Outputs from Torch and ONNX don't match")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment