Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created March 21, 2024 01:29
Show Gist options
  • Save justinchuby/e2c9a1ed518507833d6f38942047d498 to your computer and use it in GitHub Desktop.
Save justinchuby/e2c9a1ed518507833d6f38942047d498 to your computer and use it in GitHub Desktop.
ONNX Tape
"""Convenience methods for constructing (and manipulating?) the IR."""
from __future__ import annotations
import collections.abc
from typing import Any, Mapping, Sequence
from onnxrewriter.experimental.intermediate_representation import _ir
def _convert_attributes(attrs: Mapping[str, Any]) -> list[_ir.Attr]:
attributes = []
for name, attr in attrs.items():
if isinstance(attr, int):
attributes.append(_ir.AttrInt64(name, attr))
elif isinstance(attr, float):
attributes.append(_ir.AttrFloat32(name, attr))
elif isinstance(attr, str):
attributes.append(_ir.AttrString(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
attributes.append(_ir.AttrInt64s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
attributes.append(_ir.AttrFloat32s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
attributes.append(_ir.AttrStrings(name, attr))
elif isinstance(attr, _ir.Attr):
attributes.append(attr)
else:
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
return attributes
class Tape(collections.abc.Iterable[_ir.Node]):
"""A tape for recording nodes that are created."""
def __init__(self) -> None:
self._nodes = []
def __iter__(self) -> Sequence[_ir.Node]:
return self._nodes
@property
def nodes(self) -> Sequence[_ir.Node]:
return tuple(self._nodes)
def op(
self,
op_type: str,
inputs: Sequence[_ir.Value | None],
attributes: Mapping[str, Any] | None = None,
domain: str = "",
) -> _ir.Value:
if attributes is None:
attrs = ()
else:
attrs = _convert_attributes(attributes)
node = _ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1)
self._nodes.append(node)
return node.outputs[0]
def op_multi_output(
self,
op_type: str,
inputs: Sequence[_ir.Value | None],
attributes: Mapping[str, Any] | None = None,
*,
num_outputs: int,
domain: str = "",
) -> Sequence[_ir.Value]:
if attributes is None:
attrs = ()
else:
attrs = _convert_attributes(attributes)
node = _ir.Node(
domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs
)
self._nodes.append(node)
return node.outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment