Created
October 30, 2022 18:14
-
-
Save ezyang/54f03e02fd36069bf9693ae2ab707d10 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import inf | |
import torch | |
from torch import tensor, device | |
import torch.fx as fx | |
import torch._dynamo | |
from torch._dynamo.testing import rand_strided | |
from torch._dynamo.debug_utils import run_fwd_maybe_bwd | |
from torch._dynamo.debug_utils import same_two_models | |
args = [((2, 576, 576, 16), (5308416, 576, 1, 331776), torch.float32, 'cuda', True)] | |
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] | |
from torch.nn import * | |
class Repro(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.self_self_blocks_0_attn_proj_l = Linear(in_features=16, out_features=16, bias=True) | |
def forward(self, permute_1): | |
self_self_blocks_0_attn_proj_l = self.self_self_blocks_0_attn_proj_l(permute_1); permute_1 = None | |
permute_2 = self_self_blocks_0_attn_proj_l.permute(0, 3, 1, 2); self_self_blocks_0_attn_proj_l = None | |
softmax = permute_2.softmax(dim = -1); permute_2 = None | |
permute_3 = softmax.permute(0, 2, 3, 1); softmax = None | |
return (permute_3,) | |
mod = Repro().cuda() | |
mod(*args) | |
opt_mod = torch._dynamo.optimize("aot_eager")(mod) | |
with torch.cuda.amp.autocast(enabled=False): | |
ref = run_fwd_maybe_bwd(mod, args) | |
res = run_fwd_maybe_bwd(opt_mod, args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment