Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 3, 2023 14:01
Show Gist options
  • Save tiandiao123/0b82ea31a5dc5865663c2966e369b05a to your computer and use it in GitHub Desktop.
Save tiandiao123/0b82ea31a5dc5865663c2966e369b05a to your computer and use it in GitHub Desktop.
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def _flash_attention_alibi_fwd_kernel(
Q, K, V,
BIAS,
sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
#bias strid
stride_bz, stride_bh, stride_bm, stride_bn,
# batch_size, num_heads, seq_len
Z, H, N_CTX, kv_length,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# use q value
qvk_offset = off_hz * stride_qh
kv_offset_last2 = off_hz * stride_kh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + kv_offset_last2,
shape=(BLOCK_DMODEL, kv_length),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
bias_block_ptr = None
if BIAS is not None:
bias_offset = off_hz * stride_bh
bias_block_ptr = tl.make_block_ptr(
base = BIAS + bias_offset,
shape = (BLOCK_M, kv_length),
strides=(stride_bm, stride_bn),
offsets = (start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0)
)
V_block_ptr = tl.make_block_ptr(
base=V + kv_offset_last2,
shape=(kv_length, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
# qk_scale = sm_scale * 1.44269504
qk_scale = sm_scale
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else kv_length
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
if BIAS is not None:
bias = tl.load(bias_block_ptr)
qk += bias
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp(m_i - m_i_new)
p = tl.math.exp(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_block_ptr is not None:
bias_block_ptr = tl.advance(bias_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
l_ptrs = L + off_hz * N_CTX + offs_m
tl.store(l_ptrs, m_i + tl.math.log(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(O_block_ptr, acc.to(tl.float16))
@shiqingzhangCSU
Copy link

请问有更新吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment