Skip to content

Instantly share code, notes, and snippets.

Last active July 11, 2022 21:51
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krunt/72197074816dfe4035fcd9413e4afb22 to your computer and use it in GitHub Desktop.
Save krunt/72197074816dfe4035fcd9413e4afb22 to your computer and use it in GitHub Desktop.

Issue title: (working implementation) Fused multi-head attention for arbitrary sequence lengths.

TL;DR you can run multi-head attention (fwd+bwd) faster and with no extra memory – with any sequence length and head dim. We’d love to make it available via apex. We need your advice on how best to do that.

Why should I care? Here's how it compares against the standard multihead attention (blue) for one multi-head attention layer of GPT-J on an RTX 3080Ti.

time, with backward (ms) peak vram allocated (mb)
  • This can accelerate most foundation models, including bloom
    • in contrast, vanilla fused attention requires <=64d attention heads and short sequences, which is incompatible with GPT-like transformers
  • This lets you train transformers on very long sequences (few x 10^4) without micro-checkpointing
  • This also solves NVIDIA/apex#1201

How it works?

Our implementation is based on apex's fused multihead attention (fmha) -- a very efficient implementation that works only with head dim 64 for sequences up to 512. We use mathematical invariants of softmax to extend FMHA for longer sequences and 128d heads required for large language models.

About: fmha algorithm and its limitations (click to expand)

FMHA implements the entire attention core (from q,k,v to outputs) as a single cuda kernel. Forward pass: it stores all keys and values in shared memory, iteratively loads a slice of queries, computes all intermediate values (dot-products, weights) and saves only attention outputs aka weighted sums. Backward pass: it re-computes some of these intermediate values for backpropagation, which turns out to be faster than loading them, as long as you use tensor cores.

This runs 1.5-2x faster and is more memory-optimal, but it requires that everything fits into shared memory (50-100KB) -- which is why it only works for short sequences and small head dim.

Main observation: if you compute attention outputs over two separate blocks of keys & values, you can merge them efficiently, but only if you saved softmax denominators and stabilizers from each block. However, both denominators and stabilizers are small 1d vectors, negligible compared to attention outputs.

Math: how to merge blocks (click to expand)

Let's consider one attention head and one sample for simplicity:

queries = tensor  # shape: [N, head_dim]
keys1, values1 = load_tensors()  # shape: [M1, head_dim]

outputs1, maxima1, denominators1 = fused_attention(queries, keys1, values1) #, where
# outputs [N, head_dim] are attention outputs, weighted sums of values,
#     outputs[i] := sum( e^(queries1[i] * keys1[j] - maxima1[i]) / denominators1[i] * values1[j] )
# maxima [N, 1] are maximum dot products, used for softmax numerical stability
#     maxima[i] := max(queries1[i] * keys1[j] for j in range(M1))
# denominators [N, 1] are softmax denominators, 
#     denominators[i] := sum( e^(queries1[i] * keys1[j] - maxima1[i]) for j in range(M1) )

# now for the second block
keys2, values2 = load_more_tensors()  # shape: [M2, head_dim]
outputs2, maxima2, denominators2 = fused_attention(queries, keys2, values2)

# convert to new maxima
new_maxima = elementwise_max(maxima1, maxima2)

# convert to new denominator
denominator1_new_maxima = denominators1 * e^(maxima1 - new_maxima)
denominator2_new_maxima = denominators2 * e^(maxima2 - new_maxima)
new_denominators = denominator1_new_maxima + denominator2_for_new_maxima

# convert attention outputs:    switch to new maxima    ,    then to new denominator
outputs1_rescaled = outputs1 * e^(maxima1 - new_maxima) * (denominators1 / new_denominators)
outputs2_rescaled = outputs2 * e^(maxima2 - new_maxima) * (denominators2 / new_denominators)

new_outputs = outputs1_rescaled + outputs2_rescaled

# if there is a 3rd block of keys/values, do the same calculation again for (new_outputs, new_maxima, new_denominators)
# Multiple heads and/or sequences can be handled independently, typically in separate thread blocks.

Our strategy is to compute fused attention one chunk at a time, and use the above strategy to "accumulate" attention outputs as we go. In our implementation, this is done by a single CUDA kernel that iterates over key&value chunks in the outer loop and passes over queries in the inner loop.

There is a similar invariant for backward pass, but it requires global softmax maxima & denominators. Fortunately, you compute them anyway during forward pass, and they are small [N, 1] vectors. For specifics, please refer to our implementation (below). Alternatively, we'll happily explain the math over some voice call if you're interested.

This algorithm will only use global memory for storing layer outputs and softmax denominators/maxima (2 * seq_length). In contrast, most attention implementations would expliitly materialize all attention weights (seq_length^2), which makes it difficult to fit long sequences -- and slows down code due to slow global memory access.

Stability-wise, we tested this algoritithm in half precision for up to 20k tokens, and it seems to work fine. We anticipated some issues with softmax denominator (in float32), but it did not cause problems. This is likely because we subtract maxima before computing denominator, which was inherited from the original FMHA.


An implementation for arbitrary sequence length with head dim 64 can be found here: It reuses most of the code from apex.contrib.fmha, modifying it with the block merging algorithm defined above

Overview: changes from FMHA Forward:
Float vectors of maximums and sums are in global memory and updated on each block iteration. Ouput tensor is used as an output accumulator. On each block iteration - load only one block of k,v, but full seqlen iteration on queries and outputs.
Using maximums and sums vectors from forward - calculate float vector sum(grado * s). dodv calculate by full seqlen iteration of do and write block rows of dodv per block iteration. use as temporary storage in gmem - attention matrix (with block number of columns) grado*s - sum(grado*s)*s. based on it - calculate dodq analogously to how output of forward pass is calculated. calculate dodk - the same way as dodv calculated.

We also made a proof of concept for 128-dim heads for forward pass: Larger head size requires running with smaller blocks (128-256 instead of 384-512), but it still runs faster than most popular implementations and uses dramatically less memory.

Unfortunately, we haven't (yet) found a good way to merge both implementations into one. This requires deeper knowledge of CUDA internals that we do not have. We would appreciate if someone could help us with that.


seq-len head-dim block-size fwd,ms bwd,ms mem, mb fwd, max error fwd, avg error bwd, max error bwd, avg error
fmha 1920 64 384 5.23 16.87 605 5e-4 1.1e-05 2e-4 4.3e-06
py reference impl 1920 64 - 8.64 17.33 4769 - - - -
fmha 2048 128 128 9.57 38.5 1203 8e-4 3.8e-6
py reference impl 2048 128 - 12.8 24.6 5680 - - - -


  • Alexey Kuts implemented the algorithm and benchmarks
  • TimDettmers formulated the idea and high-level code design
  • xtinkt and dfyz helped with implementation details
  • this work is done as a part of BigScience engineering group

Related work: Attention does not need n^2 memory - this paper proposes a very similar idea, but requires micro-checkpointing and slows down computation a bit, whereas our version accelerates computation.

Copy link

@krunt thank you for this writeup!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment