Skip to content

Instantly share code, notes, and snippets.

@alexfanqi
Last active May 13, 2024 01:58
Show Gist options
  • Save alexfanqi/8b4d549799d002d4ba85f3284c8df90a to your computer and use it in GitHub Desktop.
Save alexfanqi/8b4d549799d002d4ba85f3284c8df90a to your computer and use it in GitHub Desktop.
relax custom layer intercept
original graph:
opcode name target args kwargs
------------- ---------- ----------------------- ------------- --------
placeholder x x () {}
call_function mul <built-in function mul> (x, 1.1) {}
call_module mod_linear mod.linear (mul,) {}
output output output (mod_linear,) {}
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((128, 4), dtype="float32")) -> R.Tensor((128, 5), dtype="float32"):
with R.dataflow():
lv: R.Tensor((128, 4), dtype="float32") = R.multiply(inp_0, R.const(1.1000000238418579, "
float32"))
lv1: R.Tensor((4, 5), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0
], axes=None)
lv2: R.Tensor((128, 5), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32")
lv3: R.Tensor((128, 5), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][1])
gv: R.Tensor((128, 5), dtype="float32") = lv3
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
custom graph:
opcode name target args kwargs
------------- ------ ----------------------- -------- --------
placeholder x x () {}
call_function mul <built-in function mul> (x, 1.1) {}
call_module mod mod (mul,) {}
output output output (mod,) {}
hello, world
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((128, 4), dtype="float32")) -> R.Tensor((128, 4), dtype="float32"):
with R.dataflow():
lv: R.Tensor((128, 4), dtype="float32") = R.multiply(inp_0, R.const(1.1000000238418579, "
float32"))
lv1: R.Tensor((128, 4), dtype="float32") = R.full_like(lv, R.const(0, "float32"), dtype="
float32")
gv: R.Tensor((128, 4), dtype="float32") = lv1
R.output(gv)
return gv
import torch
from torch import nn
class InterestingModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 5)
self.hello = "hello, world"
def forward(self, x):
return self.linear(x)
class Top(nn.Module):
def __init__(self):
super().__init__()
self.mod = InterestingModule()
def forward(self, x):
return self.mod(x * 1.1)
from torch.fx import Tracer, symbolic_trace
from tvm.relax.frontend.torch import from_fx
module_traced = symbolic_trace(Top())
print("original graph:")
module_traced.graph.print_tabular()
with torch.no_grad():
fx_module = from_fx(module_traced, [((128, 4), "float32")])
fx_module.show()
from torch.fx.proxy import Scope, ScopeContextManager
from torch.fx.graph_module import GraphModule
modules_to_intercept = set(("InterestingModule",))
class myTracer(Tracer):
def call_module(
self,
m: torch.nn.Module,
forward,
args,
kwargs,
):
module_qualified_name = self.path_of_module(m)
with ScopeContextManager(
self.scope, Scope(module_qualified_name, type(m))
) as _scope:
self.module_stack[_scope.module_path] = (
module_qualified_name,
_scope.module_type,
)
if m._get_name() in modules_to_intercept:
ret_val = self.create_proxy(
"call_module", module_qualified_name, args, kwargs
)
elif not self.is_leaf_module(m, module_qualified_name):
ret_val = forward(*args, **kwargs)
else:
ret_val = self.create_proxy(
"call_module", module_qualified_name, args, kwargs
)
key, _ = self.module_stack.popitem(last=True)
assert key == _scope.module_path, f" Unexpected key {key}"
return ret_val
def my_symbolic_trace(
root,
concrete_args=None,
) -> GraphModule:
tracer = myTracer()
graph = tracer.trace(root, concrete_args)
name = (
root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
)
return GraphModule(tracer.root, graph, name)
module_traced_custom = my_symbolic_trace(Top())
print("custom graph:")
module_traced_custom.graph.print_tabular()
from tvm import relax
from tvm.relax.frontend.torch.fx_translator import TorchFXImporter
def _InterestingModule(self, node) -> relax.Var:
assert node.target in self.named_modules
module = self.named_modules[node.target]
print(module.hello)
x = self.env[node.args[0]]
dtype = (
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
if "dtype" in node.kwargs
else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env)
)
return self.block_builder.emit(
relax.op.full_like(
x,
relax.const(0, dtype),
dtype,
)
)
importer = TorchFXImporter()
with torch.no_grad():
fx_module_custom = importer.from_fx(
module_traced_custom,
[((128, 4), "float32")],
False,
False,
False,
custom_convert_map={
InterestingModule: lambda node: _InterestingModule(importer, node)
},
)
fx_module_custom.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment