Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active April 11, 2024 16:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scturtle/98816ec7828a4619234554b139bc012b to your computer and use it in GitHub Desktop.
Save scturtle/98816ec7828a4619234554b139bc012b to your computer and use it in GitHub Desktop.
flash attention v1 v2 in numpy
import numpy as np
N_inp = 64
N_out = 64
d = 128
Q = np.random.randn(N_out, d)
K = np.random.randn(N_inp, d)
V = np.random.randn(N_inp, d)
O = np.random.randn(N_out, d)
Bc = 16
Br = 16
Tc = (N_inp + Bc - 1) // Bc
Tr = (N_out + Br - 1) // Br
scale_factor = 1 / np.sqrt(Q.shape[-1])
L = np.zeros((N_out, 1))
M = np.full((N_out, 1), -np.inf)
for j in range(Tc):
Kj = K[j * Bc: (j + 1) * Bc]
Vj = V[j * Bc: (j + 1) * Bc]
for i in range(Tr):
Oi = O[i * Br: (i + 1) * Br]
li = L[i * Br: (i + 1) * Br]
mi = M[i * Br: (i + 1) * Br]
Qi = Q[i * Br: (i + 1) * Br]
Sij = scale_factor * (Qi @ Kj.T)
mij = np.max(Sij, axis=1, keepdims=True)
Pij = np.exp(Sij - mij)
lij = np.sum(Pij, axis=1, keepdims=True)
mi_new = np.maximum(mi, mij)
li_new = np.exp(mi - mi_new) * li + np.exp(mij - mi_new) * lij
Oi = (1.0 / li_new) * (li * np.exp(mi - mi_new) * Oi + np.exp(mij - mi_new) * (Pij @ Vj))
O[i * Br: (i + 1) * Br] = Oi
L[i * Br: (i + 1) * Br] = li_new
M[i * Br: (i + 1) * Br] = mi_new
S_ = scale_factor * Q @ K.T
P_ = np.exp(S_ - np.max(S_, axis=1, keepdims=True))
O_ = (P_ / np.sum(P_, axis=1, keepdims=True)) @ V
assert(np.allclose(O, O_))
import numpy as np
N_inp = 64
N_out = 64
d = 128
Q = np.random.randn(N_out, d)
K = np.random.randn(N_inp, d)
V = np.random.randn(N_inp, d)
O = np.random.randn(N_out, d)
L = np.zeros((N_out, 1))
Bc = 16
Br = 16
Tc = (N_inp + Bc - 1) // Bc
Tr = (N_out + Br - 1) // Br
scale_factor = 1 / np.sqrt(Q.shape[-1])
for i in range(Tr):
Qi = Q[i * Br: (i + 1) * Br]
Oi = np.zeros((Br, d))
li = np.zeros((Br, 1))
mi = np.full((Br, 1), -np.inf)
last_mi = mi
for j in range(Tc):
Kj = K[j * Bc: (j + 1) * Bc]
Vj = V[j * Bc: (j + 1) * Bc]
Si = scale_factor * (Qi @ Kj.T)
mi = np.maximum(mi, np.max(Si, axis=1, keepdims=True))
Pi = np.exp(Si - mi)
li = np.exp(last_mi - mi) * li + np.sum(Pi, axis=1, keepdims=True)
Oi = np.exp(last_mi - mi) * Oi + Pi @ Vj
last_mi = mi
Oi = Oi / li
O[i * Br: (i + 1) * Br] = Oi
L[i * Br: (i + 1) * Br] = li
S_ = scale_factor * Q @ K.T
P_ = np.exp(S_ - np.max(S_, axis=1, keepdims=True))
O_ = (P_ / np.sum(P_, axis=1, keepdims=True)) @ V
assert(np.allclose(O, O_))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment