Skip to content

Instantly share code, notes, and snippets.

@romain-keramitas-prl
Created December 20, 2021 14:31
Show Gist options
  • Save romain-keramitas-prl/b9a1dbb313ce8cb1d53a188b170ed0cc to your computer and use it in GitHub Desktop.
Save romain-keramitas-prl/b9a1dbb313ce8cb1d53a188b170ed0cc to your computer and use it in GitHub Desktop.
Code to reproduce ORT error for cross-attention with dynamic past key / values axis
import os
import numpy as np
from onnx import TensorProto, helper, save
from onnxruntime import InferenceSession
hidden_size = 10
num_heads = 2
head_size = 5
# Create node
attention_input_names = [
"query",
"key",
"q_weight",
"kv_weight",
"qkv_bias",
"key_mask",
"past_key",
"past_value",
"is_cross",
"use_past",
"use_layer",
"has_mask",
]
attention_output_names = ["hidden_state", "present_key", "present_value"]
attention_node = helper.make_node(
"DecoderAttention",
inputs=attention_input_names,
outputs=attention_output_names,
name="cross_attention_node",
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
# Create initializers
tensors = []
q_weight = np.random.randn(hidden_size, hidden_size)
kv_weight = np.random.randn(hidden_size, hidden_size * 2)
qkv_bias = np.random.randn(hidden_size * 3)
for x, name in zip(
[q_weight, kv_weight, qkv_bias], ["q_weight", "kv_weight", "qkv_bias"]
):
tensor_proto = helper.make_tensor(
name=name, data_type=TensorProto.FLOAT, dims=x.shape, vals=x.flatten().tolist()
)
tensors.append(tensor_proto)
for name in [
"is_cross",
"use_past",
"use_layer",
"has_mask",
]:
tensor_proto = helper.make_tensor(
name=name, data_type=TensorProto.BOOL, dims=(), vals=[True]
)
tensors.append(tensor_proto)
# Create inputs with either batch size / sequence length fixed or not
past_kv_shape_1 = ["batch_size", num_heads, "seq_length", head_size]
past_kv_shape_2 = [1, num_heads, 1, head_size]
common_inputs = [
helper.make_tensor_value_info(
"query", TensorProto.FLOAT, shape=["seq_length", "batch_size", hidden_size]
),
helper.make_tensor_value_info(
"key", TensorProto.FLOAT, shape=["seq_length", "batch_size", hidden_size]
),
helper.make_tensor_value_info(
"key_mask", TensorProto.BOOL, shape=["batch_size", "seq_length"]
),
]
inputs_1 = common_inputs + [
helper.make_tensor_value_info(
"past_key",
TensorProto.FLOAT,
shape=past_kv_shape_1,
),
helper.make_tensor_value_info(
"past_value",
TensorProto.FLOAT,
shape=past_kv_shape_1,
),
]
inputs_2 = common_inputs + [
helper.make_tensor_value_info(
"past_key",
TensorProto.FLOAT,
shape=past_kv_shape_2,
),
helper.make_tensor_value_info(
"past_value",
TensorProto.FLOAT,
shape=past_kv_shape_2,
),
]
# Create outputs
outputs = [
helper.make_tensor_value_info(
"hidden_state",
TensorProto.FLOAT,
shape=["seq_length", "batch_size", hidden_size],
),
helper.make_tensor_value_info(
"present_key",
TensorProto.FLOAT,
shape=["batch_size", num_heads, "seq_length", head_size],
),
helper.make_tensor_value_info(
"present_value",
TensorProto.FLOAT,
shape=["batch_size", num_heads, "seq_length", head_size],
),
]
# create and save models
for inputs, name in zip(
[inputs_1, inputs_2], ["test_dynamic_inputs", "test_fixed_inputs"]
):
graph = helper.make_graph(
[attention_node],
initializer=tensors,
name=name,
inputs=inputs,
outputs=outputs,
)
model = helper.make_model(graph)
save(model, f"{name}.onnx")
# test models
for name in ["test_dynamic_inputs", "test_fixed_inputs"]:
try:
InferenceSession(f"{name}.onnx", providers=["CUDAExecutionProvider"])
print(f"{name}: success")
except Exception as e:
print(f"{name}: failed")
print(f"message: {e}")
os.remove(f"{name}.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment