Skip to content

Instantly share code, notes, and snippets.

@anijain2305
Created July 12, 2023 00:53
Show Gist options
  • Save anijain2305/322be86c7f2f9a11975a19559e0606b1 to your computer and use it in GitHub Desktop.
Save anijain2305/322be86c7f2f9a11975a19559e0606b1 to your computer and use it in GitHub Desktop.
import torch
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("iterations", torch.tensor(0))
def forward(self, x):
self.iterations.add_(1)
return x * self.iterations
x = torch.ones(1, 2, 3)
mod = Foo()
opt_mod = torch.compile(mod, backend="aot_eager", fullgraph=True)
ref = mod(x)
res = opt_mod(x)
assert torch.allclose(ref, res)
# Check again to ensure that mutation is handled correctly
ref = mod(x)
res = opt_mod(x)
assert torch.allclose(ref, res)
@anijain2305
Copy link
Author

[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_x_ : torch.Tensor):
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_x_ = L_x_
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /scratch/anijain/work/pytorch/examples/aot.py:9, code: self.iterations.add_(1)
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         l__self___iterations = self.L__self___iterations
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         add_ = l__self___iterations.add_(1)
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /scratch/anijain/work/pytorch/examples/aot.py:10, code: return x * self.iterations
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         mul = l_x_ * l__self___iterations;  l_x_ = l__self___iterations = None
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (mul,)
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-07-12 00:54:48,583] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 0 =====
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.2 from /scratch/anijain/work/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module):
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, arg0_1: i64[], arg1_1: f32[1, 2, 3]):
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /scratch/anijain/work/pytorch/examples/aot.py:9, code: self.iterations.add_(1)
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]         add: i64[] = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /scratch/anijain/work/pytorch/examples/aot.py:10, code: return x * self.iterations
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: f32[1, 2, 3] = torch.ops.aten.mul.Tensor(arg1_1, add);  arg1_1 = None
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return (add, mul)
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-07-12 00:54:48,615] torch._functorch.aot_autograd.__aot_graphs: [INFO]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment