Skip to content

Instantly share code, notes, and snippets.

from typing import List, Optional, Tuple
import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
NUM_HEADS = [(64, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16]