Skip to content

Instantly share code, notes, and snippets.

@pchng
Last active April 28, 2024 18:57
Show Gist options
  • Save pchng/b6d51483397bd9440f7654dfb3c8275c to your computer and use it in GitHub Desktop.
Save pchng/b6d51483397bd9440f7654dfb3c8275c to your computer and use it in GitHub Desktop.
Triton Puzzles: Simple FlashAttention: Using one program_id to block over q, and loop over k, v columns
# This is the tiling approach in: https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
x = q[:, None] * k[None, :]
x_max = x.max(1, keepdim=True)[0]
x = x - x_max
x_exp = x.exp()
soft = x_exp / x_exp.sum(1, keepdim=True)
return (v[None, :] * soft).sum(1)
@triton.jit
def flashatt_kernel(q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr):
log2_e = 1.44269504
pid_0 = tl.program_id(0)
# How this works:
# 1. Each kernel instance reads B0 number of "rows" from q.
# 2. The loop handles reading rows from k, v, and then uses FlashAttention to compute B0 output "rows" for o/z.
# 3. In each loop iteration:
# - Read in TILE_SIZE number of "rows" from k, v.
# - Compute the PARTIAL PRE-SOFTMAX logits x; in each iteration this is of size (B0, TILE_SIZE)
# - Over all iterations this will produce the entire pre-softmax logits for all B0 rows, but never need to completely materialize!
# - Over all iterations, ALL rows of k, v are read.
# - Apply FlashAttention algorithm to compute/update m, d, o.
# - After all iterations have completed, `o` will have B0 rows of the output.
# 4. Write B0 "rows" to z from the aggregated results in o.
# 5. This means that each kernel instance:
# - Reads separate B0 rows from q
# - Reads ALL rows of k, v.
# - Writes B0 rows to z/o.
# 6. Memory footprint of each kernel instance: This can be controlled by B0, TILE_SIZE:
# - B0 rows of q
# - TILE_SIZE rows of k, v
# - Pre-softmax logits x: (B0, TILE_SIZE)
# - m, d, o: (B0, )
# TODO: Since T == N0, how are they different and how does it matter here?
# Read B0 "rows" from q
q_offsets = pid_0*B0 + tl.arange(0, B0)
q_mask = q_offsets < N0 # NOTE: Sequence length T is same as N0 here
q = tl.load(q_ptr + q_offsets, mask=q_mask, other=0.0)
# Initial values: All have dimensions of B0 (number rows from q) because they correspond to how many rows of the attention matrix are being computed at once.
m = tl.full(shape=(B0,), value=float('-inf'), dtype=tl.float32) # Need max per row
d = tl.full(shape=(B0,), value=0.0, dtype=tl.float32) # Need denominator per row
o = tl.full(shape=(B0,), value=0.0, dtype=tl.float32) # Output for each position aggregated across v
TILE_SIZE = 32 # Could be set to B0; just to test out the iteration loop logic
S = N0 # Or T
for i in range(0, S, TILE_SIZE):
kv_offsets = i + tl.arange(0, TILE_SIZE) # Read only TILE_SIZE "rows" from k, v
kv_mask = kv_offsets < S
# Use 0 so the x = q @ k^T won't be affected; same for v.
k = tl.load(k_ptr + kv_offsets, mask=kv_mask, other=0.0)
v = tl.load(v_ptr + kv_offsets, mask=kv_mask, other=0.0)
# Need to unsqueeze an extra dimension to make this an outer product since the `d_head` dimension is only 1.
# Equivalent to doing q @ k.transpose()
# NOTE: This is a TILE of the full attention matrix that is of dimension (B0, TILE_SIZE):
# B0 * TILE_SIZE dot products between B0 rows of q and TILE_SIZE number of cols from k.transpose()
x = q[:, None] * k[None, :] # (B0=num rows from q, TILE_SIZE=num cols from k)
# NOTE: Left out for simplicity:
# 1. Scaling factor to divide by sqrt(d_head), which in this case is just 1 so no effect
# 2. Masked part of self attention, e.g. the pre-softmax logits should be lower triangular retained with other elements sent to -inf
# Update m, d, o
# NOTE: Need m_next[:, None] to unsqueeze m_next from (200,) to (200, 1) for proper broadcasting
m_next = tl.maximum(m, x.max(axis=1))
# Sum when accumulating the additional denominator terms because processing a block at a time.
d_next = d * tl.exp2(log2_e*(m - m_next)) + tl.sum(tl.exp2(log2_e*(x - m_next[:, None])), axis=1)
o_next = o * d * tl.exp2(log2_e*(m - m_next))/d_next + tl.sum(tl.exp2(log2_e*(x - m_next[:, None])) * v/d_next[:, None], axis=1)
d = d_next
m = m_next
o = o_next
# Output: Same offsets and mask as q: Outputs B0 "rows" of o/z.
tl.store(z_ptr + q_offsets, o, mask=q_mask)
return
# TODO: Testing with B0=32, N0=32(?) and T=200
test(flashatt_kernel, flashatt_spec, B={"B0": 32},
nelem={"N0": 200, "T": 200}, viz=False)
# test(flashatt_kernel, flashatt_spec, B={"B0": 200},
# nelem={"N0": 200, "T": 200}, viz=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment