Created
March 21, 2024 01:29
-
-
Save justinchuby/e2c9a1ed518507833d6f38942047d498 to your computer and use it in GitHub Desktop.
ONNX Tape
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Convenience methods for constructing (and manipulating?) the IR.""" | |
from __future__ import annotations | |
import collections.abc | |
from typing import Any, Mapping, Sequence | |
from onnxrewriter.experimental.intermediate_representation import _ir | |
def _convert_attributes(attrs: Mapping[str, Any]) -> list[_ir.Attr]: | |
attributes = [] | |
for name, attr in attrs.items(): | |
if isinstance(attr, int): | |
attributes.append(_ir.AttrInt64(name, attr)) | |
elif isinstance(attr, float): | |
attributes.append(_ir.AttrFloat32(name, attr)) | |
elif isinstance(attr, str): | |
attributes.append(_ir.AttrString(name, attr)) | |
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr): | |
attributes.append(_ir.AttrInt64s(name, attr)) | |
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr): | |
attributes.append(_ir.AttrFloat32s(name, attr)) | |
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr): | |
attributes.append(_ir.AttrStrings(name, attr)) | |
elif isinstance(attr, _ir.Attr): | |
attributes.append(attr) | |
else: | |
raise TypeError(f"Unsupported attribute type: '{type(attr)}'") | |
return attributes | |
class Tape(collections.abc.Iterable[_ir.Node]): | |
"""A tape for recording nodes that are created.""" | |
def __init__(self) -> None: | |
self._nodes = [] | |
def __iter__(self) -> Sequence[_ir.Node]: | |
return self._nodes | |
@property | |
def nodes(self) -> Sequence[_ir.Node]: | |
return tuple(self._nodes) | |
def op( | |
self, | |
op_type: str, | |
inputs: Sequence[_ir.Value | None], | |
attributes: Mapping[str, Any] | None = None, | |
domain: str = "", | |
) -> _ir.Value: | |
if attributes is None: | |
attrs = () | |
else: | |
attrs = _convert_attributes(attributes) | |
node = _ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1) | |
self._nodes.append(node) | |
return node.outputs[0] | |
def op_multi_output( | |
self, | |
op_type: str, | |
inputs: Sequence[_ir.Value | None], | |
attributes: Mapping[str, Any] | None = None, | |
*, | |
num_outputs: int, | |
domain: str = "", | |
) -> Sequence[_ir.Value]: | |
if attributes is None: | |
attrs = () | |
else: | |
attrs = _convert_attributes(attributes) | |
node = _ir.Node( | |
domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs | |
) | |
self._nodes.append(node) | |
return node.outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment