Skip to content

Instantly share code, notes, and snippets.

@Chillee
Created April 12, 2024 05:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chillee/41baf11aac8036d25d637321c48dad20 to your computer and use it in GitHub Desktop.
Save Chillee/41baf11aac8036d25d637321c48dad20 to your computer and use it in GitHub Desktop.
You Could Have Invented Flash-Attention!
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
torch.set_default_device('cuda')
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
iters_per_second = 1e3/ms_per_iter
print(f"{iters_per_second * total_flops / 1e12} TF/s")
def attention(q, k, v):
return torch.softmax(q @ k.T, dim=-1) @ v
S = 4096
D = 256
for D in [64, 128, 256, 512, 1024]:
q = torch.randn(S, D, dtype=torch.bfloat16)
k = torch.randn(S, D, dtype=torch.bfloat16)
v = torch.randn(S, D, dtype=torch.bfloat16)
print(f"D={D}")
get_flops_achieved(lambda: attention(q, k, v))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment