-
-
Save jroesch/a110f8eab27ba2bfb1b433fed2b13c9c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import relay | |
from tvm.relay.expr_functor import ExprMutator | |
x = relay.var('x', shape=(10, 2)) | |
y = relay.var('y', shape=(10, 2)) | |
f = relay.Function([x, y], x + y) | |
class SplitOp(ExprMutator): | |
def __init__(self): | |
super().__init__() | |
self.inputs = [] | |
self.new_inputs = [] | |
def visit_function(self, func): | |
new_inputs = [] | |
for param in func.params: | |
self.inputs.append(param) | |
split_dim = int(param.type_annotation.shape[0]) // 2 | |
fixed_dim = param.type_annotation.shape[1] | |
param1 = relay.var(f"{param.name_hint}1", shape=(split_dim, fixed_dim)) | |
param2 = relay.var(f"{param.name_hint}2", shape=(split_dim, fixed_dim)) | |
new_inputs.append(param1) | |
new_inputs.append(param2) | |
self.new_inputs = new_inputs | |
body = self.visit(func.body) | |
return relay.Function(new_inputs, body) | |
def visit_call(self, call): | |
if isinstance(call.op, relay.Op) and call.op == relay.op.op.get("add"): | |
lhs, rhs = call.args | |
if lhs in self.inputs and rhs in self.inputs: | |
lhs1 = self.new_inputs[0] | |
rhs1 = self.new_inputs[1] | |
lhs2 = self.new_inputs[2] | |
rhs2 = self.new_inputs[3] | |
return relay.op.concatenate([lhs1 + rhs1, lhs2 + rhs2], axis=0) | |
else: | |
return super().visit_call(call) | |
else: | |
return super().visit_call(call) | |
so = SplitOp() | |
f_prime = so.visit(f) | |
mod = tvm.IRModule.from_expr(f_prime) | |
mod = relay.transform.InferType()(mod) | |
print(mod) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment