Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Last active October 7, 2023 04:44
Show Gist options
  • Save KeremTurgutlu/847dd84519e28df85e68f8d88dc29905 to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/847dd84519e28df85e68f8d88dc29905 to your computer and use it in GitHub Desktop.
Multipack Sampler x Flash Attention
"""
Testing flash attn with multipacking which essentially packs sequences using https://github.com/imoneoi/multipack_sampler,
and passes a single sequence of `1 x (bs x seqlen)` to the model to avoid padding.
An alternative is to use block diagonal attention as attention bias, but the following uses flash attention 2 which
is much faster.
Multipacking can be used to speed up both pretraining and finetuning.
"""
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
# packed sequence to [bs x seqlen = 16]
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
# begin idx of each sequence in the pack (size: bsz/num samples + 1)
cu_seqlens = torch.tensor([0, 4, 7, 12, 14, 16]).to(torch.int32)
max_seqlen = 5
for i in cu_seqlens[:-1]: print(attn_mask[0][i])
#tensor(1)
#tensor(2)
#tensor(3)
#tensor(4)
#tensor(0)
# total size: bs x seqlen (after packing sequences)
bs_x_seqlen = 16
# create random qkv
qkv = torch.randn(
bs_x_seqlen, 3, 2, 128, device="cuda:0", dtype=torch.float16, requires_grad=True
)
# should use block diagonal attn mask
attn_output_flash = flash_attn_varlen_qkvpacked_func(
qkv.cuda(), cu_seqlens.cuda(), max_seqlen, 0.0, softmax_scale=None, causal=True
)
# create block diagonal attn mask manually for torch testing
attn_mask = make_decoder_mask_pt(position_ids, torch.int32, decoder_segment_ids=attn_mask)
# convert to attn bias (inverse of attn mask)
attn_bias = attn_mask2attn_bias(attn_mask)
attn_mask, attn_bias
# (tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]]],
# dtype=torch.int32),
# tensor([[[[ 0, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [ 0, 0, -128, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [ 0, 0, 0, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [ 0, 0, 0, 0, -128, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, 0, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, 0, 0, -128, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, 0, 0, 0, -128, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, 0, -128, -128, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, -128, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, 0, -128,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, 0, 0,
# -128, -128, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, 0, 0,
# 0, -128, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, 0, -128, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, 0, 0, -128, -128],
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, 0, -128],
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128,
# -128, -128, -128, 0, 0]]]], dtype=torch.int32))
# reshape q,k,v for torch scaled_dot_product_attention
q,k,v = qkv.unbind(1)
q,k,v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = rearrange(q, "b s h d -> b h s d")
k = rearrange(k, "b s h d -> b h s d")
v = rearrange(v, "b s h d -> b h s d")
attn_output_torch = F.scaled_dot_product_attention(
q.cuda(), k.cuda(), v.cuda(), attn_bias.to(q.dtype).cuda(), 0.0, is_causal=False
)
attn_output_torch = rearrange(attn_output_torch[0], "h s d -> s h d")
torch.isclose(attn_output_flash, attn_output_torch).float().mean()
# tensor(1., device='cuda:0')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment