Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created June 27, 2023 15:33
Show Gist options
  • Save soulitzer/1c2d49a049e3a2d77326d867d875cd42 to your computer and use it in GitHub Desktop.
Save soulitzer/1c2d49a049e3a2d77326d867d875cd42 to your computer and use it in GitHub Desktop.
get_source_partitions produces different results in export aten_graph=True
import torch
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
import pprint
m = torch.nn.Linear(10, 10)
def fn(x):
return m(x)
def my_backend(gm, example_inputs):
gm.graph.print_tabular()
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
)
pprint.pprint(module_partitions)
# import inspect
# print(inspect.getsource(gm.forward))
return gm
compiled_f = torch.compile(fn, backend=my_backend)
a = torch.rand(10)
compiled_f(a)
# EXPORT
import copy
import torch._dynamo as torchdynamo
m, guards = torchdynamo.export(
m,
*copy.deepcopy((a,)),
aten_graph=True,
tracing_mode="real",
)
module_partitions = get_source_partitions(
m.graph, [torch.nn.Linear, torch.nn.functional.linear]
)
m.graph.print_tabular()
pprint.pprint(module_partitions)
"""
before:
opcode name target args kwargs
----------- ------ -------- ------- --------
placeholder l_x_ L_x_ () {}
call_module m m (l_x_,) {}
output output output ((m,),) {}
{<class 'torch.nn.modules.linear.Linear'>: [SourcePartition(nodes=[m], source=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[l_x_], output_nodes=[m], params=[])]}
opcode name target args kwargs
------------- ----------------- ---------------------- ------------------------------- --------
placeholder arg0 arg0 () {}
get_attr _param_constant0 _param_constant0 () {}
call_function t_default aten.t.default (_param_constant0,) {}
call_function unsqueeze_default aten.unsqueeze.default (arg0, 0) {}
call_function mm_default aten.mm.default (unsqueeze_default, t_default) {}
call_function squeeze_dim aten.squeeze.dim (mm_default, 0) {}
get_attr _param_constant1 _param_constant1 () {}
call_function add_tensor aten.add.Tensor (squeeze_dim, _param_constant1) {}
output output output ([add_tensor],) {}
{<class 'torch.nn.modules.linear.Linear'>: [SourcePartition(nodes=[_param_constant0, t_default, unsqueeze_default, mm_default, squeeze_dim, _param_constant1, add_tensor], source=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[add_tensor], params=[_param_constant1, _param_constant0])]}
after:
opcode name target args kwargs
------------- ------------- -------------------------- ---------------------------------- --------
placeholder l_x_ L_x_ () {}
get_attr g__m___weight G__m___weight () {}
get_attr g__m___bias G__m___bias () {}
call_function linear <built-in function linear> (l_x_, g__m___weight, g__m___bias) {}
output output output ((linear,),) {}
{<built-in function linear>: [SourcePartition(nodes=[linear], source=<built-in function linear>, input_nodes=[l_x_, g__m___weight, g__m___bias], output_nodes=[linear], params=[])]}
opcode name target args kwargs
------------- ----------------- ---------------------- ------------------------------- --------
placeholder arg0 arg0 () {}
get_attr _param_constant0 _param_constant0 () {}
call_function t_default aten.t.default (_param_constant0,) {}
call_function unsqueeze_default aten.unsqueeze.default (arg0, 0) {}
call_function mm_default aten.mm.default (unsqueeze_default, t_default) {}
call_function squeeze_dim aten.squeeze.dim (mm_default, 0) {}
get_attr _param_constant1 _param_constant1 () {}
call_function add_tensor aten.add.Tensor (squeeze_dim, _param_constant1) {}
output output output ([add_tensor],) {}
{<built-in function linear>: [SourcePartition(nodes=[_param_constant0, t_default, unsqueeze_default, mm_default, squeeze_dim, _param_constant1, add_tensor], source=<built-in function linear>, input_nodes=[arg0], output_nodes=[add_tensor], params=[_param_constant0, _param_constant1])]}
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment