Skip to content

Instantly share code, notes, and snippets.

@uzl
Last active April 8, 2024 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save uzl/508c15e6bca21e9775309af4266e29d3 to your computer and use it in GitHub Desktop.
Save uzl/508c15e6bca21e9775309af4266e29d3 to your computer and use it in GitHub Desktop.
Attention Calculation Methods
import torch.nn as nn
import torch.nn.functional as F
from torch.backends.cuda import sdp_kernel, SDPBackend
import time
import torch
import torch.utils.benchmark as benchmark
# check dependency
assert torch.cuda.is_available(), 'CUDA is expected.'
cu_major, cu_minor = torch.version.cuda.split('.')
# assert int(cu_major) >= 11 and int(cu_minor) >= 6, f'Expected CUDA 11.6 and above. But found {cu_major}.{cu_minor}'
pt_major, pt_minor = str(torch.__version__).split('.')[:2]
assert int(pt_major) >= 2 and int(pt_minor) >= 0, 'Expected Pytorch 2.0 and above.'
# =
# for reproducible result (but might not fully deterministic)
torch.manual_seed(37)
torch.cuda.manual_seed_all(37)
device = "cuda"
# dummy input
b_size = 32
seq_len = 1024
num_heads = 32
embd_dim = 32
dtype = torch.float16
query = torch.rand(b_size, num_heads, seq_len, embd_dim, device=device, dtype=dtype)
key = torch.rand(b_size, num_heads, seq_len, embd_dim, device=device, dtype=dtype)
value = torch.rand(b_size, num_heads, seq_len, embd_dim, device=device, dtype=dtype)
print(f"flash_sdp_enabled:\t\t{torch.backends.cuda.flash_sdp_enabled()}")
print(f"mem_efficient_sdp_enabled:\t{torch.backends.cuda.mem_efficient_sdp_enabled()}")
print(f"math_sdp_enabled:\t\t{torch.backends.cuda.math_sdp_enabled()}")
print("_" * 80)
# Checking different attention calculation methods
#-------------------------------------------------
# PyTorch naive implementation defined in C++
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
res_1 = F.scaled_dot_product_attention(query, key, value)
# calculate time
t1 = benchmark.Timer(
stmt="F.scaled_dot_product_attention(query, key, value)",
setup="import torch.nn.functional as F",
globals={"query": query, "key": key, "value": value},
)
print(f"PyTorch naive attention took:\t\t{round(t1.blocked_autorange().mean * 1e6)} microseconds")
# Memory-Efficient Attention
# Self-attention Does Not Need O(n2) Memory [https://arxiv.org/abs/2112.05682]
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
res_2 = F.scaled_dot_product_attention(query, key, value)
# calculate time
t2 = benchmark.Timer(
stmt="F.scaled_dot_product_attention(query, key, value)",
setup="import torch.nn.functional as F",
globals={"query": query, "key": key, "value": value},
)
print(f"Memory-Efficient attention took:\t{round(t2.blocked_autorange().mean * 1e6)} microseconds")
# FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
res_3 = None
try:
res_3 = F.scaled_dot_product_attention(query, key, value)
# calculate time
t3 = benchmark.Timer(
stmt="F.scaled_dot_product_attention(query, key, value)",
setup="import torch.nn.functional as F",
globals={"query": query, "key": key, "value": value},
)
print(f"Flash attention took:\t{round(t3.blocked_autorange().mean * 1e6)} microseconds")
except RuntimeError:
print("FlashAttention is not supported in this device")
print('\nRESULT:')
print(f"PyTorch naive vs Memory-Efficient: is same? --> {torch.allclose(res_1, res_2, rtol=0.001, atol=0.0000001)}")
if res_3 is not None:
print(f"PyTorch naive vs Flash: is same? --> {torch.allclose(res_1, res_3, rtol=0.001, atol=0.0000001)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment