Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created May 23, 2024 18:18
Show Gist options
  • Save justinchuby/871d79234c37f180f573eb78dbf1a408 to your computer and use it in GitHub Desktop.
Save justinchuby/871d79234c37f180f573eb78dbf1a408 to your computer and use it in GitHub Desktop.
Graph Traversal ONNX IR
"""Utilities for traversing the IR graph."""
from __future__ import annotations
__all__ = [
"RecursiveGraphIterator",
]
from typing import Callable, Iterator, Reversible
from typing_extensions import Self
from onnxscript.ir import _core, _enums
class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
def __init__(
self,
graph: _core.Graph | _core.Function | _core.GraphView,
*,
enter_graph_handler: Callable[[_core.Graph | _core.Function | _core.GraphView], None]
| None = None,
exit_graph_handler: Callable[[_core.Graph | _core.Function | _core.GraphView], None]
| None = None,
recursive: Callable[[_core.Node], bool] | None = None,
reverse: bool = False,
):
"""Iterate over the nodes in the graph, recursively visiting subgraphs.
Args:
graph: The graph to traverse.
enter_graph_handler: A callback that is called when a subgraph is entered.
exit_graph_handler: A callback that is called when a subgraph is exited.
recursive: A callback that determines whether to recursively visit a node. If
not provided, all nodes are visited.
reverse: Whether to iterate in reverse order.
"""
self._graph = graph
self._enter_graph_handler = enter_graph_handler
self._exit_graph_handler = exit_graph_handler
self._recursive = recursive
self._reverse = reverse
self._iterator = self._recursive_node_iter(graph)
def __iter__(self) -> Self:
self._iterator = self._recursive_node_iter(self._graph)
return self
def __next__(self) -> _core.Node:
return next(self._iterator)
def _recursive_node_iter(
self, graph: _core.Graph | _core.Function | _core.GraphView
) -> Iterator[_core.Node]:
if self._enter_graph_handler is not None:
self._enter_graph_handler(graph)
iterable = reversed(graph) if self._reverse else graph
for node in iterable: # type: ignore[union-attr]
yield node
if self._recursive is not None and not self._recursive(node):
continue
yield from self._iterate_subgraphs(node)
if self._exit_graph_handler is not None:
self._exit_graph_handler(graph)
def _iterate_subgraphs(self, node: _core.Node):
iterator = (
reversed(node.attributes.values()) if self._reverse else node.attributes.values()
)
for attr in iterator:
if not isinstance(attr, _core.Attr):
continue
if attr.type == _enums.AttributeType.GRAPH:
yield from RecursiveGraphIterator(
attr.value,
enter_graph_handler=self._enter_graph_handler,
exit_graph_handler=self._exit_graph_handler,
recursive=self._recursive,
reverse=self._reverse,
)
elif attr.type == _enums.AttributeType.GRAPHS:
graphs = reversed(attr.value) if self._reverse else attr.value
for graph in graphs:
yield from RecursiveGraphIterator(
graph,
enter_graph_handler=self._enter_graph_handler,
exit_graph_handler=self._exit_graph_handler,
recursive=self._recursive,
reverse=self._reverse,
)
def __reversed__(self) -> Iterator[_core.Node]:
return RecursiveGraphIterator(
self._graph,
enter_graph_handler=self._enter_graph_handler,
exit_graph_handler=self._exit_graph_handler,
recursive=self._recursive,
reverse=not self._reverse,
)
import unittest
import parameterized
from onnxscript import ir
from onnxscript.ir import traversal
class RecursiveGraphIteratorTest(unittest.TestCase):
def setUp(self):
self.graph = ir.Graph(
[],
[],
nodes=[
ir.Node("", "Node1", []),
ir.Node("", "Node2", []),
ir.Node(
"",
"If",
[],
attributes=[
ir.AttrGraph(
"then_branch",
ir.Graph(
[],
[],
nodes=[ir.Node("", "Node3", []), ir.Node("", "Node4", [])],
name="then_graph",
),
),
ir.AttrGraph(
"else_branch",
ir.Graph(
[],
[],
nodes=[ir.Node("", "Node5", []), ir.Node("", "Node6", [])],
name="else_graph",
),
),
],
),
],
name="main_graph",
)
@parameterized.parameterized.expand(
[
("forward", False, ("Node1", "Node2", "If", "Node3", "Node4", "Node5", "Node6")),
("reversed", True, ("If", "Node6", "Node5", "Node4", "Node3", "Node2", "Node1")),
]
)
def test_recursive_graph_iterator(self, _: str, reverse: bool, expected: tuple[str, ...]):
iterator = traversal.RecursiveGraphIterator(self.graph)
if reverse:
iterator = reversed(iterator)
nodes = list(iterator)
self.assertEqual(tuple(node.op_type for node in nodes), expected)
@parameterized.parameterized.expand(
[
("forward", False, ["main_graph", "then_graph", "else_graph"]),
("reversed", True, ["main_graph", "else_graph", "then_graph"]),
]
)
def test_recursive_graph_iterator_enter_graph_handler(
self, _: str, reverse: bool, expected: list[str]
):
scopes = []
def enter_graph_handler(graph):
scopes.append(graph.name)
for __ in traversal.RecursiveGraphIterator(
self.graph, enter_graph_handler=enter_graph_handler, reverse=reverse
):
pass
self.assertEqual(scopes, expected)
@parameterized.parameterized.expand(
[
(
"forward",
False,
[
"then_graph",
"else_graph",
"main_graph",
],
),
("reversed", True, ["else_graph", "then_graph", "main_graph"]),
]
)
def test_recursive_graph_iterator_exit_graph_handler(
self, _: str, reverse: bool, expected: list[str]
):
scopes = []
def exit_graph_handler(graph):
scopes.append(graph.name)
for __ in traversal.RecursiveGraphIterator(
self.graph, exit_graph_handler=exit_graph_handler, reverse=reverse
):
pass
self.assertEqual(scopes, expected)
@parameterized.parameterized.expand(
[
("forward", False, ("Node1", "Node2", "If")),
("reversed", True, ("If", "Node2", "Node1")),
]
)
def test_recursive_graph_iterator_recursive_controls_recursive_behavior(
self, _: str, reverse: bool, expected: list[str]
):
nodes = list(
traversal.RecursiveGraphIterator(
self.graph, recursive=lambda node: node.op_type != "If", reverse=reverse
)
)
self.assertEqual(tuple(node.op_type for node in nodes), expected)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment