Skip to content

Instantly share code, notes, and snippets.

@proger
Last active April 26, 2024 07:53
Show Gist options
  • Save proger/b521fa038351689729d6a9929fcabefb to your computer and use it in GitHub Desktop.
Save proger/b521fa038351689729d6a9929fcabefb to your computer and use it in GitHub Desktop.
# prompt: https://twitter.com/francoisfleuret/status/1783479122418716805
import os
os.environ['TORCH_LOGS'] = 'output_code' # shows all the bmms
import torch
torch.set_float32_matmul_precision('high')
N, T, D, U, C = 3, 128, 5, 32, 32 # batch, time, heads, head_dim, dim
S = T
A = torch.randn(N, T, D, U) / U**0.5
B = torch.randn(N, D, U, S) / U**0.5
V = torch.randn(N, S, C) / C**0.5
@torch.compile
def notscan(A, B, V):
ABV = V.new_zeros(N, T, C)
V = V.unsqueeze(1) # N1SC
for i in range(D):
a = A[:,:,i,:] # NTU
b = B[:,[i],:,:] # N1US
bv = torch.matmul(b, V) # N1US, N1SC -> N1UC
abv = torch.matmul(a, bv.squeeze(1)) # NTU, NUC -> NTC
ABV.add_(abv)
return ABV
assert notscan(A.cuda(), B.cuda(), V.cuda()).shape == (N, T, C)
@soumith
Copy link

soumith commented Apr 26, 2024

the new_empty should be new_zeros, or else you're adding to uninitialized memory.

@proger
Copy link
Author

proger commented Apr 26, 2024

Thanks @soumith for the catch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment