Created
November 6, 2023 16:50
-
-
Save fxmarty/39e7c7566f660cdd6033e2d259f45948 to your computer and use it in GitHub Desktop.
FAv2 forward bench
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from flash_attn.utils.benchmark import benchmark_forward | |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as ck_flash_unpadded_func | |
from flash_attn.flash_attn_triton import flash_attn_func as triton_flash_unpadded_func | |
def attention_pytorch(q, k, v, dropout_p=0.0, causal=True): | |
""" | |
Arguments: | |
q: (batch_size, seqlen, nheads, head_dim) | |
k: (batch_size, seqlen, nheads, head_dim) | |
v: (batch_size, seqlen, nheads, head_dim) | |
dropout_p: float | |
Output: | |
output: (batch_size, seqlen, nheads, head_dim) | |
""" | |
batch_size, seqlen, nheads, d = q.shape | |
q = rearrange(q, 'b t h d -> (b h) t d') | |
k = rearrange(k, 'b s h d -> (b h) d s') | |
softmax_scale = 1.0 / math.sqrt(d) | |
# Preallocate attn_weights for `baddbmm` | |
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device) | |
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), | |
'(b h) t s -> b h t s', h=nheads) | |
if causal: | |
# "triu_tril_cuda_template" not implemented for 'BFloat16' | |
# So we have to construct the mask in float | |
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) | |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) | |
scores = scores + causal_mask.to(dtype=scores.dtype) | |
attention = torch.softmax(scores, dim=-1) | |
attention_drop = F.dropout(attention, dropout_p) | |
output = torch.einsum('bhts,bshd->bthd', attention_drop , v) | |
return output.to(dtype=q.dtype) | |
torch.manual_seed(0) | |
repeats = 30 | |
headdim = 128 | |
dropout_p = 0.0 | |
causal = True | |
dtype = torch.float16 | |
device = 'cuda' | |
for batch_size in [4, 8, 16]: | |
for seqlen in [1024, 4096, 8192]: | |
for nheads in [16, 24]: | |
if (nheads == 24 and batch_size in [8, 16] and seqlen == 8192) or (batch_size == 16 and nheads == 16 and seqlen == 8192): | |
continue # torch oom | |
print(f"------------ bs={batch_size}, seqlen={seqlen}, nheads={nheads}") | |
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) | |
k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) | |
v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) | |
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, | |
device=q.device) | |
max_seqlen = seqlen | |
q_flash = rearrange(q, 'b s ... -> (b s) ...') | |
k_flash = rearrange(q, 'b s ... -> (b s) ...') | |
v_flash = rearrange(q, 'b s ... -> (b s) ...') | |
benchmark_forward(ck_flash_unpadded_func, q_flash, k_flash, v_flash, | |
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p, None, True, repeats=repeats, desc='FlashAttention') | |
benchmark_forward(attention_pytorch, q, k, v, dropout_p, causal=causal, | |
repeats=repeats, desc='PyTorch Attention') | |
# benchmark_forward(triton_flash_unpadded_func, q, k, v, None, causal, repeats=repeats, desc='FlashAttention Triton') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment