Last active
May 13, 2024 01:58
-
-
Save alexfanqi/8b4d549799d002d4ba85f3284c8df90a to your computer and use it in GitHub Desktop.
relax custom layer intercept
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
original graph: | |
opcode name target args kwargs | |
------------- ---------- ----------------------- ------------- -------- | |
placeholder x x () {} | |
call_function mul <built-in function mul> (x, 1.1) {} | |
call_module mod_linear mod.linear (mul,) {} | |
output output output (mod_linear,) {} | |
# from tvm.script import ir as I | |
# from tvm.script import relax as R | |
@I.ir_module | |
class Module: | |
@R.function | |
def main(inp_0: R.Tensor((128, 4), dtype="float32")) -> R.Tensor((128, 5), dtype="float32"): | |
with R.dataflow(): | |
lv: R.Tensor((128, 4), dtype="float32") = R.multiply(inp_0, R.const(1.1000000238418579, " | |
float32")) | |
lv1: R.Tensor((4, 5), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0 | |
], axes=None) | |
lv2: R.Tensor((128, 5), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32") | |
lv3: R.Tensor((128, 5), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][1]) | |
gv: R.Tensor((128, 5), dtype="float32") = lv3 | |
R.output(gv) | |
return gv | |
# Metadata omitted. Use show_meta=True in script() method to show it. | |
custom graph: | |
opcode name target args kwargs | |
------------- ------ ----------------------- -------- -------- | |
placeholder x x () {} | |
call_function mul <built-in function mul> (x, 1.1) {} | |
call_module mod mod (mul,) {} | |
output output output (mod,) {} | |
hello, world | |
# from tvm.script import ir as I | |
# from tvm.script import relax as R | |
@I.ir_module | |
class Module: | |
@R.function | |
def main(inp_0: R.Tensor((128, 4), dtype="float32")) -> R.Tensor((128, 4), dtype="float32"): | |
with R.dataflow(): | |
lv: R.Tensor((128, 4), dtype="float32") = R.multiply(inp_0, R.const(1.1000000238418579, " | |
float32")) | |
lv1: R.Tensor((128, 4), dtype="float32") = R.full_like(lv, R.const(0, "float32"), dtype=" | |
float32") | |
gv: R.Tensor((128, 4), dtype="float32") = lv1 | |
R.output(gv) | |
return gv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class InterestingModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear = nn.Linear(4, 5) | |
self.hello = "hello, world" | |
def forward(self, x): | |
return self.linear(x) | |
class Top(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.mod = InterestingModule() | |
def forward(self, x): | |
return self.mod(x * 1.1) | |
from torch.fx import Tracer, symbolic_trace | |
from tvm.relax.frontend.torch import from_fx | |
module_traced = symbolic_trace(Top()) | |
print("original graph:") | |
module_traced.graph.print_tabular() | |
with torch.no_grad(): | |
fx_module = from_fx(module_traced, [((128, 4), "float32")]) | |
fx_module.show() | |
from torch.fx.proxy import Scope, ScopeContextManager | |
from torch.fx.graph_module import GraphModule | |
modules_to_intercept = set(("InterestingModule",)) | |
class myTracer(Tracer): | |
def call_module( | |
self, | |
m: torch.nn.Module, | |
forward, | |
args, | |
kwargs, | |
): | |
module_qualified_name = self.path_of_module(m) | |
with ScopeContextManager( | |
self.scope, Scope(module_qualified_name, type(m)) | |
) as _scope: | |
self.module_stack[_scope.module_path] = ( | |
module_qualified_name, | |
_scope.module_type, | |
) | |
if m._get_name() in modules_to_intercept: | |
ret_val = self.create_proxy( | |
"call_module", module_qualified_name, args, kwargs | |
) | |
elif not self.is_leaf_module(m, module_qualified_name): | |
ret_val = forward(*args, **kwargs) | |
else: | |
ret_val = self.create_proxy( | |
"call_module", module_qualified_name, args, kwargs | |
) | |
key, _ = self.module_stack.popitem(last=True) | |
assert key == _scope.module_path, f" Unexpected key {key}" | |
return ret_val | |
def my_symbolic_trace( | |
root, | |
concrete_args=None, | |
) -> GraphModule: | |
tracer = myTracer() | |
graph = tracer.trace(root, concrete_args) | |
name = ( | |
root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ | |
) | |
return GraphModule(tracer.root, graph, name) | |
module_traced_custom = my_symbolic_trace(Top()) | |
print("custom graph:") | |
module_traced_custom.graph.print_tabular() | |
from tvm import relax | |
from tvm.relax.frontend.torch.fx_translator import TorchFXImporter | |
def _InterestingModule(self, node) -> relax.Var: | |
assert node.target in self.named_modules | |
module = self.named_modules[node.target] | |
print(module.hello) | |
x = self.env[node.args[0]] | |
dtype = ( | |
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) | |
if "dtype" in node.kwargs | |
else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) | |
) | |
return self.block_builder.emit( | |
relax.op.full_like( | |
x, | |
relax.const(0, dtype), | |
dtype, | |
) | |
) | |
importer = TorchFXImporter() | |
with torch.no_grad(): | |
fx_module_custom = importer.from_fx( | |
module_traced_custom, | |
[((128, 4), "float32")], | |
False, | |
False, | |
False, | |
custom_convert_map={ | |
InterestingModule: lambda node: _InterestingModule(importer, node) | |
}, | |
) | |
fx_module_custom.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment