Skip to content

Instantly share code, notes, and snippets.

@jroesch

jroesch/split.py Secret

Created March 4, 2020 19:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jroesch/a110f8eab27ba2bfb1b433fed2b13c9c to your computer and use it in GitHub Desktop.
Save jroesch/a110f8eab27ba2bfb1b433fed2b13c9c to your computer and use it in GitHub Desktop.
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
x = relay.var('x', shape=(10, 2))
y = relay.var('y', shape=(10, 2))
f = relay.Function([x, y], x + y)
class SplitOp(ExprMutator):
def __init__(self):
super().__init__()
self.inputs = []
self.new_inputs = []
def visit_function(self, func):
new_inputs = []
for param in func.params:
self.inputs.append(param)
split_dim = int(param.type_annotation.shape[0]) // 2
fixed_dim = param.type_annotation.shape[1]
param1 = relay.var(f"{param.name_hint}1", shape=(split_dim, fixed_dim))
param2 = relay.var(f"{param.name_hint}2", shape=(split_dim, fixed_dim))
new_inputs.append(param1)
new_inputs.append(param2)
self.new_inputs = new_inputs
body = self.visit(func.body)
return relay.Function(new_inputs, body)
def visit_call(self, call):
if isinstance(call.op, relay.Op) and call.op == relay.op.op.get("add"):
lhs, rhs = call.args
if lhs in self.inputs and rhs in self.inputs:
lhs1 = self.new_inputs[0]
rhs1 = self.new_inputs[1]
lhs2 = self.new_inputs[2]
rhs2 = self.new_inputs[3]
return relay.op.concatenate([lhs1 + rhs1, lhs2 + rhs2], axis=0)
else:
return super().visit_call(call)
else:
return super().visit_call(call)
so = SplitOp()
f_prime = so.visit(f)
mod = tvm.IRModule.from_expr(f_prime)
mod = relay.transform.InferType()(mod)
print(mod)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment