Skip to content

Instantly share code, notes, and snippets.

@jroesch
Created February 25, 2020 20:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jroesch/de9648556b1f7e1bb3db490e0f708813 to your computer and use it in GitHub Desktop.
Save jroesch/de9648556b1f7e1bb3db490e0f708813 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import attr
from typing import List, Any, Optional
from tvm import relay
from tvm.ir import Type, IRModule
from tvm.relay import Expr, Op
from tvm.relay.expr_functor import ExprVisitor
class Matcher:
def __init__(self):
self.input_map = {}
def match(self, pattern, expr):
if isinstance(pattern, ExprPattern):
return self.match_expr(pattern.op, expr)
elif isinstance(pattern, CallPattern):
return self.match_call(pattern.op, pattern.args, expr)
elif isinstance(pattern, AltPattern):
return self.match_alt(pattern.left, pattern.right, expr)
elif isinstance(pattern, MatchTypePattern):
return self.match_type(pattern.pattern, pattern.ty, expr)
elif isinstance(pattern, MatchAttrPattern):
return self.match_attr(pattern.pattern, pattern.attr_name, pattern.attr_value, expr)
elif isinstance(pattern, InputPattern):
return self.match_input(pattern.name, expr)
elif isinstance(pattern, WildCardPattern):
# NB: don't need to process in anyway
return True
else:
raise Exception(f"unsupported {type(pattern)}")
def match_expr(self, op, expr):
if op == expr:
return True
else:
return False
# TODO(@jroesch): we should probably support matching type args and call attributes here
def match_call(self, op, args, expr):
if isinstance(expr, relay.Call):
does_match = self.match(op, expr.op)
if does_match:
assert len(args) == len(expr.args)
for pat, arg in zip(args, expr.args):
does_match = does_match and self.match(pat, arg)
return does_match
else:
return False
def match_alt(self, lhs, rhs, expr):
return self.match(lhs, expr) or self.match(rhs, expr)
def match_type(self, pattern, ty, expr):
expr = infer_type(expr)
if self.match(pattern, expr):
return expr.checked_type == ty
else:
return False
def match_attr(self, pattern, attr_name, attr_value, expr):
if self.match(pattern, expr) and isinstance(expr, relay.Op):
actual_value = expr.get_attr(attr_name)
if actual_value == attr_value:
return True
else:
return False
else:
return False
import pdb; pdb.set_trace()
def match_input(self, name, expr):
# TODO(@jroesch) we should probably generalize this to named groups,
# or some kind of feature to name a matched sub-expression then
# require the match be identical elsewhere
#
# will come back to this tomorrow
# If it isn't a variable its not an input, assuming graph form.
if not isinstance(expr, relay.Var):
return False
# If name isn't set we can match any input.
if name is None:
return True
# If name is set we must match the same expression in each position.
#
# We could improve implementaton by tracking all possible matches to
# improve error reporting.
if name not in self.input_map:
self.input_map[name] = expr
return
else:
return self.input_map[name] == expr
def infer_type(expr):
mod = IRModule.from_expr(expr)
mod = relay.transform.InferType()(mod)
if not isinstance(expr, relay.Function):
return mod["main"].body
else:
return mod["main"]
class Pattern:
"""The base class of patterns."""
def __call__(self, *args: Pattern) -> Pattern:
return CallPattern(self, list(args))
def __or__(self, other: Pattern) -> Pattern:
return AltPattern(self, other)
# def __le__(self, other: Pattern) -> Pattern:
# """This is probably a crazy idea, but just a PoC of what we could do."""
# return CallPattern()
def has_attr(self, attr_name, attr_value) -> Pattern:
return MatchAttrPattern(self, attr_name, attr_value)
def match(self, expr: Expr) -> bool:
matcher = Matcher()
return matcher.match(self, expr)
@attr.s(auto_attribs=True)
class ExprPattern(Pattern):
"""A pattern which matches a single operation."""
op: Op
@attr.s(auto_attribs=True)
class WildCardPattern(Pattern):
"""A wildcard pattern, which matches any sub-expression."""
pass
@attr.s(auto_attribs=True)
class CallPattern(Pattern):
"""A pattern which matches a call, containing recursive patterns."""
op: Pattern
args: List[Pattern]
@attr.s(auto_attribs=True)
class MatchTypePattern(Pattern):
"""A pattern which only matches if the interior match has a specific type."""
pattern: Pattern
# Note potentially match a type with holes in order to do partial type matching.
ty: Type
@attr.s(auto_attribs=True)
class MatchAttrPattern(Pattern):
"""An op pattern which only matches if the op's attribute matches."""
pattern: Pattern
attr_name: str
attr_value: Any
@attr.s(auto_attribs=True)
class InputPattern(Pattern):
name: Optional[str]
class WildcardPattern(Pattern):
pass
@attr.s(auto_attribs=True)
class AltPattern(Pattern):
left: Pattern
right: Pattern
@attr.s(auto_attribs=True)
class Group(Pattern):
pattern: Pattern
def is_op(op_name: str) -> Pattern:
op = relay.op.op.get(op_name)
return ExprPattern(op)
def wildcard() -> Pattern:
return WildCardPattern()
def has_type(ty, pattern=None):
if pattern is None:
pattern = wildcard()
return MatchTypePattern(pattern, ty)
def has_attr(attr_name, attr_value, pattern):
return MatchAttrPattern(pattern, attr_name, attr_value)
def is_input(name=None) -> Pattern:
return InputPattern(name)
def test_match_op():
assert is_op('add').match(relay.op.op.get("add"))
def test_no_match_op():
assert not is_op('add').match(relay.op.op.get("subtract"))
def test_match_op_or():
is_add_or_sub = is_op('add') | is_op('subtract')
assert is_add_or_sub.match(relay.op.op.get("add"))
assert is_add_or_sub.match(relay.op.op.get("subtract"))
def test_match_call():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
assert add_pattern.match(x + y)
def test_no_match_call():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
assert not add_pattern.match(x - y)
def test_match_type():
x = relay.var('x', shape=(10, 10), dtype="float32")
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert ty_pat.match(x)
def test_no_match_type():
x = relay.var('x', shape=(10, 10), dtype="int32")
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert not ty_pat.match(x)
# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 1
def test_match_attr():
op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert op_pat.match(x + y)
def test_no_match_attr():
op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert not op_pat.match(relay.op.nn.dense(x, y))
def test_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
# Expr
inp = relay.var('input')
weight = relay.var('weight')
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu
# Check
assert diamond.match(out)
def test_no_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
# Expr
inp = relay.var('input')
weight = relay.var('weight')
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu
# Check
assert not diamond.match(leaky_relu)
assert not diamond.match(relu)
def test_match_fake_diamond():
# Pattern
data_pat = is_input('data')
weight_pat = is_input('weight')
is_conv2d = is_op('nn.conv2d')(data_pat, weight_pat)
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
# Expr
input1 = relay.var('input1')
weight1 = relay.var('weight1')
conv2d1 = relay.op.nn.conv2d(input1, weight1)
inp2 = relay.var('input2')
weight2 = relay.var('weight2')
conv2d2 = relay.op.nn.conv2d(inp2, weight2)
relu = relay.op.nn.relu(conv2d1)
leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0)
out = relu + leaky_relu
# Check
assert not diamond.match(out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment