Skip to content

Instantly share code, notes, and snippets.

@xiabingquan
Last active June 27, 2024 02:59
Show Gist options
  • Save xiabingquan/a4a9a743f97aadd531ed6218be20afd2 to your computer and use it in GitHub Desktop.
Save xiabingquan/a4a9a743f97aadd531ed6218be20afd2 to your computer and use it in GitHub Desktop.
An toy example of flash attention implemented with Numpy.
# A minimal exmaple of flash attention implemented in Numpy
# Contact: bingquanxia AT qq.com
import unittest
from typing import List
import numpy as np
import torch
class SoftMax(object):
"""
Softmax in Numpy. A naive implementation.
"""
def forward(self, x: List[float]):
# loop 1: get the maximum value
max_x = -np.inf
for t in x:
max_x = t if t > max_x else max_x
# loop 2: get the accumulative sum of exp(x_i - x_max)
accum_exp = 0.
for t in x:
accum_exp += np.exp(t - max_x)
# loop 3: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp`
output = [0. for _ in range(len(x))]
for i, t in enumerate(x):
output[i] = np.exp(t - max_x) / accum_exp
return output
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class SoftMaxWithTiling(object):
"""
Softmax with tiling in Numpy. A naive implementation.
"""
def forward(self, x: List[float]):
# loop 1: get the maximum value of x and the accumulated exponential values
max_x = -np.inf
accum_exp = 0.
for t in x:
max_x_new = t if t > max_x else max_x
accum_exp = np.exp(max_x - max_x_new) * accum_exp + np.exp(t - max_x_new)
max_x = max_x_new
# loop 2: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp`
out = [0. for _ in range(len(x))]
for i, t in enumerate(x):
out[i] = np.exp(t - max_x) / accum_exp
return out
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class SoftMaxTest(unittest.TestCase):
"""
Unit test for SoftMax and SoftMaxWithTiling.
"""
def test_softmax(self):
n_test = 10
for _ in range(n_test):
n_elem = np.random.randint(1, 11)
x = np.random.randn(n_elem).tolist()
expected = torch.nn.functional.softmax(torch.tensor(x), dim=-1).tolist()
out = SoftMax()(x)
self.assertTrue(np.allclose(expected, out, atol=1e-4))
out_with_tiling = SoftMaxWithTiling()(x)
self.assertTrue(np.allclose(expected, out_with_tiling, atol=1e-4))
class StandardAttention(object):
def __init__(self) -> None:
"""
Attention module implemented in Numpy.
Formula:
P = QK^T
S = softmax(P / sqrt(d_k))
O = SV
Reference:
<<Attention Is All You Need>>
URL:
https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
"""
pass
def _validity_check(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> None:
assert q.ndim == 3, "q should be a 3D tensor" # [batch_size, seq_len, hidden_size]
assert k.ndim == 3, "k should be a 3D tensor"
assert v.ndim == 3, "v should be a 3D tensor"
assert q.shape[0] == k.shape[0], "batch_size of q and k should be the same"
assert q.shape[2] == k.shape[2], "hidden_size of q and k should be the same"
assert q.shape[2] == v.shape[2], "hidden_size of q and v should be the same"
def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> np.ndarray:
self._validity_check(q, k, v)
batch_size, q_len, hidden_size = q.shape
denom = np.sqrt(hidden_size)
attn = np.matmul(q, k.transpose(0, 2, 1)) # [batch_size, q_len, k_len]
attn = np.exp((attn - attn.max(axis=-1, keepdims=True)) / denom)
attn = attn / attn.sum(axis=-1, keepdims=True)
out = np.matmul(attn, v) # [batch_size, q_len, hidden_size]
return out
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def self_attention(x):
return StandardAttention()(x, x, x)
class FlashAttention(object):
def __init__(self, row_block_size: int, col_block_size: int) -> None:
"""
Flash Attention in Numpy.
Reference:
<<FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness>>
https://proceedings.neurips.cc/paper_files/paper/2022/file/67d57c32e20fd0a7a302cb81d36e40d5-Paper-Conference.pdf
row_block_size: block size of query
col_block_size: block size of key
"""
# (Line 1): set the block size.
# We manually set the block size for query and key, respectively, since we do not know the on-chip SRAM size of GPU.
self.row_block_size = row_block_size
self.col_block_size = col_block_size
def _validity_check(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> None:
assert q.ndim == 3, "q should be a 3D tensor" # [batch_size, seq_len, hidden_size]
assert k.ndim == 3, "k should be a 3D tensor"
assert v.ndim == 3, "v should be a 3D tensor"
assert q.shape[0] == k.shape[0], "batch_size of q and k should be the same"
assert q.shape[2] == k.shape[2] == v.shape[2], "hidden_size of q, k and v should be the same"
assert q.shape[1] % self.row_block_size == 0 and k.shape[1] % self.col_block_size == 0, \
"seq_len should be divisible by block_size"
@staticmethod
def load(arr, st, ed, step):
"""Simulate the process that moves data from HBM to SRAM"""
return arr[:, st * step: ed * step]
@staticmethod
def write(arr, val, st, ed, step):
"""Simulate the process that moves data from SRAM to HBM"""
arr[:, st * step: ed * step] = val
def forward(self, q, k, v):
"""
The following implementation strictly follows the Algorithm 1 in the paper of FLASH-ATTENTION.
Except that we put it in a batched way, i.e. the batch_size is the first dimension of q, k, v.
Algorithm 1 is on the 5th page of the orginal paper of FLASH-ATTENTION.
"""
self._validity_check(q, k, v)
batch_size, q_len, hidden_size = q.shape
k_len = k.shape[1]
# (Line 2): initialize O, l and m
# O: output, will be updated in a row-block-wise manner
out = np.zeros((batch_size, q_len, hidden_size))
# l: exp-sum of each row block, will be the denominator in softmax.
# l will be updated in a exponential moving average way.
l = np.zeros((batch_size, q_len))
# m: max of each row block, will be part of the numerator in softmax.
# m will also be updated in a exponential moving average way.
m = np.zeros((batch_size, q_len))
m.fill(-np.inf)
# (Line 3): divide q into row blocks and k, v into column blocks
Tr = q_len // self.row_block_size # Tr: number of row blocks
Tc = k_len // self.col_block_size # Tc: number of column blocks
# (Line 4): pass. We do not need to explicitly split the output into row blocks,
# but we will update the output in a row-block-wise manner to simulate the process of FLASH-ATTENTION.
# (Line 5): iterate over column blocks
for j in range(Tc):
# (Line 6), load the key and value block
# kj: key block, [batch_size, col_block_size, hidden_size]
# vj: value block, [batch_size, col_block_size, hidden_size]
kj = self.load(k, j, j + 1, self.col_block_size)
vj = self.load(v, j, j + 1, self.col_block_size)
# (Line 7): iterate over row blocks
for i in range(Tr):
# (Line 8): load the query block. [batch_size, row_block_size, hidden_size]
qi = self.load(q, i, i + 1, self.row_block_size)
oi = self.load(out, i, i + 1, self.row_block_size)
mi = self.load(m, i, i + 1, self.row_block_size)
li = self.load(l, i, i + 1, self.row_block_size)
# (Line 9): compute the dot-product attention score
sij = np.matmul(qi, kj.transpose(0, 2, 1)) / np.sqrt(hidden_size)
# (Line 10): compute max, softmax, and exp-sum
mij = np.max(sij, axis=-1) # [batch_size, row_block_size]
pij = np.exp((sij - mij[..., np.newaxis])) # [batch_size, row_block_size, col_block_size]
lij = pij.sum(axis=-1) # [batch_size, row_block_size]
# (Line 11): update m and l
# 11.a. update m, the max of each row block
m_new = np.maximum.reduce([mi, mij])
# 11.b. update l, the accumulated exp-sum of each row block
l_new = np.exp(mi - m_new) * li + np.exp(mij - m_new) * lij
# (Line 12): update output
temp = li[..., np.newaxis] * np.exp(mi - m_new)[..., np.newaxis] * oi + np.exp(mij - m_new)[..., np.newaxis] * np.matmul(pij, vj)
temp /= l_new[..., np.newaxis]
self.write(out, temp, i, i + 1, self.row_block_size)
# (Line 13): store the m and l of current row block to the global m and l
self.write(m, m_new, i, i + 1, self.row_block_size)
self.write(l, l_new, i, i + 1, self.row_block_size)
return out
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class FlashAttentionTest(unittest.TestCase):
def run_test(self, batch_size, q_len, k_len, hidden_size, row_block_size, col_block_size):
# generate random inputs
q = np.random.randn(batch_size, q_len, hidden_size)
k = np.random.randn(batch_size, k_len, hidden_size)
v = np.random.randn(batch_size, k_len, hidden_size)
# standard attention
standard_out = StandardAttention()(q, k, v)
eps = 1e-8
# scaled_dot_product_attention of PyTorch
torch_out = torch.nn.functional.scaled_dot_product_attention(*map(torch.from_numpy, [q, k, v]))
self.assertTrue(np.allclose(standard_out, torch_out.numpy(), atol=eps))
# flash attention
attn = FlashAttention(row_block_size=row_block_size, col_block_size=col_block_size)
flash_out = attn(q, k, v)
self.assertTrue(np.allclose(standard_out, flash_out, atol=eps))
def test(self):
n_test = 2
batch_size = 2
for row_block_size in (2, 4):
for col_block_size in (4, 8):
for factor in (10, 20):
q_len = row_block_size * factor
k_len = col_block_size * factor
for _ in range(n_test):
for hidden_size in (8, 16, 32):
self.run_test(batch_size, q_len, k_len, hidden_size, row_block_size, col_block_size)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment