Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created October 30, 2022 18:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ezyang/54f03e02fd36069bf9693ae2ab707d10 to your computer and use it in GitHub Desktop.
Save ezyang/54f03e02fd36069bf9693ae2ab707d10 to your computer and use it in GitHub Desktop.
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
args = [((2, 576, 576, 16), (5308416, 576, 1, 331776), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_self_blocks_0_attn_proj_l = Linear(in_features=16, out_features=16, bias=True)
def forward(self, permute_1):
self_self_blocks_0_attn_proj_l = self.self_self_blocks_0_attn_proj_l(permute_1); permute_1 = None
permute_2 = self_self_blocks_0_attn_proj_l.permute(0, 3, 1, 2); self_self_blocks_0_attn_proj_l = None
softmax = permute_2.softmax(dim = -1); permute_2 = None
permute_3 = softmax.permute(0, 2, 3, 1); softmax = None
return (permute_3,)
mod = Repro().cuda()
mod(*args)
opt_mod = torch._dynamo.optimize("aot_eager")(mod)
with torch.cuda.amp.autocast(enabled=False):
ref = run_fwd_maybe_bwd(mod, args)
res = run_fwd_maybe_bwd(opt_mod, args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment