Skip to content

Instantly share code, notes, and snippets.

@masahi
Last active August 4, 2023 07:28
Show Gist options
  • Save masahi/079a72120cd54fbb0d3ebf29a751fed1 to your computer and use it in GitHub Desktop.
Save masahi/079a72120cd54fbb0d3ebf29a751fed1 to your computer and use it in GitHub Desktop.

LLM inference with FasterTransformer FP16 A - Int4 B GEMM kernel

What is it?

Standard GEMM kernels in cuBLAS or CUTLASS operate on the matrix A and B of the same dtype. So we need to dequantize the weight into fp16 to apply the vendor-optimized fp16 GEMM kernels. This adds significant memory-bandwidth overhead.

NVIDIA developed a custom GEMM kernel that operates directly on the quantized B matrix. It computes C_fp16 = A_fp16 * dequantize(B_int4) in one, fused kernel. It was originally released as part of the FasterTransformer repo. We use the same kernel extracted into a self-contained repo.

For more info, see the paper and GTC 23 presentations

Integration into mlc-llm

The FT kernel has been easily integrated in the branch via:

  • A new quantization pass to prepare weights and scales as expected by the FT kernel
  • Relax BYOC to offload the decode -> matmul sequence to the FT kernel

The integration has been validated (generates good tokens) on

  • Vicuna
  • Dolly v2 12B
  • GPTJ-6B
  • GPT2-XL

Perf results on Vicuna and Dolly

  • Implemented two optimized compilation flows where the only difference is the matmul implementation (with a matching quantization scheme):
    • FT matmul with rowwise quantization
    • Dlight GEMV with group quantization
  • Note that dlight is at an early stage of development and its perf changes rapidly at the moment. The data was taken using the commit https://github.com/apache/tvm/commit/0681b3959fe65abe8a73c49fa80d1c4f0e793b43
  • Compares tok / sec for encoding 128 len prompt + generating 128 tokens.
  • Also shows exllama numbers for reference (combined encode / decode time)
  • All numbers on RTX 4080

Vicuna 7B, int4

seqlen tok / s
FT 128 121.7
dlight 128 141.9
exllama 128 121-122

This looks disappointing… but we can understand this result as follows:

  • In the decoder, the core computation is vector x matrix, GEMV. For example, in Vicuna 7B the vector is of shape (1, 1, 4096).
  • The FT kernel uses tensor core to compute vector x matrix via GEMM.
  • On the Ampere architecture, the tensor core instruction operates on the A matrix with a tile size of 16 x 8.
  • So we end up doing zero padding to create 16 x 8 tiles from 1 x 8 for the matrix A. 15 out of 16 rows-worth of tensor core computation are wasted.
  • On the other hand, dlight and exllama implement a GEMV kernel. No wasteful computation, simple and lean.

Comparing decode profile at seqlen 128

Using FT GEMM

                                                               Time (ms)   Count   Total time (ms)   Percentage (%)
fused_decode3_relax_matmul_cutlass                               0.0769      32      2.4594            31.57
fused_decode4_relax_matmul_relax_add_cutlass                     0.0684      32      2.1889            28.09
fused_decode1_relax_matmul_cutlass                               0.0519      32      1.6613            21.32
fused_decode2_relax_matmul_relax_add_cutlass                     0.0273      32      0.8749            11.23
fused_relax_nn_attention_cutlass1                                0.0100      32      0.3208            4.12
fused_rms_norm1_cutlass                                          0.0020      65      0.1323            1.70
fused_decode5_relax_matmul_cutlass                               0.1046      1       0.1046            1.34
fused_split_silu_multiply                                        0.0014      32      0.0456            0.58
take_decode1                                                     0.0020      1       0.0020            0.03
cast
Total time: 7.7913 ms

Using Dlight GEMV

                                                               Time (ms)   Count   Total time (ms)   Percentage (%)
fused_fused_decode4_NT_matmul2                                   0.0721      32      2.3069            38.99
fused_fused_decode2_NT_matmul                                    0.0419      32      1.3404            22.66
fused_fused_decode5_fused_NT_matmul3_add                         0.0367      32      1.1731            19.83
fused_fused_decode3_fused_NT_matmul1_add                         0.0153      32      0.4910            8.30
fused_relax_nn_attention_cutlass1                                0.0099      32      0.3178            5.37
fused_rms_norm1_cutlass                                          0.0019      65      0.1265            2.14
fused_fused_decode1_fused_NT_matmul4_cast                        0.1132      1       0.1132            1.91
fused_split_silu_multiply                                        0.0014      32      0.0459            0.78
fused_fused_decode1_take                                         0.0018      1       0.0018            0.03
Total time: 5.9166 ms

FT is the fastest for other models, despite the wasted computation.

Vicuna 7B, int8

seqlen tok / s
FT 128 83.7
dlight 128 79.2

Vicuna 13B, int4

seqlen tok / s
FT 128 74.6
dlight 128 67.8
exllama 128 68.5

Dolly v2 12B, int4

seqlen tok / s
FT 128 65.5
dlight 128 44.6

Batched inference performance

Since the FT kernel is based on tensor core matmul, its performance is expected to get a lot better with batched inference. In particular, there is no wasted zero padding for a batch size which is a multiple of 16.

For a batch size of 16, FT gets only 1.5 lower tok / s per batch compared to the single batch case.

Vicuna 7B

mode batch seqlen tok / s (per batch)
FT 1 128 121.7
FT 16 128 77.6
dlight 1 128 141.9
dlight 16 128 25.9

Note: The way KV cache is organized in the current mlc-llm necessitates expensive runtime transpositions of the batched KV cache at each decoding step. This kills batched perf as shown below.

The above numbers are computed without this transpose cost.

  • Batched decode profile at seqlen 128

    ======================= Decoding Profiling =======================
    Name                                                             Time (ms)   Count   Total time (ms)   Percentage (%)
    transpose3                                                       0.1829      64      11.7062           52.71
    fused_decode3_relax_matmul1_cutlass                              0.0863      32      2.7623            12.44
    fused_decode4_relax_matmul_relax_add1_cutlass                    0.0760      32      2.4314            10.95
    fused_relax_nn_attention1_cutlass1                               0.0599      32      1.9168            8.63
    fused_decode1_relax_matmul1_cutlass                              0.0583      32      1.8645            8.40
    fused_decode2_relax_matmul_relax_add1_cutlass                    0.0300      32      0.9601            4.32
    fused_rms_norm1_cutlass                                          0.0028      65      0.1797            0.81
    split2                                                           0.0049      32      0.1566            0.70
    fused_decode5_relax_matmul_cutlass                               0.1250      1       0.1250            0.56
    fused_split3_silu1_multiply1                                     0.0032      32      0.1012            0.46
    cast                                                             0.0035      1       0.0035            0.02
    take_decode                                                      0.0020      1       0.0020            0.01
    Total time: 22.2093 ms
    

Other optimizations

  • Combining parallel matmuls in QKV projections and MLP - 20 tok / s improvement
  • CUDA Graph - Putting MLP and QKV projections into a graph gives 1 tok / s improvement
  • Attention - Use the CUTLASS attention kernel with causal mask optimization

For attention, there is a further room for improvement:

  • After a combined QKV matmul, we have split - > rotary embedding -> KV cache update -> attention
  • The CUTLASS kernel supports fusing split into attention. Ideally we want to fuse all of split, rotary, KV cache update into attention.
  • Such kernel exists in FasterTransformer https://github.com/NVIDIA/FasterTransformer/tree/main/src/fastertransformer/kernels/decoder_masked_multihead_attention
  • Moreover, since it is specialized for the single-query case, it does two matmuls with GEMV - Another reason this attention impl is expected to be faster than the standard attention implementations that use tensor core for matmul.
@sleepwalker2017
Copy link

sleepwalker2017 commented Aug 4, 2023

Hi masahi, I note that you mentioned the int4 implementation in FT for Llama. Could you tell me where you find the implementaion?

I found a PR implementing llama using fp16, it's here, but no int4 supported.

Hope for your reply, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment