Created
February 29, 2024 20:54
-
-
Save andreaskoepf/833aac25c6e049e37ddadb5d0ad1ef48 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
*Experimental* implementation of FlashAttention in Triton. | |
Tested with triton==2.0.0.dev20221202. | |
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions | |
other than 64: | |
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 | |
We'll update this implementation with the new Triton backend once this is fixed. | |
We use the FlashAttention implementation from Phil Tillet a starting point. | |
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py | |
Changes: | |
- Implement both causal and non-causal attention. | |
- Implement both self-attention and cross-attention. | |
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. | |
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. | |
- Support attention bias. | |
- Speed up the forward pass a bit, and only store the LSE instead of m and l. | |
- Make the backward for d=128 much faster by reducing register spilling. | |
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of | |
small batch size * nheads. | |
Caution: | |
- This is an *experimental* implementation. The forward pass should be quite robust but | |
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). | |
- This implementation has only been tested on A100. | |
- If you plan to use headdim other than 64 and 128, you should test for race conditions | |
(due to the Triton compiler), as done in tests/test_flash_attn.py | |
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions | |
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident | |
that there are none left for other head dimensions. | |
Differences between this Triton version and the CUDA version: | |
- Triton version doesn't support dropout. | |
- Triton forward is generally faster than CUDA forward, while Triton backward is | |
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower | |
than CUDA forward + backward. | |
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). | |
- Triton version supports attention bias, while CUDA version doesn't. | |
""" | |
import math | |
import torch | |
import triton | |
import triton.language as tl | |
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 | |
# @triton.autotune( | |
# configs=[ | |
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), | |
# # This config has a race condition when EVEN_M == False, disabling it for now. | |
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), | |
# ], | |
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] | |
# ) | |
@triton.heuristics( | |
{ | |
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, | |
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, | |
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], | |
} | |
) | |
@triton.jit | |
def _fwd_kernel( | |
Q, | |
K, | |
V, | |
Bias, | |
Out, | |
Lse, | |
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug | |
softmax_scale, | |
stride_qb, # Q : [b, h, m, d] | |
stride_qh, | |
stride_qm, | |
stride_kb, # K : [b, h, n, d] | |
stride_kh, | |
stride_kn, | |
stride_vb, # V : [b, h, n, d] | |
stride_vh, | |
stride_vn, | |
stride_bb, # bias : [b, h, m, d] | |
stride_bh, | |
stride_bm, | |
stride_ob, # output: [b, h, m, d] | |
stride_oh, | |
stride_om, | |
nheads, | |
seqlen_q, | |
seqlen_k, | |
seqlen_q_rounded, | |
headdim, | |
CACHE_KEY_SEQLEN_Q, | |
CACHE_KEY_SEQLEN_K, | |
BIAS_TYPE: tl.constexpr, | |
IS_CAUSAL: tl.constexpr, | |
BLOCK_HEADDIM: tl.constexpr, | |
EVEN_M: tl.constexpr, | |
EVEN_N: tl.constexpr, | |
EVEN_HEADDIM: tl.constexpr, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
): | |
start_m = tl.program_id(0) # q block | |
off_hb = tl.program_id(1) | |
off_b = off_hb // nheads # batch index | |
off_h = off_hb % nheads # head index | |
# off_b = tl.program_id(1) | |
# off_h = tl.program_id(2) | |
# off_hb = off_b * nheads + off_h | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # q row block offsets, BLOCK_M is block over rows of queries | |
offs_n = tl.arange(0, BLOCK_N) # BLOCK_N block over rows of keys | |
offs_d = tl.arange(0, BLOCK_HEADDIM) | |
# Initialize pointers to Q, K, V | |
# Adding parenthesis around indexing might use int32 math instead of int64 math? | |
# https://github.com/openai/triton/issues/741 | |
# I'm seeing a tiny bit of difference (5-7us) | |
q_ptrs = ( | |
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) | |
) | |
k_ptrs = ( | |
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) | |
) | |
v_ptrs = ( | |
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) | |
) | |
if BIAS_TYPE == "vector": | |
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n | |
elif BIAS_TYPE == "matrix": | |
b_ptrs = ( | |
Bias | |
+ off_b * stride_bb | |
+ off_h * stride_bh | |
+ (offs_m[:, None] * stride_bm + offs_n[None, :]) | |
) | |
# initialize pointer to m and l | |
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m | |
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) | |
# load q: it will stay in SRAM throughout | |
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call | |
# tl.load(q_ptrs), we get the wrong output! | |
if EVEN_M & EVEN_N: | |
if EVEN_HEADDIM: | |
q = tl.load(q_ptrs) | |
else: | |
q = tl.load( | |
q_ptrs, | |
mask=offs_d[None, :] < headdim, | |
other=0.0, | |
) | |
else: | |
if EVEN_HEADDIM: | |
q = tl.load( | |
q_ptrs, | |
mask=offs_m[:, None] < seqlen_q, | |
other=0.0, | |
) | |
else: | |
q = tl.load( | |
q_ptrs, | |
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), | |
other=0.0, | |
) | |
# loop over k, v and update accumulator | |
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) | |
for start_n in range(0, end_n, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition | |
if EVEN_HEADDIM: | |
k = tl.load(k_ptrs + start_n * stride_kn) | |
else: | |
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) | |
else: | |
if EVEN_HEADDIM: | |
k = tl.load( | |
k_ptrs + start_n * stride_kn, | |
mask=(start_n + offs_n)[:, None] < seqlen_k, | |
other=0.0, | |
) | |
else: | |
k = tl.load( | |
k_ptrs + start_n * stride_kn, | |
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), | |
other=0.0, | |
) | |
# UNCOMMENTED HERE HERE: IT WORKS ... | |
# if BIAS_TYPE != "none": | |
# if BIAS_TYPE == "vector": | |
# if EVEN_N: | |
# bias = tl.load(b_ptrs + start_n).to(tl.float32) | |
# else: | |
# bias = tl.load( | |
# b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 | |
# ).to(tl.float32) | |
# bias = bias[None, :] | |
# elif BIAS_TYPE == "matrix": | |
# if EVEN_M & EVEN_N: | |
# bias = tl.load(b_ptrs + start_n).to(tl.float32) | |
# else: | |
# bias = tl.load( | |
# b_ptrs + start_n, | |
# mask=(offs_m[:, None] < seqlen_q) | |
# & ((start_n + offs_n)[None, :] < seqlen_k), | |
# other=0.0, | |
# ).to(tl.float32) | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
qk += tl.dot(q, tl.trans(k)) | |
# Trying to combine the two masks seem to make the result wrong | |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong | |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) | |
if IS_CAUSAL: | |
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) | |
if BIAS_TYPE != "none": | |
# BEGIN COMMENTING HERE | |
if BIAS_TYPE == "vector": | |
if EVEN_N: | |
bias = tl.load(b_ptrs + start_n).to(tl.float32) | |
else: | |
bias = tl.load( | |
b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 | |
).to(tl.float32) | |
bias = bias[None, :] | |
elif BIAS_TYPE == "matrix": | |
if EVEN_M & EVEN_N: | |
bias = tl.load(b_ptrs + start_n).to(tl.float32) | |
else: | |
bias = tl.load( | |
b_ptrs + start_n, | |
mask=(offs_m[:, None] < seqlen_q) | |
& ((start_n + offs_n)[None, :] < seqlen_k), | |
other=0.0, | |
).to(tl.float32) | |
# END COMMENTING HERE | |
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler | |
# can then fuse the mult and add into an fma instruction. But if we have bias we need to | |
# to multiply with softmax_scale here. | |
qk = qk * softmax_scale + bias | |
m_ij = tl.maximum(tl.max(qk, 1), lse_i) | |
p = tl.exp(qk - m_ij[:, None]) | |
else: | |
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) | |
p = tl.exp(qk * softmax_scale - m_ij[:, None]) | |
l_ij = tl.sum(p, 1) | |
# scale acc_o | |
acc_o_scale = tl.exp(m_i - m_ij) | |
# # -- update output accumulator -- | |
# BUG: have to store and immediately load | |
# tl.store(t_ptrs, acc_o_scale) | |
# acc_o_scale = tl.load(t_ptrs) | |
acc_o = acc_o * acc_o_scale[:, None] | |
# update acc_o | |
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition | |
if EVEN_HEADDIM: | |
v = tl.load(v_ptrs + start_n * stride_vn) | |
else: | |
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) | |
else: | |
if EVEN_HEADDIM: | |
v = tl.load( | |
v_ptrs + start_n * stride_vn, | |
mask=(start_n + offs_n)[:, None] < seqlen_k, | |
other=0.0, | |
) | |
else: | |
v = tl.load( | |
v_ptrs + start_n * stride_vn, | |
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), | |
other=0.0, | |
) | |
p = p.to(v.dtype) | |
acc_o += tl.dot(p, v) | |
# -- update statistics | |
m_i = m_ij | |
l_i_new = tl.exp(lse_i - m_ij) + l_ij | |
lse_i = m_ij + tl.log(l_i_new) | |
o_scale = tl.exp(m_i - lse_i) | |
# BUG: have to store and immediately load | |
tl.store(t_ptrs, o_scale) | |
o_scale = tl.load(t_ptrs) | |
acc_o = acc_o * o_scale[:, None] | |
# rematerialize offsets to save registers | |
start_m = tl.program_id(0) | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
# write back l and m | |
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m | |
tl.store(lse_ptrs, lse_i) | |
# initialize pointers to output | |
offs_d = tl.arange(0, BLOCK_HEADDIM) | |
out_ptrs = ( | |
Out | |
+ off_b * stride_ob | |
+ off_h * stride_oh | |
+ (offs_m[:, None] * stride_om + offs_d[None, :]) | |
) | |
if EVEN_M: | |
if EVEN_HEADDIM: | |
tl.store(out_ptrs, acc_o) | |
else: | |
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) | |
else: | |
if EVEN_HEADDIM: | |
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) | |
else: | |
tl.store( | |
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) | |
) | |
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): | |
# shape constraints | |
batch, seqlen_q, nheads, d = q.shape | |
_, seqlen_k, _, _ = k.shape | |
assert k.shape == (batch, seqlen_k, nheads, d) | |
assert v.shape == (batch, seqlen_k, nheads, d) | |
assert d <= 128, "FlashAttention only support head dimensions up to 128" | |
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" | |
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" | |
assert q.is_cuda and k.is_cuda and v.is_cuda | |
softmax_scale = softmax_scale or 1.0 / math.sqrt(d) | |
has_bias = bias is not None | |
bias_type = "none" | |
if has_bias: | |
assert bias.dtype in [q.dtype, torch.float] | |
assert bias.is_cuda | |
assert bias.dim() == 4 | |
if bias.stride(-1) != 1: | |
bias = bias.contiguous() | |
if bias.shape[2:] == (1, seqlen_k): | |
bias_type = "vector" | |
elif bias.shape[2:] == (seqlen_q, seqlen_k): | |
bias_type = "matrix" | |
else: | |
raise RuntimeError( | |
"Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" | |
) | |
print("bias_type", bias_type) | |
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) | |
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) | |
print("bias_strides", bias_strides) | |
seqlen_q_rounded = ((seqlen_q + 127) // 128) * 128 | |
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) | |
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) | |
o = torch.empty_like(q) | |
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) | |
BLOCK = 128 | |
#num_warps = 4 if d <= 64 else 8 | |
num_warps = 4 # only works with 4 for me on different GPUs tested | |
# 2D launch: (cdiv(seqlen_q, BLOCK_M), batch * nheads) | |
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) | |
_fwd_kernel[grid]( | |
q, | |
k, | |
v, | |
bias, | |
o, | |
lse, | |
tmp, | |
softmax_scale, | |
q.stride(0), | |
q.stride(2), | |
q.stride(1), | |
k.stride(0), | |
k.stride(2), | |
k.stride(1), | |
v.stride(0), | |
v.stride(2), | |
v.stride(1), | |
*bias_strides, | |
o.stride(0), | |
o.stride(2), | |
o.stride(1), | |
nheads, | |
seqlen_q, | |
seqlen_k, | |
seqlen_q_rounded, | |
d, | |
seqlen_q // 32, | |
seqlen_k // 32, # key for triton cache (limit number of compilations) | |
# Can't use kwargs here because triton autotune expects key to be args, not kwargs | |
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, | |
bias_type, | |
causal, | |
BLOCK_HEADDIM, | |
BLOCK_M=BLOCK, | |
BLOCK_N=BLOCK, | |
num_warps=num_warps, | |
num_stages=1, | |
) | |
return o, lse, softmax_scale # softmax_scale could have been updated | |
def flash_attn_qkvpacked_func(qkv, bias=None, causal=False, softmax_scale=None): | |
""" | |
qkv: (batch, seqlen, 3, nheads, headdim) | |
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). | |
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). | |
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) | |
""" | |
# Make sure that the last dimension is contiguous | |
if qkv.stride(-1) != 1: | |
qkv = qkv.contiguous() | |
o, _lse, _softmax_scale = _flash_attn_forward( | |
qkv[:, :, 0], | |
qkv[:, :, 1], | |
qkv[:, :, 2], | |
bias=bias, | |
causal=causal, | |
softmax_scale=softmax_scale, | |
) | |
return o | |
def flash_attn_kvpacked_func(q, kv, bias=None, causal=False, softmax_scale=None): | |
""" | |
q: (batch, seqlen_q, nheads, headdim) | |
kv: (batch, seqlen_k, 2, nheads, headdim) | |
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). | |
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). | |
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) | |
""" | |
# Make sure that the last dimension is contiguous | |
q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] | |
o, _lse, _softmax_scale = _flash_attn_forward( | |
q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale | |
) | |
return o | |
def flash_attn_func(q, k, v, bias=None, causal=False, softmax_scale=None): | |
""" | |
q: (batch_size, seqlen_q, nheads, headdim) | |
k, v: (batch_size, seqlen_k, nheads, headdim) | |
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). | |
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). | |
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) | |
""" | |
# Make sure that the last dimension is contiguous | |
q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] | |
o, _lse, _softmax_scale = _flash_attn_forward( | |
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale | |
) | |
return o | |
if __name__ == "__main__": | |
B,S,H,D = 1, 32, 4, 128 | |
device = torch.device("cuda:0") | |
dtype = torch.float16 | |
q = torch.rand(B, S, H, D, dtype=dtype, device=device) | |
k = torch.rand_like(q) | |
v = torch.rand_like(q) | |
scale = D ** -0.5 | |
bias = torch.randn(1, 1, S, S, device=device, dtype=dtype) * 5 | |
a = flash_attn_func(q, k, v, bias=bias) | |
print("a", a.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment