Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Last active April 17, 2024 02:36
Show Gist options
  • Save justinchuby/e69a4c9cb51f7c6ac8fabba6b2d7f3d0 to your computer and use it in GitHub Desktop.
Save justinchuby/e69a4c9cb51f7c6ac8fabba6b2d7f3d0 to your computer and use it in GitHub Desktop.
Remove duplicate cast
from onnxscript import ir
import onnx
model_proto = onnx.load("model.onnx")
# (not const) -> cast to 16 -> cast to 32 -> Op
model = ir.serde.deserialize_model(model_proto)
def is_cast(node: ir.Node, dtype: ir.DataType):
if node.op_type != "Cast":
return False
cast_type = ir.DataType(node.attributes["to"].value)
if cast_type != dtype:
return False
return True
cast_pairs = []
for node in model.graph:
if not is_cast(node, ir.DataType.FLOAT16):
continue
if node.inputs[0].name == "x":
continue
if node.inputs[0].producer().op_type == "Constant":
continue
users = node.outputs[0].consumers()
if len(users) != 1:
continue
next_node, index = users[0]
if is_cast(next_node, ir.DataType.FLOAT):
cast_pairs.append((node, next_node))
for pair in cast_pairs:
node, next_node = pair
for user, index in next_node.outputs[0].consumers():
user.replace_input_with(index, node.inputs[0])
model.graph.remove(node)
model.graph.remove(next_node)
result_proto = ir.serde.serialize_model(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment