Skip to content

Instantly share code, notes, and snippets.

@renxida
Last active April 10, 2024 18:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save renxida/6e859dbfab286916dd8b99542c0a2332 to your computer and use it in GitHub Desktop.
Save renxida/6e859dbfab286916dd8b99542c0a2332 to your computer and use it in GitHub Desktop.
import onnx
import numpy as np
from onnx import numpy_helper, TensorProto, save_model
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
from onnx.checker import check_model
# condition has to be a float tensor
condition = make_tensor_value_info('condition', TensorProto.FLOAT, [1])
input1 = make_tensor_value_info('input1', TensorProto.FLOAT, [1])
input2 = make_tensor_value_info('input2', TensorProto.FLOAT, [1])
output = make_tensor_value_info('output', TensorProto.FLOAT, [1])
then_branch = make_graph(
nodes=[
make_node('Add', ['input1', 'input2'], ['output'])
],
name='then_branch',
inputs=[input1, input2],
outputs=[output]
)
else_branch = make_graph(
nodes=[
make_node('Sub', ['input1', 'input2'], ['output'])
],
name='else_branch',
inputs=[input1, input2],
outputs=[output]
)
graph = make_graph(
nodes=[
make_node('If', ['condition'], ['output'], then_branch=then_branch, else_branch=else_branch)
],
name='if_example',
inputs=[condition],
outputs=[output]
)
model = make_model(graph, producer_name='conditional_example')
# Check the model and save it
check_model(model)
save_model(model, 'conditional_example.onnx')
# Load the model and check it
model = onnx.load('conditional_example.onnx')
check_model(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment