Skip to content

Instantly share code, notes, and snippets.

@romain-keramitas-prl
Created January 10, 2022 14:45
Show Gist options
  • Save romain-keramitas-prl/bf43254d15055f104e4c63abd504e476 to your computer and use it in GitHub Desktop.
Save romain-keramitas-prl/bf43254d15055f104e4c63abd504e476 to your computer and use it in GitHub Desktop.
Code to reproduct ORT Gather noder discrepancy.
import numpy as np
from onnxruntime import InferenceSession
import torch
import torch.nn as nn
# Create simple model with one Gather node
model = nn.Embedding(num_embeddings=10, embedding_dim=5)
x = torch.randint(0, 10, (3,))
torch.onnx.export(
model,
args=(x,),
f="model.onnx",
input_names=["input_ids"],
output_names=["output"],
opset_version=14,
dynamic_axes={"input_ids": {0: "batch_size"}, "output": {0: "batch_size"}}
)
for provider in ["CPUExecutionProvider", "CUDAExecutionProvider"]:
print(f"testing {provider} ...")
session = InferenceSession("model.onnx", providers=[provider])
io_binding = session.io_binding()
io_binding.bind_cpu_input("input_ids", np.ones(3, dtype=np.int64) * 10)
io_binding.bind_output("output")
try:
session.run_with_iobinding(io_binding)
print(f"{provider} did not raise an error")
except Exception as e:
print(f"{provider} raised the following error:\n{e}")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment