Skip to content

Instantly share code, notes, and snippets.

@xmfan
Created February 13, 2024 02:23
Show Gist options
  • Save xmfan/6f1bc7e487bf1bdb97ae467d23ec9f21 to your computer and use it in GitHub Desktop.
Save xmfan/6f1bc7e487bf1bdb97ae467d23ec9f21 to your computer and use it in GitHub Desktop.
import torch
def compiler_fn(gm):
return torch.compile(gm, mode="reduce-overhead", fullgraph=True, dynamic=True)
def fn():
x = torch.randn(2, 2, device="cuda", requires_grad=True)
y = torch.randn(2, 2, device="cuda")
out = torch.mm(x, y)
loss = out.sum() / out.numel()
loss.backward()
with torch._dynamo.compiled_autograd.enable(compiler_fn):
fn()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment