Created
August 3, 2023 14:01
-
-
Save tiandiao123/0b82ea31a5dc5865663c2966e369b05a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def max_fn(x, y): | |
return tl.math.max(x, y) | |
@triton.jit | |
def _flash_attention_alibi_fwd_kernel( | |
Q, K, V, | |
BIAS, | |
sm_scale, | |
L, | |
Out, | |
stride_qz, stride_qh, stride_qm, stride_qk, | |
stride_kz, stride_kh, stride_kn, stride_kk, | |
stride_vz, stride_vh, stride_vk, stride_vn, | |
stride_oz, stride_oh, stride_om, stride_on, | |
#bias strid | |
stride_bz, stride_bh, stride_bm, stride_bn, | |
# batch_size, num_heads, seq_len | |
Z, H, N_CTX, kv_length, | |
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
IS_CAUSAL: tl.constexpr, | |
): | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
# use q value | |
qvk_offset = off_hz * stride_qh | |
kv_offset_last2 = off_hz * stride_kh | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + qvk_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + kv_offset_last2, | |
shape=(BLOCK_DMODEL, kv_length), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_DMODEL, BLOCK_N), | |
order=(0, 1) | |
) | |
bias_block_ptr = None | |
if BIAS is not None: | |
bias_offset = off_hz * stride_bh | |
bias_block_ptr = tl.make_block_ptr( | |
base = BIAS + bias_offset, | |
shape = (BLOCK_M, kv_length), | |
strides=(stride_bm, stride_bn), | |
offsets = (start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_N), | |
order=(1, 0) | |
) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + kv_offset_last2, | |
shape=(kv_length, BLOCK_DMODEL), | |
strides=(stride_vk, stride_vn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | |
# scale sm_scale by log_2(e) and use | |
# 2^x instead of exp in the loop because CSE and LICM | |
# don't work as expected with `exp` in the loop | |
# qk_scale = sm_scale * 1.44269504 | |
qk_scale = sm_scale | |
# load q: it will stay in SRAM throughout | |
q = tl.load(Q_block_ptr) | |
q = (q * qk_scale).to(tl.float16) | |
# loop over k, v and update accumulator | |
lo = 0 | |
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else kv_length | |
for start_n in range(lo, hi, BLOCK_N): | |
# -- load k, v -- | |
k = tl.load(K_block_ptr) | |
v = tl.load(V_block_ptr) | |
# -- compute qk --- | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
qk += tl.dot(q, k) | |
if BIAS is not None: | |
bias = tl.load(bias_block_ptr) | |
qk += bias | |
if IS_CAUSAL: | |
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) | |
# -- compute scaling constant --- | |
m_i_new = tl.maximum(m_i, tl.max(qk, 1)) | |
alpha = tl.math.exp(m_i - m_i_new) | |
p = tl.math.exp(qk - m_i_new[:, None]) | |
# -- scale and update acc -- | |
acc_scale = l_i * 0 + alpha # workaround some compiler bug | |
acc *= acc_scale[:, None] | |
acc += tl.dot(p.to(tl.float16), v) | |
# -- update m_i and l_i -- | |
l_i = l_i * alpha + tl.sum(p, 1) | |
m_i = m_i_new | |
# update pointers | |
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) | |
if bias_block_ptr is not None: | |
bias_block_ptr = tl.advance(bias_block_ptr, (0, BLOCK_N)) | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
# write back l and m | |
acc = acc / l_i[:, None] | |
l_ptrs = L + off_hz * N_CTX + offs_m | |
tl.store(l_ptrs, m_i + tl.math.log(l_i)) | |
# write back O | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + qvk_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
tl.store(O_block_ptr, acc.to(tl.float16)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
请问有更新吗?