Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created November 1, 2023 07:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/43329991ef6f0473ee25333ad78cc3ab to your computer and use it in GitHub Desktop.
Save sayakpaul/43329991ef6f0473ee25333ad78cc3ab to your computer and use it in GitHub Desktop.
import xformers.ops
import torch.nn.functional as F
import torch
# xformers: 1, 4096, 16, 72
# SDPA: 1, 16, 4096, 72
q = torch.randn(1, 4096, 16, 72, generator=torch.manual_seed(0)).cuda()
k = torch.randn(1, 4096, 16, 72, generator=torch.manual_seed(1)).cuda()
v = torch.randn(1, 4096, 16, 72, generator=torch.manual_seed(2)).cuda()
query = q.transpose(1, 2)
key = k.transpose(1, 2)
value = v.transpose(1, 2)
x_xformers = xformers.ops.memory_efficient_attention(q, k, v, p=0.0)
x_sdpa = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
).transpose(1, 2)
print(torch.allclose(x_xformers, x_sdpa))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment