Last active
April 17, 2024 02:36
-
-
Save justinchuby/e69a4c9cb51f7c6ac8fabba6b2d7f3d0 to your computer and use it in GitHub Desktop.
Remove duplicate cast
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from onnxscript import ir | |
import onnx | |
model_proto = onnx.load("model.onnx") | |
# (not const) -> cast to 16 -> cast to 32 -> Op | |
model = ir.serde.deserialize_model(model_proto) | |
def is_cast(node: ir.Node, dtype: ir.DataType): | |
if node.op_type != "Cast": | |
return False | |
cast_type = ir.DataType(node.attributes["to"].value) | |
if cast_type != dtype: | |
return False | |
return True | |
cast_pairs = [] | |
for node in model.graph: | |
if not is_cast(node, ir.DataType.FLOAT16): | |
continue | |
if node.inputs[0].name == "x": | |
continue | |
if node.inputs[0].producer().op_type == "Constant": | |
continue | |
users = node.outputs[0].consumers() | |
if len(users) != 1: | |
continue | |
next_node, index = users[0] | |
if is_cast(next_node, ir.DataType.FLOAT): | |
cast_pairs.append((node, next_node)) | |
for pair in cast_pairs: | |
node, next_node = pair | |
for user, index in next_node.outputs[0].consumers(): | |
user.replace_input_with(index, node.inputs[0]) | |
model.graph.remove(node) | |
model.graph.remove(next_node) | |
result_proto = ir.serde.serialize_model(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment