Skip to content

Instantly share code, notes, and snippets.

View romain-keramitas-prl's full-sized avatar

Romain Keramitas romain-keramitas-prl

View GitHub Profile
@romain-keramitas-prl
romain-keramitas-prl / test.py
Created January 10, 2022 14:45
Code to reproduct ORT Gather noder discrepancy.
import numpy as np
from onnxruntime import InferenceSession
import torch
import torch.nn as nn
# Create simple model with one Gather node
model = nn.Embedding(num_embeddings=10, embedding_dim=5)
x = torch.randint(0, 10, (3,))
torch.onnx.export(
model,
@romain-keramitas-prl
romain-keramitas-prl / test_ort_error.py
Created December 20, 2021 14:31
Code to reproduce ORT error for cross-attention with dynamic past key / values axis
import os
import numpy as np
from onnx import TensorProto, helper, save
from onnxruntime import InferenceSession
hidden_size = 10
num_heads = 2
head_size = 5