Skip to content

Instantly share code, notes, and snippets.

Issue title: (working implementation) Fused multi-head attention for arbitrary sequence lengths.

TL;DR you can run multi-head attention (fwd+bwd) faster and with no extra memory – with any sequence length and head dim. We’d love to make it available via apex. We need your advice on how best to do that.

Why should I care? Here's how it compares against the standard multihead attention (blue) for one multi-head attention layer of GPT-J on an RTX 3080Ti.

time, with backward (ms) peak vram allocated (mb)