Skip to content

Instantly share code, notes, and snippets.

@bjacob
Last active June 30, 2023 14:56
Show Gist options
  • Save bjacob/5ffaf3508a61ff6ac22ca98031be1629 to your computer and use it in GitHub Desktop.
Save bjacob/5ffaf3508a61ff6ac22ca98031be1629 to your computer and use it in GitHub Desktop.
Example of data-tiling and microkernels on a matmul.

Example of data-tiling and microkernels on a matmul.

Basic setup.

Source: matmul_i8.mlir:

func.func @matmul_i8(%lhs: tensor<?x?xi8>, %rhs: tensor<?x?xi8>, %acc: tensor<?x?xi32>) -> tensor<?x?xi32> {
  %result = linalg.matmul ins(%lhs, %rhs: tensor<?x?xi8>, tensor<?x?xi8>) outs(%acc: tensor<?x?xi32>) -> tensor<?x?xi32>
  return %result: tensor<?x?xi32>
}

Compilation command line targeting aarch64 with the "+i8mm" target CPU feature, enabling data tiling and microkernels:

iree-compile /tmp/matmul_i8.mlir -o /tmp/a.vmfb \
  --iree-hal-target-backends=llvm-cpu \
  --iree-llvmcpu-target-triple=aarch64-unknown-unknown \
  --iree-llvmcpu-target-cpu-features=+i8mm \
  --iree-flow-enable-data-tiling \
  --iree-llvmcpu-enable-microkernels

Step-by-step look at IR

Add the following flag to iree-compile: --mlir-print-ir-after-all

The first interesting pass is SetEncoding. It is enabled by --iree-flow-enable-data-tiling. It runs early (Flow) and is target-agnostic. All it does is annotate LHS/RHS/Accumulator tensors with their role in matmuls (or any other op participating in data-tiling, but at the moment that's only matmuls). These roles are called "encodings", and the set_encoding op can be thought of as an abstract version of tensor.pack (it will be rewritten into tensor.pack later down in MaterializeEncoding), and likewise unset_encoding is an abstract version of tensor.unpack and will get rewritten into that.

// -----// IR Dump After SetEncoding (iree-flow-set-encoding) //----- //
func.func @matmul_i8(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {

  // ... (skipping uninteresting lines) ...

  %15 = iree_linalg_ext.set_encoding %padded : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>
  %16 = iree_linalg_ext.set_encoding %padded_3 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>
  %17 = iree_linalg_ext.set_encoding %padded_6 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>
  %18 = linalg.matmul ins(%15, %16 : tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>, tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>) outs(%17 : tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>
  %19 = iree_linalg_ext.unset_encoding %18 : tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>> -> tensor<?x?xi32>
  %dim_7 = tensor.dim %8, %c0 : tensor<?x?xi32>
  %dim_8 = tensor.dim %8, %c1 : tensor<?x?xi32>
  %extracted_slice = tensor.extract_slice %19[0, 0] [%dim_7, %dim_8] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
  %20 = hal.tensor.export %extracted_slice "output 0" : tensor<?x?xi32>{%6, %7} -> !hal.buffer_view
  return %20 : !hal.buffer_view
}

Next, dispatch regions are formed by the OutlineDispatchRegions pass, still part of Flow. Because of the set_encoding and unset_encoding ops that were created earlier by the SetEncoding pass, we are going to split this matmul work into multiple dispatches.

// -----// IR Dump After OutlineDispatchRegions (iree-flow-outline-dispatch-regions) //----- //
func.func @matmul_i8_dispatch_0_set_encoding_MATMUL_I8I8I32_LHS_DxD(%arg0: !flow.dispatch.tensor<readonly:tensor<?x?xi8>>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: !flow.dispatch.tensor<writeonly:tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>>) {
    // ... (skipping uninteresting lines) ...
    %9 = iree_linalg_ext.set_encoding %padded : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.
    // ... (skipping uninteresting lines) ...
}
func.func @matmul_i8_dispatch_1_set_encoding_MATMUL_I8I8I32_RHS_DxD(%arg0: !flow.dispatch.tensor<readonly:tensor<?x?xi8>>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: !flow.dispatch.tensor<writeonly:tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>>) {
    // ... (skipping uninteresting lines) ...
    %9 = iree_linalg_ext.set_encoding %padded : tensor<?x?xi8> -> tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>
    // ... (skipping uninteresting lines) ...
}
func.func @matmul_i8_dispatch_2_set_encoding_MATMUL_I8I8I32_RESULT_DxD(%arg0: !flow.dispatch.tensor<readonly:tensor<?x?xi32>>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: !flow.dispatch.tensor<writeonly:tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>>) {
    // ... (skipping uninteresting lines) ...
    %9 = iree_linalg_ext.set_encoding %padded : tensor<?x?xi32> -> tensor<?x?xi32, #iree_linalg_ext.
    // ... (skipping uninteresting lines) ...
}
func.func @matmul_i8_dispatch_3_matmul_DxDxD_i8xi8xi32(%arg0: !flow.dispatch.tensor<readonly:tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>>, %arg1: !flow.dispatch.tensor<readonly:tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>>, %arg2: !flow.dispatch.tensor<readwrite:tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>>, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index) {
    // ... (skipping uninteresting lines) ...
    %12 = linalg.matmul ins(%9, %10 : tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>, tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>) outs(%11 : tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>) -> tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>
    // ... (skipping uninteresting lines) ...
}
func.func @matmul_i8_dispatch_4_unset_encoding_MATMUL_I8I8I32_RESULT_DxD(%arg0: !flow.dispatch.tensor<readonly:tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>) {
    // ... (skipping uninteresting lines) ...
    %7 = iree_linalg_ext.unset_encoding %6 : tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>> -> tensor<?x?xi32>
    // ... (skipping uninteresting lines) ...
}

// ... (skipping uninteresting lines) ...

func.func @matmul_i8(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    // ... (skipping uninteresting lines) ...
    %15 = flow.dispatch @matmul_i8_dispatch_0::@matmul_i8_dispatch_0_set_encoding_MATMUL_I8I8I32_LHS_DxD[%0, %1, %10, %9](%2, %0, %1, %10, %9) : (tensor<?x?xi8>{%0, %1}, index, index, index, index) -> tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>{%10, %9}
    %16 = flow.dispatch @matmul_i8_dispatch_1::@matmul_i8_dispatch_1_set_encoding_MATMUL_I8I8I32_RHS_DxD[%3, %4, %12, %11](%5, %3, %4, %12, %11) : (tensor<?x?xi8>{%3, %4}, index, index, index, index) -> tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>{%12, %11}
    %17 = flow.dispatch @matmul_i8_dispatch_2::@matmul_i8_dispatch_2_set_encoding_MATMUL_I8I8I32_RESULT_DxD[%6, %7, %14, %13](%8, %6, %7, %14, %13) : (tensor<?x?xi32>{%6, %7}, index, index, index, index) -> tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>{%14, %13}
    %18 = flow.dispatch @matmul_i8_dispatch_3::@matmul_i8_dispatch_3_matmul_DxDxD_i8xi8xi32[%10, %9, %12, %11, %14, %13](%15, %16, %17, %10, %9, %12, %11, %14, %13) : (tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_LHS>>{%10, %9}, tensor<?x?xi8, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RHS>>{%12, %11}, tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>{%14, %13}, index, index, index, index, index, index) -> %17{%14, %13}
    %19 = flow.dispatch @matmul_i8_dispatch_4::@matmul_i8_dispatch_4_unset_encoding_MATMUL_I8I8I32_RESULT_DxD[%14, %13, %6, %7](%18, %14, %13, %6, %7) : (tensor<?x?xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>>{%14, %13}, index, index, index, index) -> tensor<?x?xi32>{%6, %7}
    %20 = hal.tensor.export %19 "output 0" : tensor<?x?xi32>{%6, %7} -> !hal.buffer_view
    return %20 : !hal.buffer_view
}
}

The next interesting pass is LLVMCPUMaterializeEncoding. It is also enabled by --iree-flow-enable-data-tiling. It runs at the start of codegen (HAL) and is where things start to be target-specialized. What it does is, given encodings set earlier by SetEncoding and given target attributes (target triple and target CPU features), pick specific tile sizes.

The IR log shows its effect separately on each of the dispatch functions that were formed above. Here is its effect on a set_encoding dispatch, where we see the above set_encoding op now rewritten as a tensor.pack:

// -----// IR Dump After LLVMCPUMaterializeEncoding (iree-llvmcpu-materialize-encoding) //----- //
func.func @matmul_i8_dispatch_0_set_encoding_MATMUL_I8I8I32_LHS_DxD() {
  // ... (skipping uninteresting lines) ...
  %21 = tensor.empty(%19, %20) : tensor<?x?x8x8xi8>
  %pack = tensor.pack %16 padding_value(%c0_i8 : i8) inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %21 : tensor<?x?xi8> -> tensor<?x?x8x8xi8>
  %22 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%10]
  %23 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%11]
  flow.dispatch.tensor.store %pack, %15, offsets = [0, 0, 0, 0], sizes = [%22, %23, 8, 8], strides = [1, 1, 1, 1] : tensor<?x?x8x8xi8> -> !flow.dispatch.tensor<writeonly:tensor<?x?x8x8xi8>>{%22, %23}
  return
}

Here we see that the LHS matrix, of i8 element type, gets tiled by 8x8 tiles, which is what we want given that we are targeting the "+i8mm" CPU feature, and given that the matrix dimensions are dynamic.

If some matrix dimension was statically very small, the LLVMCPUMaterializeEncoding pass would select a smaller tile size accordingly, so as to avoid padding very narrow matrices by too much. For example, single-row LHS matrices should remain single-row, and we should have dedicated "matrix times vector" matmul kernels to deal with that case. That's work-in-progress, though.

More interesting though is what happens to the linalg.matmul itself. As its operands are being rewritten to be the results of tensor.pack ops, i.e. tiled matrices, the linalg.matmul gets rewritten into a tiled matmul op, linalg.mmt4d:

// -----// IR Dump After LLVMCPUMaterializeEncoding (iree-llvmcpu-materialize-encoding) //----- //
func.func @matmul_i8_dispatch_3_matmul_DxDxD_i8xi8xi32() {
  // ... (skipping uninteresting lines) ...
  %40 = linalg.mmt4d ins(%33, %36 : tensor<?x?x8x8xi8>, tensor<?x?x8x8xi8>) outs(%39 : tensor<?x?x8x8xi32>) -> tensor<?x?x8x8xi32>
  %41 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%20]
  %42 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%21]
  flow.dispatch.tensor.store %40, %30, offsets = [0, 0, 0, 0], sizes = [%41, %42, 8, 8], strides = [1, 1, 1, 1] : tensor<?x?x8x8xi32> -> !flow.dispatch.tensor<readwrite:tensor<?x?x8x8xi32>>{%41, %42}
  return
}

At this point, our original linalg.matmul has become a linalg.mmt4d, with some tensor.pack and tensor.unpack around it to bring matrices in and out of the 4D shapes representing what will become tiled data-layout when they get bufferized --- we are still before bufferization here, but everything gets bufferized row-major, so we can effectively dictate bufferization into tiled layouts by expanding 2D shapes into 4D shapes and transposing indices, which is exactly what tensor.pack does.

Next, the LLVMCPULowerToUKernels pass rewrites any op that has a ukernel into a special ukernel.generic op, which is not yet a function call, but which contains enough information to be lowered to one. In particular, it takes as attributes a symbol name and some ABI settings. Note that this is still before bufferization, so it would not yet be possible to have a function call (there would not yet be any pointers to pass to a function). Having a ukernel.generic op working on tensors is very convenient to write patterns rewriting IR into ukernel calls while still on tensor semantics.

// -----// IR Dump After LLVMCPULowerToUKernels (iree-llvmcpu-lower-to-ukernels) //----- //
module {
  func.func @matmul_i8_dispatch_3_matmul_DxDxD_i8xi8xi32() {
    // ... (skipping uninteresting lines) ...
    %34 = iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%31, %32 : tensor<?x?x8x8xi8>, tensor<?x?x8x8xi8>) outs(%33 : tensor<?x?x8x8xi32>) (%dim, %dim_0, %dim_1, %c8_i32, %c8_i32, %c8_i32, %c258_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"]} strided_outer_dims(1) -> tensor<?x?x8x8xi32>
    flow.dispatch.tensor.store %34, %24, offsets = [%arg0, %arg1, 0, 0], sizes = [%27, %30, 8, 8], strides = [1, 1, 1, 1] : tensor<?x?x8x8xi32> -> !flow.dispatch.tensor<readwrite:tensor<?x?x8x8xi32>>{%22, %23}
    // ... (skipping uninteresting lines) ...
  }
}

Next, bufferization happens:

// -----// IR Dump After IREEComprehensiveBufferize (iree-codegen-iree-comprehensive-bufferize) //----- //
module {
  func.func @matmul_i8_dispatch_3_matmul_DxDxD_i8xi8xi32() {
    // ... (skipping uninteresting lines) ...
    iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%subview, %subview_0 : memref<?x?x8x8xi8, strided<[?, 64, 8, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<?x?x8x8xi8, strided<[?, 64, 8, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_1 : memref<?x?x8x8xi32, strided<[?, 64, 8, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) (%27, %30, %17, %c8_i32, %c8_i32, %c8_i32, %c258_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"]} strided_outer_dims(1)
    // ... (skipping uninteresting lines) ...
  }
}

Next, ukernel ops get lowered to function calls. Note how a declaration of the call target @iree_uk_mmt4d is inserted.

// -----// IR Dump After LowerUKernelOpsToCalls (iree-codegen-lower-ukernel-ops-to-calls) //----- //
module {
  func.func private @iree_uk_mmt4d(memref<i8>, index, index, memref<i8>, index, index, memref<i32>, index, index, index, index, index, i32, i32, i32, i32) attributes {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"], llvm.bareptr = true}
  func.func @matmul_i8_dispatch_3_matmul_DxDxD_i8xi8xi32() {
    // ... (skipping uninteresting lines) ...
    func.call @iree_uk_mmt4d(%base_buffer, %offset, %strides#0, %base_buffer_2, %offset_3, %strides_5#0, %base_buffer_6, %offset_7, %strides_9#0, %27, %30, %17, %c8_i32, %c8_i32, %c8_i32, %c258_i32) : (memref<i8>, index, index, memref<i8>, index, index, memref<i32>, index, index, index, index, index, i32, i32, i32, i32) -> ()
    // ... (skipping uninteresting lines) ...
  }
}

Next, microkernel bitcode is linked.

That part doesn't directly show in the IR log, but in the IREE build directory you can find the bitcode files,

benoitjacob@cloud:~/iree-build-linux$ find . -name '*.bc' | grep mmt4d
./runtime/src/iree/builtins/ukernel/ukernel_bitcode_32bit_base_mmt4d.c.bc
./runtime/src/iree/builtins/ukernel/ukernel_bitcode_32bit_base_mmt4d_tile.c.bc
./runtime/src/iree/builtins/ukernel/arch/arm_64/ukernel_bitcode_arm_64_dotprod_mmt4d_arm_64_dotprod.c.bc
./runtime/src/iree/builtins/ukernel/arch/arm_64/ukernel_bitcode_arm_64_base_mmt4d_arm_64.c.bc
./runtime/src/iree/builtins/ukernel/arch/arm_64/ukernel_bitcode_arm_64_i8mm_mmt4d_arm_64_i8mm.c.bc
./runtime/src/iree/builtins/ukernel/arch/x86_64/ukernel_bitcode_x86_64_avx2_fma_mmt4d_x86_64_avx2_fma.c.bc
./runtime/src/iree/builtins/ukernel/arch/x86_64/ukernel_bitcode_x86_64_avx512_vnni_mmt4d_x86_64_avx512_vnni.c.bc
./runtime/src/iree/builtins/ukernel/arch/x86_64/ukernel_bitcode_x86_64_base_mmt4d_x86_64.c.bc
./runtime/src/iree/builtins/ukernel/arch/x86_64/ukernel_bitcode_x86_64_avx512_base_mmt4d_x86_64_avx512_base.c.bc
./runtime/src/iree/builtins/ukernel/ukernel_bitcode_64bit_base_mmt4d_tile.c.bc
./runtime/src/iree/builtins/ukernel/ukernel_bitcode_64bit_base_mmt4d.c.bc
./ukernel_bitcode_x86_64_base_mmt4d_x86_64.c.bc

They are built direclty from the C sources under runtime/src/iree/builtins/ukernel/ using our own build of Clang (not the host toolchain), see the CMake function iree_bitcode_library().

Disassembling the generated code.

  • Add the following flags to iree-compile:
    • --iree-llvmcpu-keep-linker-artifacts so it will keep a .so in /tmp and will print its filename.
    • --iree-llvmcpu-link-embedded=false to generate a standard ELF as opposed to IREE's own "embedded" ELF flavour, so that tools such as objdump understand it better.
  • objdump the resulting .so as usual, e.g. $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-objdump -d --mattr=+dotprod,+i8mm /tmp/matmul_i8_linked_llvm_cpu-d384fd.so.
func.func @matmul_i8(%lhs: tensor<?x?xi8>, %rhs: tensor<?x?xi8>, %acc: tensor<?x?xi32>) -> tensor<?x?xi32> {
%result = linalg.matmul ins(%lhs, %rhs: tensor<?x?xi8>, tensor<?x?xi8>) outs(%acc: tensor<?x?xi32>) -> tensor<?x?xi32>
return %result: tensor<?x?xi32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment