Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from Chillee/create_block_mask.py
Created February 6, 2025 10:07
Show Gist options
  • Save jeromeku/2df933cf48c2328d934a95e3a28eca2a to your computer and use it in GitHub Desktop.
Save jeromeku/2df933cf48c2328d934a95e3a28eca2a to your computer and use it in GitHub Desktop.
Compiling `create_block_mask`
import torch
from triton.testing import do_bench
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, noop_mask
torch.manual_seed(0)
import torch
torch.set_default_device('cuda')
def sliding_window(b, h, q_idx, kv_idx):
return (q_idx - kv_idx).abs() < 2048
S = 16384
create_block_mask_compiled = torch.compile(create_block_mask)
print("compiled: ", do_bench(lambda: create_block_mask_compiled(sliding_window, B=None, H=None, Q_LEN=S, KV_LEN=S)))
print("eager: ", do_bench(lambda: create_block_mask(sliding_window, B=None, H=None, Q_LEN=S, KV_LEN=S)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment