Created
July 12, 2023 00:53
-
-
Save anijain2305/322be86c7f2f9a11975a19559e0606b1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class Foo(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.register_buffer("iterations", torch.tensor(0)) | |
def forward(self, x): | |
self.iterations.add_(1) | |
return x * self.iterations | |
x = torch.ones(1, 2, 3) | |
mod = Foo() | |
opt_mod = torch.compile(mod, backend="aot_eager", fullgraph=True) | |
ref = mod(x) | |
res = opt_mod(x) | |
assert torch.allclose(ref, res) | |
# Check again to ensure that mutation is handled correctly | |
ref = mod(x) | |
res = opt_mod(x) | |
assert torch.allclose(ref, res) |
Author
anijain2305
commented
Jul 12, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment