Skip to content

Instantly share code, notes, and snippets.

@am009
Last active February 25, 2026 05:44
Show Gist options
  • Select an option

  • Save am009/be749ce4b6133a2208366205e9d9728b to your computer and use it in GitHub Desktop.

Select an option

Save am009/be749ce4b6133a2208366205e9d9728b to your computer and use it in GitHub Desktop.
(For Turing GPUs or older GPUs), To enforce (force enable) memory efficient attention, use this diff to patch ~/.local/lib/python3.10/site-packages/transformers/integrations/sdpa_attention.py.
diff --git a/sdpa_attention.py b/sdpa_attention.py
index e2eb69b..56786e8 100644
--- a/sdpa_attention.py
+++ b/sdpa_attention.py
@@ -93,16 +93,22 @@ def sdpa_attention_forward(
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query,
- key,
- value,
- attn_mask=attention_mask,
- dropout_p=dropout,
- scale=scaling,
- is_causal=is_causal,
- **sdpa_kwargs,
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, None
+ if key.size(1) != query.size(1):
+ assert (query.size(1) % key.size(1) == 0)
+ repeat_factor = query.size(1) // key.size(1)
+ key = key.repeat_interleave(repeat_factor, dim=1)
+ value = value.repeat_interleave(repeat_factor, dim=1)
+
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=dropout,
+ scale=scaling,
+ is_causal=is_causal,
+ **sdpa_kwargs,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, None
@am009
Copy link
Copy Markdown
Author

am009 commented Jan 20, 2026

curl -fsSL https://gist.github.com/am009/be749ce4b6133a2208366205e9d9728b/raw/329d1b8e5e211c2646159381b5b7cc1d7a6e47ad/sdpa_attention.diff | patch ~/.local/lib/python3.10/site-packages/transformers/integrations/sdpa_attention.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment