Relevant PRs
- cudnn backward
- flash attention backward (this works)
Summary
We are implementing variable length attention with cuDNN backend and the outputs between our API and SDPA with packing doesn’t match after the backward.
In the provided repro, we included the definition of _varlen_attn(), our private custom op that calls into _cudnn_attention_forward(). We also define _backward(), the backward pass that is registered with torch.autograd(). This calls _cudnn_attention_backward().