Standard GEMM kernels in cuBLAS or CUTLASS operate on the matrix A
and B
of the same dtype. So we need to dequantize the weight into fp16 to apply the vendor-optimized fp16 GEMM kernels. This adds significant memory-bandwidth overhead.
NVIDIA developed a custom GEMM kernel that operates directly on the quantized B
matrix. It computes C_fp16 = A_fp16 * dequantize(B_int4)
in one, fused kernel. It was originally released as part of the FasterTransformer repo. We use the same kernel extracted into a self-contained repo.
For more info, see the paper and GTC 23 presentations
- https://arxiv.org/abs/2211.10017
- https://register.nvidia.com/flow/nvidia/gtcspring2023/attendeeportal/page/sessioncatalog/session/1666226207768001N4Fe
The FT kernel has been easily integrated in the branch via:
- A new quantization pass to prepare weights and scales as expected by the FT kernel
- Relax BYOC to offload the
decode -> matmul
sequence to the FT kernel
The integration has been validated (generates good tokens) on
- Vicuna
- Dolly v2 12B
- GPTJ-6B
- GPT2-XL
- Implemented two optimized compilation flows where the only difference is the matmul implementation (with a matching quantization scheme):
- FT matmul with rowwise quantization
- Dlight GEMV with group quantization
- Note that dlight is at an early stage of development and its perf changes rapidly at the moment. The data was taken using the commit https://github.com/apache/tvm/commit/0681b3959fe65abe8a73c49fa80d1c4f0e793b43
- Compares tok / sec for encoding 128 len prompt + generating 128 tokens.
- Also shows exllama numbers for reference (combined encode / decode time)
- All numbers on RTX 4080
Vicuna 7B, int4
seqlen | tok / s | |
---|---|---|
FT | 128 | 121.7 |
dlight | 128 | 141.9 |
exllama | 128 | 121-122 |
This looks disappointing… but we can understand this result as follows:
- In the decoder, the core computation is
vector x matrix
, GEMV. For example, in Vicuna 7B the vector is of shape(1, 1, 4096)
. - The FT kernel uses tensor core to compute
vector x matrix
via GEMM. - On the Ampere architecture, the tensor core instruction operates on the
A
matrix with a tile size of16 x 8
. - So we end up doing zero padding to create
16 x 8
tiles from1 x 8
for the matrixA
. 15 out of 16 rows-worth of tensor core computation are wasted. - On the other hand, dlight and exllama implement a GEMV kernel. No wasteful computation, simple and lean.
Comparing decode profile at seqlen 128
Using FT GEMM
Time (ms) Count Total time (ms) Percentage (%)
fused_decode3_relax_matmul_cutlass 0.0769 32 2.4594 31.57
fused_decode4_relax_matmul_relax_add_cutlass 0.0684 32 2.1889 28.09
fused_decode1_relax_matmul_cutlass 0.0519 32 1.6613 21.32
fused_decode2_relax_matmul_relax_add_cutlass 0.0273 32 0.8749 11.23
fused_relax_nn_attention_cutlass1 0.0100 32 0.3208 4.12
fused_rms_norm1_cutlass 0.0020 65 0.1323 1.70
fused_decode5_relax_matmul_cutlass 0.1046 1 0.1046 1.34
fused_split_silu_multiply 0.0014 32 0.0456 0.58
take_decode1 0.0020 1 0.0020 0.03
cast
Total time: 7.7913 ms
Using Dlight GEMV
Time (ms) Count Total time (ms) Percentage (%)
fused_fused_decode4_NT_matmul2 0.0721 32 2.3069 38.99
fused_fused_decode2_NT_matmul 0.0419 32 1.3404 22.66
fused_fused_decode5_fused_NT_matmul3_add 0.0367 32 1.1731 19.83
fused_fused_decode3_fused_NT_matmul1_add 0.0153 32 0.4910 8.30
fused_relax_nn_attention_cutlass1 0.0099 32 0.3178 5.37
fused_rms_norm1_cutlass 0.0019 65 0.1265 2.14
fused_fused_decode1_fused_NT_matmul4_cast 0.1132 1 0.1132 1.91
fused_split_silu_multiply 0.0014 32 0.0459 0.78
fused_fused_decode1_take 0.0018 1 0.0018 0.03
Total time: 5.9166 ms
FT is the fastest for other models, despite the wasted computation.
Vicuna 7B, int8
seqlen | tok / s | |
---|---|---|
FT | 128 | 83.7 |
dlight | 128 | 79.2 |
Vicuna 13B, int4
seqlen | tok / s | |
---|---|---|
FT | 128 | 74.6 |
dlight | 128 | 67.8 |
exllama | 128 | 68.5 |
Dolly v2 12B, int4
seqlen | tok / s | |
---|---|---|
FT | 128 | 65.5 |
dlight | 128 | 44.6 |
Since the FT kernel is based on tensor core matmul, its performance is expected to get a lot better with batched inference. In particular, there is no wasted zero padding for a batch size which is a multiple of 16.
For a batch size of 16, FT gets only 1.5 lower tok / s per batch compared to the single batch case.
Vicuna 7B
mode | batch | seqlen | tok / s (per batch) |
---|---|---|---|
FT | 1 | 128 | 121.7 |
FT | 16 | 128 | 77.6 |
dlight | 1 | 128 | 141.9 |
dlight | 16 | 128 | 25.9 |
Note: The way KV cache is organized in the current mlc-llm
necessitates expensive runtime transpositions of the batched KV cache at each decoding step. This kills batched perf as shown below.
The above numbers are computed without this transpose cost.
-
Batched decode profile at seqlen 128
======================= Decoding Profiling ======================= Name Time (ms) Count Total time (ms) Percentage (%) transpose3 0.1829 64 11.7062 52.71 fused_decode3_relax_matmul1_cutlass 0.0863 32 2.7623 12.44 fused_decode4_relax_matmul_relax_add1_cutlass 0.0760 32 2.4314 10.95 fused_relax_nn_attention1_cutlass1 0.0599 32 1.9168 8.63 fused_decode1_relax_matmul1_cutlass 0.0583 32 1.8645 8.40 fused_decode2_relax_matmul_relax_add1_cutlass 0.0300 32 0.9601 4.32 fused_rms_norm1_cutlass 0.0028 65 0.1797 0.81 split2 0.0049 32 0.1566 0.70 fused_decode5_relax_matmul_cutlass 0.1250 1 0.1250 0.56 fused_split3_silu1_multiply1 0.0032 32 0.1012 0.46 cast 0.0035 1 0.0035 0.02 take_decode 0.0020 1 0.0020 0.01 Total time: 22.2093 ms
- Combining parallel matmuls in QKV projections and MLP - 20 tok / s improvement
- CUDA Graph - Putting MLP and QKV projections into a graph gives 1 tok / s improvement
- Attention - Use the CUTLASS attention kernel with causal mask optimization
For attention, there is a further room for improvement:
- After a combined QKV matmul, we have
split - > rotary embedding -> KV cache update -> attention
- The CUTLASS kernel supports fusing
split
intoattention
. Ideally we want to fuse all ofsplit
,rotary
,KV cache update
into attention. - Such kernel exists in FasterTransformer https://github.com/NVIDIA/FasterTransformer/tree/main/src/fastertransformer/kernels/decoder_masked_multihead_attention
- Moreover, since it is specialized for the single-query case, it does two matmuls with GEMV - Another reason this attention impl is expected to be faster than the standard attention implementations that use tensor core for matmul.
Hi masahi, I note that you mentioned the int4 implementation in FT for Llama. Could you tell me where you find the implementaion?
I found a PR implementing llama using fp16, it's here, but no int4 supported.
Hope for your reply, thank you!