Created
May 23, 2024 18:18
-
-
Save justinchuby/871d79234c37f180f573eb78dbf1a408 to your computer and use it in GitHub Desktop.
Graph Traversal ONNX IR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Utilities for traversing the IR graph.""" | |
from __future__ import annotations | |
__all__ = [ | |
"RecursiveGraphIterator", | |
] | |
from typing import Callable, Iterator, Reversible | |
from typing_extensions import Self | |
from onnxscript.ir import _core, _enums | |
class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]): | |
def __init__( | |
self, | |
graph: _core.Graph | _core.Function | _core.GraphView, | |
*, | |
enter_graph_handler: Callable[[_core.Graph | _core.Function | _core.GraphView], None] | |
| None = None, | |
exit_graph_handler: Callable[[_core.Graph | _core.Function | _core.GraphView], None] | |
| None = None, | |
recursive: Callable[[_core.Node], bool] | None = None, | |
reverse: bool = False, | |
): | |
"""Iterate over the nodes in the graph, recursively visiting subgraphs. | |
Args: | |
graph: The graph to traverse. | |
enter_graph_handler: A callback that is called when a subgraph is entered. | |
exit_graph_handler: A callback that is called when a subgraph is exited. | |
recursive: A callback that determines whether to recursively visit a node. If | |
not provided, all nodes are visited. | |
reverse: Whether to iterate in reverse order. | |
""" | |
self._graph = graph | |
self._enter_graph_handler = enter_graph_handler | |
self._exit_graph_handler = exit_graph_handler | |
self._recursive = recursive | |
self._reverse = reverse | |
self._iterator = self._recursive_node_iter(graph) | |
def __iter__(self) -> Self: | |
self._iterator = self._recursive_node_iter(self._graph) | |
return self | |
def __next__(self) -> _core.Node: | |
return next(self._iterator) | |
def _recursive_node_iter( | |
self, graph: _core.Graph | _core.Function | _core.GraphView | |
) -> Iterator[_core.Node]: | |
if self._enter_graph_handler is not None: | |
self._enter_graph_handler(graph) | |
iterable = reversed(graph) if self._reverse else graph | |
for node in iterable: # type: ignore[union-attr] | |
yield node | |
if self._recursive is not None and not self._recursive(node): | |
continue | |
yield from self._iterate_subgraphs(node) | |
if self._exit_graph_handler is not None: | |
self._exit_graph_handler(graph) | |
def _iterate_subgraphs(self, node: _core.Node): | |
iterator = ( | |
reversed(node.attributes.values()) if self._reverse else node.attributes.values() | |
) | |
for attr in iterator: | |
if not isinstance(attr, _core.Attr): | |
continue | |
if attr.type == _enums.AttributeType.GRAPH: | |
yield from RecursiveGraphIterator( | |
attr.value, | |
enter_graph_handler=self._enter_graph_handler, | |
exit_graph_handler=self._exit_graph_handler, | |
recursive=self._recursive, | |
reverse=self._reverse, | |
) | |
elif attr.type == _enums.AttributeType.GRAPHS: | |
graphs = reversed(attr.value) if self._reverse else attr.value | |
for graph in graphs: | |
yield from RecursiveGraphIterator( | |
graph, | |
enter_graph_handler=self._enter_graph_handler, | |
exit_graph_handler=self._exit_graph_handler, | |
recursive=self._recursive, | |
reverse=self._reverse, | |
) | |
def __reversed__(self) -> Iterator[_core.Node]: | |
return RecursiveGraphIterator( | |
self._graph, | |
enter_graph_handler=self._enter_graph_handler, | |
exit_graph_handler=self._exit_graph_handler, | |
recursive=self._recursive, | |
reverse=not self._reverse, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import parameterized | |
from onnxscript import ir | |
from onnxscript.ir import traversal | |
class RecursiveGraphIteratorTest(unittest.TestCase): | |
def setUp(self): | |
self.graph = ir.Graph( | |
[], | |
[], | |
nodes=[ | |
ir.Node("", "Node1", []), | |
ir.Node("", "Node2", []), | |
ir.Node( | |
"", | |
"If", | |
[], | |
attributes=[ | |
ir.AttrGraph( | |
"then_branch", | |
ir.Graph( | |
[], | |
[], | |
nodes=[ir.Node("", "Node3", []), ir.Node("", "Node4", [])], | |
name="then_graph", | |
), | |
), | |
ir.AttrGraph( | |
"else_branch", | |
ir.Graph( | |
[], | |
[], | |
nodes=[ir.Node("", "Node5", []), ir.Node("", "Node6", [])], | |
name="else_graph", | |
), | |
), | |
], | |
), | |
], | |
name="main_graph", | |
) | |
@parameterized.parameterized.expand( | |
[ | |
("forward", False, ("Node1", "Node2", "If", "Node3", "Node4", "Node5", "Node6")), | |
("reversed", True, ("If", "Node6", "Node5", "Node4", "Node3", "Node2", "Node1")), | |
] | |
) | |
def test_recursive_graph_iterator(self, _: str, reverse: bool, expected: tuple[str, ...]): | |
iterator = traversal.RecursiveGraphIterator(self.graph) | |
if reverse: | |
iterator = reversed(iterator) | |
nodes = list(iterator) | |
self.assertEqual(tuple(node.op_type for node in nodes), expected) | |
@parameterized.parameterized.expand( | |
[ | |
("forward", False, ["main_graph", "then_graph", "else_graph"]), | |
("reversed", True, ["main_graph", "else_graph", "then_graph"]), | |
] | |
) | |
def test_recursive_graph_iterator_enter_graph_handler( | |
self, _: str, reverse: bool, expected: list[str] | |
): | |
scopes = [] | |
def enter_graph_handler(graph): | |
scopes.append(graph.name) | |
for __ in traversal.RecursiveGraphIterator( | |
self.graph, enter_graph_handler=enter_graph_handler, reverse=reverse | |
): | |
pass | |
self.assertEqual(scopes, expected) | |
@parameterized.parameterized.expand( | |
[ | |
( | |
"forward", | |
False, | |
[ | |
"then_graph", | |
"else_graph", | |
"main_graph", | |
], | |
), | |
("reversed", True, ["else_graph", "then_graph", "main_graph"]), | |
] | |
) | |
def test_recursive_graph_iterator_exit_graph_handler( | |
self, _: str, reverse: bool, expected: list[str] | |
): | |
scopes = [] | |
def exit_graph_handler(graph): | |
scopes.append(graph.name) | |
for __ in traversal.RecursiveGraphIterator( | |
self.graph, exit_graph_handler=exit_graph_handler, reverse=reverse | |
): | |
pass | |
self.assertEqual(scopes, expected) | |
@parameterized.parameterized.expand( | |
[ | |
("forward", False, ("Node1", "Node2", "If")), | |
("reversed", True, ("If", "Node2", "Node1")), | |
] | |
) | |
def test_recursive_graph_iterator_recursive_controls_recursive_behavior( | |
self, _: str, reverse: bool, expected: list[str] | |
): | |
nodes = list( | |
traversal.RecursiveGraphIterator( | |
self.graph, recursive=lambda node: node.op_type != "If", reverse=reverse | |
) | |
) | |
self.assertEqual(tuple(node.op_type for node in nodes), expected) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment