Created
June 27, 2023 15:33
-
-
Save soulitzer/1c2d49a049e3a2d77326d867d875cd42 to your computer and use it in GitHub Desktop.
get_source_partitions produces different results in export aten_graph=True
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions | |
import pprint | |
m = torch.nn.Linear(10, 10) | |
def fn(x): | |
return m(x) | |
def my_backend(gm, example_inputs): | |
gm.graph.print_tabular() | |
module_partitions = get_source_partitions( | |
gm.graph, [torch.nn.Linear, torch.nn.functional.linear] | |
) | |
pprint.pprint(module_partitions) | |
# import inspect | |
# print(inspect.getsource(gm.forward)) | |
return gm | |
compiled_f = torch.compile(fn, backend=my_backend) | |
a = torch.rand(10) | |
compiled_f(a) | |
# EXPORT | |
import copy | |
import torch._dynamo as torchdynamo | |
m, guards = torchdynamo.export( | |
m, | |
*copy.deepcopy((a,)), | |
aten_graph=True, | |
tracing_mode="real", | |
) | |
module_partitions = get_source_partitions( | |
m.graph, [torch.nn.Linear, torch.nn.functional.linear] | |
) | |
m.graph.print_tabular() | |
pprint.pprint(module_partitions) | |
""" | |
before: | |
opcode name target args kwargs | |
----------- ------ -------- ------- -------- | |
placeholder l_x_ L_x_ () {} | |
call_module m m (l_x_,) {} | |
output output output ((m,),) {} | |
{<class 'torch.nn.modules.linear.Linear'>: [SourcePartition(nodes=[m], source=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[l_x_], output_nodes=[m], params=[])]} | |
opcode name target args kwargs | |
------------- ----------------- ---------------------- ------------------------------- -------- | |
placeholder arg0 arg0 () {} | |
get_attr _param_constant0 _param_constant0 () {} | |
call_function t_default aten.t.default (_param_constant0,) {} | |
call_function unsqueeze_default aten.unsqueeze.default (arg0, 0) {} | |
call_function mm_default aten.mm.default (unsqueeze_default, t_default) {} | |
call_function squeeze_dim aten.squeeze.dim (mm_default, 0) {} | |
get_attr _param_constant1 _param_constant1 () {} | |
call_function add_tensor aten.add.Tensor (squeeze_dim, _param_constant1) {} | |
output output output ([add_tensor],) {} | |
{<class 'torch.nn.modules.linear.Linear'>: [SourcePartition(nodes=[_param_constant0, t_default, unsqueeze_default, mm_default, squeeze_dim, _param_constant1, add_tensor], source=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[add_tensor], params=[_param_constant1, _param_constant0])]} | |
after: | |
opcode name target args kwargs | |
------------- ------------- -------------------------- ---------------------------------- -------- | |
placeholder l_x_ L_x_ () {} | |
get_attr g__m___weight G__m___weight () {} | |
get_attr g__m___bias G__m___bias () {} | |
call_function linear <built-in function linear> (l_x_, g__m___weight, g__m___bias) {} | |
output output output ((linear,),) {} | |
{<built-in function linear>: [SourcePartition(nodes=[linear], source=<built-in function linear>, input_nodes=[l_x_, g__m___weight, g__m___bias], output_nodes=[linear], params=[])]} | |
opcode name target args kwargs | |
------------- ----------------- ---------------------- ------------------------------- -------- | |
placeholder arg0 arg0 () {} | |
get_attr _param_constant0 _param_constant0 () {} | |
call_function t_default aten.t.default (_param_constant0,) {} | |
call_function unsqueeze_default aten.unsqueeze.default (arg0, 0) {} | |
call_function mm_default aten.mm.default (unsqueeze_default, t_default) {} | |
call_function squeeze_dim aten.squeeze.dim (mm_default, 0) {} | |
get_attr _param_constant1 _param_constant1 () {} | |
call_function add_tensor aten.add.Tensor (squeeze_dim, _param_constant1) {} | |
output output output ([add_tensor],) {} | |
{<built-in function linear>: [SourcePartition(nodes=[_param_constant0, t_default, unsqueeze_default, mm_default, squeeze_dim, _param_constant1, add_tensor], source=<built-in function linear>, input_nodes=[arg0], output_nodes=[add_tensor], params=[_param_constant0, _param_constant1])]} | |
""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment