Skip to content

Instantly share code, notes, and snippets.

@pranavsharma
Created March 28, 2023 06:03
Show Gist options
  • Save pranavsharma/2caf24863c22ee4635f263bf766aa86d to your computer and use it in GitHub Desktop.
Save pranavsharma/2caf24863c22ee4635f263bf766aa86d to your computer and use it in GitHub Desktop.
# issue 15206 OptionalGetElement
import numpy as np
import onnx
from onnx import TensorProto
from onnx.helper import (
make_model, make_node, make_graph,
make_tensor_value_info, make_value_info)
from onnx.checker import check_model
import onnxruntime as rt
tensor_type_proto = onnx.helper.make_tensor_type_proto(
elem_type=TensorProto.FLOAT,
shape=[
4,
],
)
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
X = make_value_info('X', optional_type_proto)
Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None])
node = make_node("OptionalGetElement", ["X"], ["Y"])
#print(node)
graph = make_graph([node],
'OptionalGetElement',
[X],
[Y])
opset_id_proto = onnx.helper.make_opsetid("", 18)
onnx_model = make_model(graph, opset_imports=[opset_id_proto])
check_model(onnx_model)
print(onnx_model)
with open("test_model_15206.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
print("Serialized model")
sess = rt.InferenceSession("test_model_15206.onnx", providers=['CPUExecutionProvider'])
optional = np.array([1, 2, 3, 4]).astype(np.float32)
out = sess.run(None, {'X': optional})
print(out)
print("done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment