-
-
Save jroesch/de9648556b1f7e1bb3db490e0f708813 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import attr | |
from typing import List, Any, Optional | |
from tvm import relay | |
from tvm.ir import Type, IRModule | |
from tvm.relay import Expr, Op | |
from tvm.relay.expr_functor import ExprVisitor | |
class Matcher: | |
def __init__(self): | |
self.input_map = {} | |
def match(self, pattern, expr): | |
if isinstance(pattern, ExprPattern): | |
return self.match_expr(pattern.op, expr) | |
elif isinstance(pattern, CallPattern): | |
return self.match_call(pattern.op, pattern.args, expr) | |
elif isinstance(pattern, AltPattern): | |
return self.match_alt(pattern.left, pattern.right, expr) | |
elif isinstance(pattern, MatchTypePattern): | |
return self.match_type(pattern.pattern, pattern.ty, expr) | |
elif isinstance(pattern, MatchAttrPattern): | |
return self.match_attr(pattern.pattern, pattern.attr_name, pattern.attr_value, expr) | |
elif isinstance(pattern, InputPattern): | |
return self.match_input(pattern.name, expr) | |
elif isinstance(pattern, WildCardPattern): | |
# NB: don't need to process in anyway | |
return True | |
else: | |
raise Exception(f"unsupported {type(pattern)}") | |
def match_expr(self, op, expr): | |
if op == expr: | |
return True | |
else: | |
return False | |
# TODO(@jroesch): we should probably support matching type args and call attributes here | |
def match_call(self, op, args, expr): | |
if isinstance(expr, relay.Call): | |
does_match = self.match(op, expr.op) | |
if does_match: | |
assert len(args) == len(expr.args) | |
for pat, arg in zip(args, expr.args): | |
does_match = does_match and self.match(pat, arg) | |
return does_match | |
else: | |
return False | |
def match_alt(self, lhs, rhs, expr): | |
return self.match(lhs, expr) or self.match(rhs, expr) | |
def match_type(self, pattern, ty, expr): | |
expr = infer_type(expr) | |
if self.match(pattern, expr): | |
return expr.checked_type == ty | |
else: | |
return False | |
def match_attr(self, pattern, attr_name, attr_value, expr): | |
if self.match(pattern, expr) and isinstance(expr, relay.Op): | |
actual_value = expr.get_attr(attr_name) | |
if actual_value == attr_value: | |
return True | |
else: | |
return False | |
else: | |
return False | |
import pdb; pdb.set_trace() | |
def match_input(self, name, expr): | |
# TODO(@jroesch) we should probably generalize this to named groups, | |
# or some kind of feature to name a matched sub-expression then | |
# require the match be identical elsewhere | |
# | |
# will come back to this tomorrow | |
# If it isn't a variable its not an input, assuming graph form. | |
if not isinstance(expr, relay.Var): | |
return False | |
# If name isn't set we can match any input. | |
if name is None: | |
return True | |
# If name is set we must match the same expression in each position. | |
# | |
# We could improve implementaton by tracking all possible matches to | |
# improve error reporting. | |
if name not in self.input_map: | |
self.input_map[name] = expr | |
return | |
else: | |
return self.input_map[name] == expr | |
def infer_type(expr): | |
mod = IRModule.from_expr(expr) | |
mod = relay.transform.InferType()(mod) | |
if not isinstance(expr, relay.Function): | |
return mod["main"].body | |
else: | |
return mod["main"] | |
class Pattern: | |
"""The base class of patterns.""" | |
def __call__(self, *args: Pattern) -> Pattern: | |
return CallPattern(self, list(args)) | |
def __or__(self, other: Pattern) -> Pattern: | |
return AltPattern(self, other) | |
# def __le__(self, other: Pattern) -> Pattern: | |
# """This is probably a crazy idea, but just a PoC of what we could do.""" | |
# return CallPattern() | |
def has_attr(self, attr_name, attr_value) -> Pattern: | |
return MatchAttrPattern(self, attr_name, attr_value) | |
def match(self, expr: Expr) -> bool: | |
matcher = Matcher() | |
return matcher.match(self, expr) | |
@attr.s(auto_attribs=True) | |
class ExprPattern(Pattern): | |
"""A pattern which matches a single operation.""" | |
op: Op | |
@attr.s(auto_attribs=True) | |
class WildCardPattern(Pattern): | |
"""A wildcard pattern, which matches any sub-expression.""" | |
pass | |
@attr.s(auto_attribs=True) | |
class CallPattern(Pattern): | |
"""A pattern which matches a call, containing recursive patterns.""" | |
op: Pattern | |
args: List[Pattern] | |
@attr.s(auto_attribs=True) | |
class MatchTypePattern(Pattern): | |
"""A pattern which only matches if the interior match has a specific type.""" | |
pattern: Pattern | |
# Note potentially match a type with holes in order to do partial type matching. | |
ty: Type | |
@attr.s(auto_attribs=True) | |
class MatchAttrPattern(Pattern): | |
"""An op pattern which only matches if the op's attribute matches.""" | |
pattern: Pattern | |
attr_name: str | |
attr_value: Any | |
@attr.s(auto_attribs=True) | |
class InputPattern(Pattern): | |
name: Optional[str] | |
class WildcardPattern(Pattern): | |
pass | |
@attr.s(auto_attribs=True) | |
class AltPattern(Pattern): | |
left: Pattern | |
right: Pattern | |
@attr.s(auto_attribs=True) | |
class Group(Pattern): | |
pattern: Pattern | |
def is_op(op_name: str) -> Pattern: | |
op = relay.op.op.get(op_name) | |
return ExprPattern(op) | |
def wildcard() -> Pattern: | |
return WildCardPattern() | |
def has_type(ty, pattern=None): | |
if pattern is None: | |
pattern = wildcard() | |
return MatchTypePattern(pattern, ty) | |
def has_attr(attr_name, attr_value, pattern): | |
return MatchAttrPattern(pattern, attr_name, attr_value) | |
def is_input(name=None) -> Pattern: | |
return InputPattern(name) | |
def test_match_op(): | |
assert is_op('add').match(relay.op.op.get("add")) | |
def test_no_match_op(): | |
assert not is_op('add').match(relay.op.op.get("subtract")) | |
def test_match_op_or(): | |
is_add_or_sub = is_op('add') | is_op('subtract') | |
assert is_add_or_sub.match(relay.op.op.get("add")) | |
assert is_add_or_sub.match(relay.op.op.get("subtract")) | |
def test_match_call(): | |
x = relay.var('x') | |
y = relay.var('y') | |
add_pattern = is_op('add')(wildcard(), wildcard()) | |
assert add_pattern.match(x + y) | |
def test_no_match_call(): | |
x = relay.var('x') | |
y = relay.var('y') | |
add_pattern = is_op('add')(wildcard(), wildcard()) | |
assert not add_pattern.match(x - y) | |
def test_match_type(): | |
x = relay.var('x', shape=(10, 10), dtype="float32") | |
ty_pat = has_type(relay.TensorType((10, 10), "float32")) | |
assert ty_pat.match(x) | |
def test_no_match_type(): | |
x = relay.var('x', shape=(10, 10), dtype="int32") | |
ty_pat = has_type(relay.TensorType((10, 10), "float32")) | |
assert not ty_pat.match(x) | |
# NB: 1 corresponds to the C++ enum that specicfies this | |
# we loose the type safety due to the Python/C++ calling | |
# convention. | |
K_ELEMWISE = 1 | |
def test_match_attr(): | |
op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) | |
op_pat = op(wildcard(), wildcard()) | |
x = relay.var('x') | |
y = relay.var('y') | |
assert op_pat.match(x + y) | |
def test_no_match_attr(): | |
op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) | |
op_pat = op(wildcard(), wildcard()) | |
x = relay.var('x') | |
y = relay.var('y') | |
assert not op_pat.match(relay.op.nn.dense(x, y)) | |
def test_match_diamond(): | |
# Pattern | |
is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) | |
path1 = is_op('nn.relu')(is_conv2d) | |
path2 = is_op('nn.leaky_relu')(is_conv2d) | |
diamond = is_op('add')(path1, path2) | |
# Expr | |
inp = relay.var('input') | |
weight = relay.var('weight') | |
conv2d = relay.op.nn.conv2d(inp, weight) | |
relu = relay.op.nn.relu(conv2d) | |
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) | |
out = relu + leaky_relu | |
# Check | |
assert diamond.match(out) | |
def test_no_match_diamond(): | |
# Pattern | |
is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) | |
path1 = is_op('nn.relu')(is_conv2d) | |
path2 = is_op('nn.leaky_relu')(is_conv2d) | |
diamond = is_op('add')(path1, path2) | |
# Expr | |
inp = relay.var('input') | |
weight = relay.var('weight') | |
conv2d = relay.op.nn.conv2d(inp, weight) | |
relu = relay.op.nn.relu(conv2d) | |
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) | |
out = relu + leaky_relu | |
# Check | |
assert not diamond.match(leaky_relu) | |
assert not diamond.match(relu) | |
def test_match_fake_diamond(): | |
# Pattern | |
data_pat = is_input('data') | |
weight_pat = is_input('weight') | |
is_conv2d = is_op('nn.conv2d')(data_pat, weight_pat) | |
path1 = is_op('nn.relu')(is_conv2d) | |
path2 = is_op('nn.leaky_relu')(is_conv2d) | |
diamond = is_op('add')(path1, path2) | |
# Expr | |
input1 = relay.var('input1') | |
weight1 = relay.var('weight1') | |
conv2d1 = relay.op.nn.conv2d(input1, weight1) | |
inp2 = relay.var('input2') | |
weight2 = relay.var('weight2') | |
conv2d2 = relay.op.nn.conv2d(inp2, weight2) | |
relu = relay.op.nn.relu(conv2d1) | |
leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) | |
out = relu + leaky_relu | |
# Check | |
assert not diamond.match(out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment