Last active
April 28, 2024 18:57
-
-
Save pchng/b6d51483397bd9440f7654dfb3c8275c to your computer and use it in GitHub Desktop.
Triton Puzzles: Simple FlashAttention: Using one program_id to block over q, and loop over k, v columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is the tiling approach in: https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf | |
def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: | |
x = q[:, None] * k[None, :] | |
x_max = x.max(1, keepdim=True)[0] | |
x = x - x_max | |
x_exp = x.exp() | |
soft = x_exp / x_exp.sum(1, keepdim=True) | |
return (v[None, :] * soft).sum(1) | |
@triton.jit | |
def flashatt_kernel(q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr): | |
log2_e = 1.44269504 | |
pid_0 = tl.program_id(0) | |
# How this works: | |
# 1. Each kernel instance reads B0 number of "rows" from q. | |
# 2. The loop handles reading rows from k, v, and then uses FlashAttention to compute B0 output "rows" for o/z. | |
# 3. In each loop iteration: | |
# - Read in TILE_SIZE number of "rows" from k, v. | |
# - Compute the PARTIAL PRE-SOFTMAX logits x; in each iteration this is of size (B0, TILE_SIZE) | |
# - Over all iterations this will produce the entire pre-softmax logits for all B0 rows, but never need to completely materialize! | |
# - Over all iterations, ALL rows of k, v are read. | |
# - Apply FlashAttention algorithm to compute/update m, d, o. | |
# - After all iterations have completed, `o` will have B0 rows of the output. | |
# 4. Write B0 "rows" to z from the aggregated results in o. | |
# 5. This means that each kernel instance: | |
# - Reads separate B0 rows from q | |
# - Reads ALL rows of k, v. | |
# - Writes B0 rows to z/o. | |
# 6. Memory footprint of each kernel instance: This can be controlled by B0, TILE_SIZE: | |
# - B0 rows of q | |
# - TILE_SIZE rows of k, v | |
# - Pre-softmax logits x: (B0, TILE_SIZE) | |
# - m, d, o: (B0, ) | |
# TODO: Since T == N0, how are they different and how does it matter here? | |
# Read B0 "rows" from q | |
q_offsets = pid_0*B0 + tl.arange(0, B0) | |
q_mask = q_offsets < N0 # NOTE: Sequence length T is same as N0 here | |
q = tl.load(q_ptr + q_offsets, mask=q_mask, other=0.0) | |
# Initial values: All have dimensions of B0 (number rows from q) because they correspond to how many rows of the attention matrix are being computed at once. | |
m = tl.full(shape=(B0,), value=float('-inf'), dtype=tl.float32) # Need max per row | |
d = tl.full(shape=(B0,), value=0.0, dtype=tl.float32) # Need denominator per row | |
o = tl.full(shape=(B0,), value=0.0, dtype=tl.float32) # Output for each position aggregated across v | |
TILE_SIZE = 32 # Could be set to B0; just to test out the iteration loop logic | |
S = N0 # Or T | |
for i in range(0, S, TILE_SIZE): | |
kv_offsets = i + tl.arange(0, TILE_SIZE) # Read only TILE_SIZE "rows" from k, v | |
kv_mask = kv_offsets < S | |
# Use 0 so the x = q @ k^T won't be affected; same for v. | |
k = tl.load(k_ptr + kv_offsets, mask=kv_mask, other=0.0) | |
v = tl.load(v_ptr + kv_offsets, mask=kv_mask, other=0.0) | |
# Need to unsqueeze an extra dimension to make this an outer product since the `d_head` dimension is only 1. | |
# Equivalent to doing q @ k.transpose() | |
# NOTE: This is a TILE of the full attention matrix that is of dimension (B0, TILE_SIZE): | |
# B0 * TILE_SIZE dot products between B0 rows of q and TILE_SIZE number of cols from k.transpose() | |
x = q[:, None] * k[None, :] # (B0=num rows from q, TILE_SIZE=num cols from k) | |
# NOTE: Left out for simplicity: | |
# 1. Scaling factor to divide by sqrt(d_head), which in this case is just 1 so no effect | |
# 2. Masked part of self attention, e.g. the pre-softmax logits should be lower triangular retained with other elements sent to -inf | |
# Update m, d, o | |
# NOTE: Need m_next[:, None] to unsqueeze m_next from (200,) to (200, 1) for proper broadcasting | |
m_next = tl.maximum(m, x.max(axis=1)) | |
# Sum when accumulating the additional denominator terms because processing a block at a time. | |
d_next = d * tl.exp2(log2_e*(m - m_next)) + tl.sum(tl.exp2(log2_e*(x - m_next[:, None])), axis=1) | |
o_next = o * d * tl.exp2(log2_e*(m - m_next))/d_next + tl.sum(tl.exp2(log2_e*(x - m_next[:, None])) * v/d_next[:, None], axis=1) | |
d = d_next | |
m = m_next | |
o = o_next | |
# Output: Same offsets and mask as q: Outputs B0 "rows" of o/z. | |
tl.store(z_ptr + q_offsets, o, mask=q_mask) | |
return | |
# TODO: Testing with B0=32, N0=32(?) and T=200 | |
test(flashatt_kernel, flashatt_spec, B={"B0": 32}, | |
nelem={"N0": 200, "T": 200}, viz=False) | |
# test(flashatt_kernel, flashatt_spec, B={"B0": 200}, | |
# nelem={"N0": 200, "T": 200}, viz=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment