Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save anijain2305/0472b8a9a255541e6e869b81fd2eb4d4 to your computer and use it in GitHub Desktop.
Save anijain2305/0472b8a9a255541e6e869b81fd2eb4d4 to your computer and use it in GitHub Desktop.
import torch
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.tensor(0)
self.b = torch.tensor(0)
self.register_buffer("iterations", torch.tensor(0))
def forward(self, x):
# Mutation handled in bytecode
self.a += 1
# Mutation handled in bytecode
setattr(self, "b", self.a * 2)
# Mutation handled in the graph module itself
self.iterations += 1
return x * self.a * self.b
mod = MockModule()
x = torch.randn(4)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
print(opt_mod(torch.ones(4)))
print(opt_mod(torch.ones(4)))
@anijain2305
Copy link
Author

[2023-07-11 21:38:27,973] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
 ===== __compiled_fn_0 =====
 <eval_with_key>.0 class GraphModule(torch.nn.Module):
    def forward(self, L_x_ : torch.Tensor, L_self_a : torch.Tensor, L_self_iterations : torch.Tensor):
        l_x_ = L_x_
        l_self_a = L_self_a
        l_self_iterations = L_self_iterations

        # File: /scratch/anijain/work/pytorch/examples/custom_setattr.py:13, code: self.a += 1
        l_self_a += 1;  iadd = l_self_a;  l_self_a = None

        # File: /scratch/anijain/work/pytorch/examples/custom_setattr.py:16, code: setattr(self, "b", self.a * 2)
        mul = iadd * 2

        # File: /scratch/anijain/work/pytorch/examples/custom_setattr.py:19, code: self.iterations += 1
        l_self_iterations += 1;  iadd_1 = l_self_iterations;  l_self_iterations = None

        # File: /scratch/anijain/work/pytorch/examples/custom_setattr.py:20, code: return x * self.a * self.b
        mul_1 = l_x_ * iadd;  l_x_ = None
        mul_2 = mul_1 * mul;  mul_1 = None
        return (mul_2, iadd, mul, iadd_1)


[2023-07-11 21:38:27,975] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE forward /scratch/anijain/work/pytorch/examples/custom_setattr.py line 11
 13           0 LOAD_FAST                0 (self)
              2 DUP_TOP
              4 LOAD_ATTR                0 (a)
              6 LOAD_CONST               1 (1)
              8 INPLACE_ADD
             10 ROT_TWO
             12 STORE_ATTR               0 (a)

 16          14 LOAD_GLOBAL              1 (setattr)
             16 LOAD_FAST                0 (self)
             18 LOAD_CONST               2 ('b')
             20 LOAD_FAST                0 (self)
             22 LOAD_ATTR                0 (a)
             24 LOAD_CONST               3 (2)
             26 BINARY_MULTIPLY
             28 CALL_FUNCTION            3
             30 POP_TOP

 19          32 LOAD_FAST                0 (self)
             34 DUP_TOP
             36 LOAD_ATTR                2 (iterations)
             38 LOAD_CONST               1 (1)
             40 INPLACE_ADD
             42 ROT_TWO
             44 STORE_ATTR               2 (iterations)

 20          46 LOAD_FAST                1 (x)
             48 LOAD_FAST                0 (self)
             50 LOAD_ATTR                0 (a)
             52 BINARY_MULTIPLY
             54 LOAD_FAST                0 (self)
             56 LOAD_ATTR                3 (b)
             58 BINARY_MULTIPLY
             60 RETURN_VALUE


[2023-07-11 21:38:27,975] torch._dynamo.convert_frame.__bytecode: [DEBUG] MODIFIED BYTECODE forward /scratch/anijain/work/pytorch/examples/custom_setattr.py line 11
 11           0 LOAD_GLOBAL              4 (__compiled_fn_0)
              2 LOAD_FAST                1 (x)
              4 LOAD_FAST                0 (self)
              6 LOAD_ATTR                0 (a)
              8 LOAD_FAST                0 (self)
             10 LOAD_ATTR                2 (iterations)
             12 CALL_FUNCTION            3
             14 STORE_FAST               2 (___graph_out_0)
             16 LOAD_FAST                2 (___graph_out_0)
             18 LOAD_CONST               4 (0)
             20 BINARY_SUBSCR
             22 LOAD_FAST                2 (___graph_out_0)
             24 LOAD_CONST               1 (1)
             26 BINARY_SUBSCR
             28 LOAD_FAST                0 (self)
             30 LOAD_FAST                2 (___graph_out_0)
             32 LOAD_CONST               3 (2)
             34 BINARY_SUBSCR
             36 LOAD_FAST                0 (self)
             38 LOAD_FAST                2 (___graph_out_0)
             40 LOAD_CONST               5 (3)
             42 BINARY_SUBSCR
             44 LOAD_FAST                0 (self)
             46 STORE_ATTR               2 (iterations)
             48 STORE_ATTR               3 (b)
             50 STORE_ATTR               0 (a)
             52 RETURN_VALUE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment