Skip to content

Instantly share code, notes, and snippets.

@Abhishek-Varma
Created June 27, 2024 11:19
Show Gist options
  • Save Abhishek-Varma/909f4d5059ee21c25c376ce1751827a1 to your computer and use it in GitHub Desktop.
Save Abhishek-Varma/909f4d5059ee21c25c376ce1751827a1 to your computer and use it in GitHub Desktop.
non objectfifo bf16 matmul vectorization
This file has been truncated, but you can view the full file.
// -----// IR Dump Before TranslateTargetExecutableVariantsPass (iree-hal-translate-target-executable-variants) //----- //
hal.executable.variant public @amdaie_xclbin_fb target(<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "none"}>) {
hal.executable.export public @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
}
}
// -----// IR Dump Before TypePropagation (iree-codegen-type-propagation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before BubbleUpOrdinalOps (iree-codegen-bubble-up-ordinal-ops) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before BufferizeCopyOnlyDispatches (iree-codegen-bufferize-copy-only-dispatches) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before DecomposeSoftmax (iree-codegen-decompose-softmax) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before MaterializeUserConfigs (iree-codegen-materialize-user-configs) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
}
// -----// IR Dump Before AMDAIELoweringStrategy (iree-amdaie-lowering-strategy) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
}
// -----// IR Dump Before LowerExecutableUsingTransformDialect (iree-codegen-lower-executable-using-transform-dialect) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
}
// -----// IR Dump Before AMDAIELowerExecutableTarget (iree-amdaie-lower-executable-target) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIETileAndFuse (iree-amdaie-tile-and-fuse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIECleanup (iree-amdaie-cleanup) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = linalg.fill ins(%cst : bf16) outs(%5 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%7 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%3, %4 : tensor<308x9728xbf16>, tensor<9728x2432xbf16>) outs(%6 : tensor<308x2432xbf16>) -> tensor<308x2432xbf16>
%8 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%9 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_1 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
%extracted_slice_2 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%10 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%extracted_slice, %extracted_slice_0 : tensor<44x9728xbf16>, tensor<9728x128xbf16>) outs(%9 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %8, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_1 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
%8 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%extracted_slice, %extracted_slice_0 : tensor<44x9728xbf16>, tensor<9728x128xbf16>) outs(%7 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_1 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
%8 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%extracted_slice, %extracted_slice_0 : tensor<44x9728xbf16>, tensor<9728x128xbf16>) outs(%7 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEPackAndTranspose (iree-amdaie-pack-and-transpose) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_1 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
%8 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} ins(%extracted_slice, %extracted_slice_0 : tensor<44x9728xbf16>, tensor<9728x128xbf16>) outs(%7 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEPropagateDataLayout (iree-amdaie-propagate-data-layout) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_1 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
%8 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %8 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%9 = tensor.empty() : tensor<152x2x64x64xbf16>
%10 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%11 = tensor.empty() : tensor<1x2x44x64xbf16>
%pack_3 = tensor.pack %7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %11 : tensor<44x128xbf16> -> tensor<1x2x44x64xbf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x152x44x64xbf16>, tensor<152x2x64x64xbf16>) outs(%pack_3 : tensor<1x2x44x64xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_4: bf16, %out: bf16):
%13 = arith.mulf %in, %in_4 : bf16
%14 = arith.addf %out, %13 : bf16
linalg.yield %14 : bf16
} -> tensor<1x2x44x64xbf16>
%unpack = tensor.unpack %12 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_1 : tensor<44x128xbf16>) -> tensor<44x128xbf16>
%8 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %8 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%9 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %9 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%10 = tensor.empty() : tensor<1x2x44x64xbf16>
%pack_3 = tensor.pack %7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %10 : tensor<44x128xbf16> -> tensor<1x2x44x64xbf16>
%11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x152x44x64xbf16>, tensor<152x2x64x64xbf16>) outs(%pack_3 : tensor<1x2x44x64xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_4: bf16, %out: bf16):
%12 = arith.mulf %in, %in_4 : bf16
%13 = arith.addf %out, %12 : bf16
linalg.yield %13 : bf16
} -> tensor<1x2x44x64xbf16>
%unpack = tensor.unpack %11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%9 = tensor.empty() : tensor<1x2x44x64xbf16>
%10 = linalg.fill ins(%cst : bf16) outs(%9 : tensor<1x2x44x64xbf16>) -> tensor<1x2x44x64xbf16>
%11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x152x44x64xbf16>, tensor<152x2x64x64xbf16>) outs(%10 : tensor<1x2x44x64xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_3: bf16, %out: bf16):
%12 = arith.mulf %in, %in_3 : bf16
%13 = arith.addf %out, %12 : bf16
linalg.yield %13 : bf16
} -> tensor<1x2x44x64xbf16>
%unpack = tensor.unpack %11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEBufferizeToAllocation (iree-amdaie-bufferize-to-allocation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%9 = tensor.empty() : tensor<1x2x44x64xbf16>
%10 = linalg.fill ins(%cst : bf16) outs(%9 : tensor<1x2x44x64xbf16>) -> tensor<1x2x44x64xbf16>
%11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x152x44x64xbf16>, tensor<152x2x64x64xbf16>) outs(%10 : tensor<1x2x44x64xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_3: bf16, %out: bf16):
%12 = arith.mulf %in, %in_3 : bf16
%13 = arith.addf %out, %12 : bf16
linalg.yield %13 : bf16
} -> tensor<1x2x44x64xbf16>
%unpack = tensor.unpack %11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEBufferizeToAllocation (iree-amdaie-bufferize-to-allocation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%9 = tensor.empty() : tensor<1x2x44x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%11 = linalg.fill ins(%cst : bf16) outs(%10 : tensor<1x2x44x64xbf16>) -> tensor<1x2x44x64xbf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x152x44x64xbf16>, tensor<152x2x64x64xbf16>) outs(%11 : tensor<1x2x44x64xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_3: bf16, %out: bf16):
%13 = arith.mulf %in, %in_3 : bf16
%14 = arith.addf %out, %13 : bf16
linalg.yield %14 : bf16
} -> tensor<1x2x44x64xbf16>
%unpack = tensor.unpack %12 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEPackAndTranspose (iree-amdaie-pack-and-transpose) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%9 = tensor.empty() : tensor<1x2x44x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%11 = linalg.fill ins(%cst : bf16) outs(%10 : tensor<1x2x44x64xbf16>) -> tensor<1x2x44x64xbf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x152x44x64xbf16>, tensor<152x2x64x64xbf16>) outs(%11 : tensor<1x2x44x64xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_3: bf16, %out: bf16):
%13 = arith.mulf %in, %in_3 : bf16
%14 = arith.addf %out, %13 : bf16
linalg.yield %14 : bf16
} -> tensor<1x2x44x64xbf16>
%unpack = tensor.unpack %12 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEPropagateDataLayout (iree-amdaie-propagate-data-layout) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%9 = tensor.empty() : tensor<1x2x44x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%11 = linalg.fill ins(%cst : bf16) outs(%10 : tensor<1x2x44x64xbf16>) -> tensor<1x2x44x64xbf16>
%12 = tensor.empty() : tensor<1x152x11x8x4x8xbf16>
%13 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %13 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%14 = tensor.empty() : tensor<152x2x8x16x4x8xbf16>
%15 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %15 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%16 = tensor.empty() : tensor<1x2x11x16x4x4xbf16>
%17 = tensor.empty() : tensor<1x2x16x11x4x4xbf16>
%pack_5 = tensor.pack %11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %17 : tensor<1x2x44x64xbf16> -> tensor<1x2x16x11x4x4xbf16>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_3, %pack_4 : tensor<1x152x8x11x4x8xbf16>, tensor<152x2x16x8x8x4xbf16>) outs(%pack_5 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_7: bf16, %out: bf16):
%19 = arith.mulf %in, %in_7 : bf16
%20 = arith.addf %out, %19 : bf16
linalg.yield %20 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%unpack = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %11 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = linalg.fill ins(%cst : bf16) outs(%9 : tensor<1x2x44x64xbf16>) -> tensor<1x2x44x64xbf16>
%11 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %11 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%12 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %12 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%13 = tensor.empty() : tensor<1x2x16x11x4x4xbf16>
%pack_5 = tensor.pack %10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %13 : tensor<1x2x44x64xbf16> -> tensor<1x2x16x11x4x4xbf16>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_3, %pack_4 : tensor<1x152x8x11x4x8xbf16>, tensor<152x2x16x8x8x4xbf16>) outs(%pack_5 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_7: bf16, %out: bf16):
%15 = arith.mulf %in, %in_7 : bf16
%16 = arith.addf %out, %15 : bf16
linalg.yield %16 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %10 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%12 = tensor.empty() : tensor<1x2x16x11x4x4xbf16>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_3, %pack_4 : tensor<1x152x8x11x4x8xbf16>, tensor<152x2x16x8x8x4xbf16>) outs(%13 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_6: bf16, %out: bf16):
%15 = arith.mulf %in, %in_6 : bf16
%16 = arith.addf %out, %15 : bf16
linalg.yield %16 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_5 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_5 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEBufferizeToAllocation (iree-amdaie-bufferize-to-allocation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%12 = tensor.empty() : tensor<1x2x16x11x4x4xbf16>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_3, %pack_4 : tensor<1x152x8x11x4x8xbf16>, tensor<152x2x16x8x8x4xbf16>) outs(%13 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_6: bf16, %out: bf16):
%15 = arith.mulf %in, %in_6 : bf16
%16 = arith.addf %out, %15 : bf16
linalg.yield %16 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_5 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_5 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIETileAndFuse (iree-amdaie-tile-and-fuse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%12 = tensor.empty() : tensor<1x2x16x11x4x4xbf16>
%alloc_5 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%14 = linalg.fill ins(%cst : bf16) outs(%13 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_3, %pack_4 : tensor<1x152x8x11x4x8xbf16>, tensor<152x2x16x8x8x4xbf16>) outs(%14 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_7: bf16, %out: bf16):
%16 = arith.mulf %in, %in_7 : bf16
%17 = arith.addf %out, %16 : bf16
linalg.yield %17 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%unpack = tensor.unpack %15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_5 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIECleanup (iree-amdaie-cleanup) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%12 = tensor.empty() : tensor<1x2x16x11x4x4xbf16>
%alloc_5 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%14 = linalg.fill ins(%cst : bf16) outs(%13 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_3, %pack_4 : tensor<1x152x8x11x4x8xbf16>, tensor<152x2x16x8x8x4xbf16>) outs(%14 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_8: bf16, %out: bf16):
%17 = arith.mulf %in, %in_8 : bf16
%18 = arith.addf %out, %17 : bf16
linalg.yield %18 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%c0_6 = arith.constant 0 : index
%c152 = arith.constant 152 : index
%c1 = arith.constant 1 : index
%16 = scf.for %arg3 = %c0_6 to %c152 step %c1 iter_args(%arg4 = %14) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_8 = tensor.extract_slice %pack_3[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_9 = tensor.extract_slice %pack_4[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%extracted_slice_10 = tensor.extract_slice %arg4[0, 0, 0, 0, 0, 0] [1, 2, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x2x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_8, %extracted_slice_9 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%extracted_slice_10 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_11: bf16, %out: bf16):
%18 = arith.mulf %in, %in_11 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%inserted_slice = tensor.insert_slice %17 into %arg4[0, 0, 0, 0, 0, 0] [1, 2, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_5 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%alloc_5 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_7 = tensor.extract_slice %pack_3[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_8 = tensor.extract_slice %pack_4[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_7, %extracted_slice_8 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_9: bf16, %out: bf16):
%16 = arith.mulf %in, %in_9 : bf16
%17 = arith.addf %out, %16 : bf16
linalg.yield %17 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
scf.yield %15 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_5 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%alloc_5 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_7 = tensor.extract_slice %pack_3[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_8 = tensor.extract_slice %pack_4[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_7, %extracted_slice_8 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_9: bf16, %out: bf16):
%16 = arith.mulf %in, %in_9 : bf16
%17 = arith.addf %out, %16 : bf16
linalg.yield %17 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
scf.yield %15 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_5 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEFusePackIntoLoop (iree-amdaie-fuse-pack-into-loop) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%alloc_5 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_7 = tensor.extract_slice %pack_3[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_8 = tensor.extract_slice %pack_4[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_7, %extracted_slice_8 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_9: bf16, %out: bf16):
%16 = arith.mulf %in, %in_9 : bf16
%17 = arith.addf %out, %16 : bf16
linalg.yield %17 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
scf.yield %15 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_5 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %7 : tensor<44x9728xbf16> -> tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %8 : tensor<9728x128xbf16> -> tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%pack_3 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x152x44x64xbf16> -> tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%pack_4 = tensor.pack %pack_2 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<152x2x64x64xbf16> -> tensor<152x2x16x8x8x4xbf16>
%alloc_5 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%15 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, %15] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %7[0, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x152x44x64xbf16> to tensor<1x1x44x64xbf16>
%pack_9 = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_8 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %pack[0, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x152x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_11 = tensor.extract_slice %10[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_12 = tensor.pack %pack_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_11 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_13 = tensor.extract_slice %pack_3[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice_0[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%extracted_slice_15 = tensor.extract_slice %8[%arg3, 0, 0, 0] [1, 2, 64, 64] [1, 1, 1, 1] : tensor<152x2x64x64xbf16> to tensor<1x2x64x64xbf16>
%pack_16 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %extracted_slice_15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%extracted_slice_17 = tensor.extract_slice %pack_2[%arg3, 0, 0, 0] [1, 2, 64, 64] [1, 1, 1, 1] : tensor<152x2x64x64xbf16> to tensor<1x2x64x64xbf16>
%extracted_slice_18 = tensor.extract_slice %11[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%pack_19 = tensor.pack %pack_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_18 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_4[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_12, %pack_19 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_21: bf16, %out: bf16):
%18 = arith.mulf %in, %in_21 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
scf.yield %17 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_6 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_5 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_6 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%15 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %15] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%extracted_slice_5 = tensor.extract_slice %7[0, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x152x44x64xbf16> to tensor<1x1x44x64xbf16>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_5 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_6 = tensor.extract_slice %10[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_7 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_6 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%extracted_slice_9 = tensor.extract_slice %8[%arg3, 0, 0, 0] [1, 2, 64, 64] [1, 1, 1, 1] : tensor<152x2x64x64xbf16> to tensor<1x2x64x64xbf16>
%pack_10 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %extracted_slice_9 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%extracted_slice_11 = tensor.extract_slice %11[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%pack_12 = tensor.pack %pack_10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_11 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_7, %pack_12 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_13: bf16, %out: bf16):
%18 = arith.mulf %in, %in_13 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
scf.yield %17 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEBufferizeToAllocation (iree-amdaie-bufferize-to-allocation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%15 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %15] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%extracted_slice_5 = tensor.extract_slice %7[0, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x152x44x64xbf16> to tensor<1x1x44x64xbf16>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_5 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_6 = tensor.extract_slice %10[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_7 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_6 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%15, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%extracted_slice_9 = tensor.extract_slice %8[%arg3, 0, 0, 0] [1, 2, 64, 64] [1, 1, 1, 1] : tensor<152x2x64x64xbf16> to tensor<1x2x64x64xbf16>
%pack_10 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %extracted_slice_9 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%extracted_slice_11 = tensor.extract_slice %11[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%pack_12 = tensor.pack %pack_10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_11 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_7, %pack_12 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_13: bf16, %out: bf16):
%17 = arith.mulf %in, %in_13 : bf16
%18 = arith.addf %out, %17 : bf16
linalg.yield %18 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIETileAndFuse (iree-amdaie-tile-and-fuse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%15 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %15] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%extracted_slice_5 = tensor.extract_slice %7[0, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x152x44x64xbf16> to tensor<1x1x44x64xbf16>
%alloc_6 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%16 = bufferization.to_tensor %alloc_6 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %16 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_7 = tensor.extract_slice %10[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_8 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_7 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_9 = tensor.extract_slice %extracted_slice_0[%15, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%extracted_slice_10 = tensor.extract_slice %8[%arg3, 0, 0, 0] [1, 2, 64, 64] [1, 1, 1, 1] : tensor<152x2x64x64xbf16> to tensor<1x2x64x64xbf16>
%alloc_11 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%17 = bufferization.to_tensor %alloc_11 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_12 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %17 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%extracted_slice_13 = tensor.extract_slice %11[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%pack_14 = tensor.pack %pack_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_13 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_8, %pack_14 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_15: bf16, %out: bf16):
%19 = arith.mulf %in, %in_15 : bf16
%20 = arith.addf %out, %19 : bf16
linalg.yield %20 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
memref.dealloc %alloc_6 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_11 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %18 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIECleanup (iree-amdaie-cleanup) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = tensor.empty() : tensor<1x152x44x64xbf16>
%8 = tensor.empty() : tensor<152x2x64x64xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%9 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%10 = tensor.empty() : tensor<1x152x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<152x2x16x8x8x4xbf16>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%12 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%13 = linalg.fill ins(%cst : bf16) outs(%12 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%14 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%15 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %15] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%extracted_slice_5 = tensor.extract_slice %7[0, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x152x44x64xbf16> to tensor<1x1x44x64xbf16>
%alloc_6 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%16 = bufferization.to_tensor %alloc_6 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %16 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_7 = tensor.extract_slice %10[0, %arg3, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x152x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_8 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_7 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_9 = tensor.extract_slice %extracted_slice_0[%15, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%extracted_slice_10 = tensor.extract_slice %8[%arg3, 0, 0, 0] [1, 2, 64, 64] [1, 1, 1, 1] : tensor<152x2x64x64xbf16> to tensor<1x2x64x64xbf16>
%alloc_11 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%17 = bufferization.to_tensor %alloc_11 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_12 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %17 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%extracted_slice_13 = tensor.extract_slice %11[%arg3, 0, 0, 0, 0, 0] [1, 2, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<152x2x16x8x8x4xbf16> to tensor<1x2x16x8x8x4xbf16>
%pack_14 = tensor.pack %pack_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_13 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_8, %pack_14 : tensor<1x1x8x11x4x8xbf16>, tensor<1x2x16x8x8x4xbf16>) outs(%arg4 : tensor<1x2x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_15: bf16, %out: bf16):
%20 = arith.mulf %in, %in_15 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x2x16x11x4x4xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_15 = tensor.extract_slice %pack_8[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_14[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%extracted_slice_17 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_15, %extracted_slice_16 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_17 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_18: bf16, %out: bf16):
%21 = arith.mulf %in, %in_18 : bf16
%22 = arith.addf %out, %21 : bf16
linalg.yield %22 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %20 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_6 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_11 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %9 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%pack_6 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_8 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_8 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_7 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%pack_10 = tensor.pack %pack_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack_6[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_12 = tensor.extract_slice %pack_10[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%extracted_slice_13 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_11, %extracted_slice_12 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_13 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_14: bf16, %out: bf16):
%18 = arith.mulf %in, %in_14 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_8 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%pack_6 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_8 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_8 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_7 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%pack_10 = tensor.pack %pack_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack_6[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_12 = tensor.extract_slice %pack_10[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%extracted_slice_13 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_11, %extracted_slice_12 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_13 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_14: bf16, %out: bf16):
%18 = arith.mulf %in, %in_14 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_8 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEFusePackIntoLoop (iree-amdaie-fuse-pack-into-loop) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%pack_6 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_8 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_8 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_7 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%pack_10 = tensor.pack %pack_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack_6[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_12 = tensor.extract_slice %pack_10[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%extracted_slice_13 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_11, %extracted_slice_12 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_13 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_14: bf16, %out: bf16):
%18 = arith.mulf %in, %in_14 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_8 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%pack_6 = tensor.pack %pack outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %10 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_8 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_8 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_7 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%pack_10 = tensor.pack %pack_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %11 : tensor<1x2x64x64xbf16> -> tensor<1x2x16x8x8x4xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %10[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_13 = tensor.pack %extracted_slice_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_12 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %pack_6[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%extracted_slice_15 = tensor.extract_slice %pack_9[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%extracted_slice_16 = tensor.extract_slice %11[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%pack_17 = tensor.pack %extracted_slice_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_16 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %pack_10[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%18 = arith.mulf %in, %in_20 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_8 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_6 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_7 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_7 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_8 = tensor.pack %extracted_slice_6 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_9 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %10[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_11 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_10 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_12 = tensor.extract_slice %pack_8[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%extracted_slice_13 = tensor.extract_slice %11[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%pack_14 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_13 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_15 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_11, %pack_14 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_15 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_16: bf16, %out: bf16):
%18 = arith.mulf %in, %in_16 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_7 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before AMDAIEBufferizeToAllocation (iree-amdaie-bufferize-to-allocation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_6 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_7 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_7 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_8 = tensor.pack %extracted_slice_6 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_9 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %10[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%pack_11 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %extracted_slice_10 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_12 = tensor.extract_slice %pack_8[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%extracted_slice_13 = tensor.extract_slice %11[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%pack_14 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %extracted_slice_13 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_15 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_11, %pack_14 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_15 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_16: bf16, %out: bf16):
%18 = arith.mulf %in, %in_16 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_7 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before HoistStaticallyBoundAllocations (iree-hoist-statically-bound-allocations) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_0 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%7 = bufferization.to_tensor %alloc restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%8 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%14 = bufferization.to_tensor %alloc_5 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_6 = tensor.extract_slice %extracted_slice_0[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%alloc_7 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%15 = bufferization.to_tensor %alloc_7 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_8 = tensor.pack %extracted_slice_6 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_9 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %10[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%alloc_11 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%17 = bufferization.to_tensor %alloc_11 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_12 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_13 = tensor.extract_slice %pack_8[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%extracted_slice_14 = tensor.extract_slice %11[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%alloc_15 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%18 = bufferization.to_tensor %alloc_15 restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_16 = tensor.pack %extracted_slice_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_17 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_12, %pack_16 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_17 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_18: bf16, %out: bf16):
%20 = arith.mulf %in, %in_18 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
memref.dealloc %alloc_11 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_15 : memref<1x1x16x8x8x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_7 : memref<1x2x64x64xbf16, 1 : i32>
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_3 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_1 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x2x16x11x4x4xbf16, 2 : i32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_3 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%c1 = arith.constant 1 : index
%c152 = arith.constant 152 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = tensor.empty() : tensor<1x1x8x11x4x8xbf16>
%11 = tensor.empty() : tensor<1x2x16x8x8x4xbf16>
%12 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_8 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_8 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_9 = tensor.extract_slice %extracted_slice_5[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_10 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %10[%arg5, 0, 0, 0, 0, 0] [1, 1, 8, 11, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x8x11x4x8xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_13 = tensor.pack %extracted_slice_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %pack_10[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%extracted_slice_15 = tensor.extract_slice %11[0, %arg6, 0, 0, 0, 0] [1, 1, 16, 8, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x8x8x4xbf16> to tensor<1x1x16x8x8x4xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_16 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_17 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_16 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_17 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_18: bf16, %out: bf16):
%20 = arith.mulf %in, %in_18 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c152 = arith.constant 152 : index
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%11 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_8 = tensor.extract_slice %extracted_slice[0, %11] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%12 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_8 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %12 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_9 = tensor.extract_slice %extracted_slice_5[%11, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%13 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_10 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %13 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%14 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%15 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_12 = tensor.pack %extracted_slice_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %15 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_13 = tensor.extract_slice %pack_10[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%16 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_14 = tensor.pack %extracted_slice_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %16 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_15 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_12, %pack_14 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_15 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_16: bf16, %out: bf16):
%18 = arith.mulf %in, %in_16 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %14 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEPeelForLoop (iree-amdaie-peel-for-loop) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c152 = arith.constant 152 : index
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%10 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%11 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_8 = tensor.extract_slice %extracted_slice[0, %11] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%12 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_8 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %12 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_9 = tensor.extract_slice %extracted_slice_5[%11, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%13 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_10 = tensor.pack %extracted_slice_9 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %13 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%14 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_11 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%15 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_12 = tensor.pack %extracted_slice_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %15 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_13 = tensor.extract_slice %pack_10[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%16 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_14 = tensor.pack %extracted_slice_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %16 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_15 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_12, %pack_14 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_15 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_16: bf16, %out: bf16):
%18 = arith.mulf %in, %in_16 : bf16
%19 = arith.addf %out, %18 : bf16
linalg.yield %19 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %17 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %14 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_7 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_7 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c152 = arith.constant 152 : index
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%c1_7 = arith.constant 1 : index
%10 = scf.for %arg3 = %c0 to %c1_7 step %c1 iter_args(%arg4 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_9 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %extracted_slice_5[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_12 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %pack_11[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_16 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_15 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_16 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_17: bf16, %out: bf16):
%20 = arith.mulf %in, %in_17 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%c151 = arith.constant 151 : index
%11 = scf.for %arg3 = %c1_7 to %c151 step %c1 iter_args(%arg4 = %10) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_9 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %extracted_slice_5[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_12 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %pack_11[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_16 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_15 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_16 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_17: bf16, %out: bf16):
%20 = arith.mulf %in, %in_17 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%12 = scf.for %arg3 = %c151 to %c152 step %c1 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%13 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_9 = tensor.extract_slice %extracted_slice[0, %13] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_10 = tensor.extract_slice %extracted_slice_5[%13, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_12 = tensor.extract_slice %pack[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_14 = tensor.extract_slice %pack_11[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_16 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_13, %pack_15 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_16 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_17: bf16, %out: bf16):
%20 = arith.mulf %in, %in_17 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %16 : tensor<1x2x16x11x4x4xbf16>
}
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_8 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_8 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEFuseFillIntoForall (iree-amdaie-fuse-fill-into-forall) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%10 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %10 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%11 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %11 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%12 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %9) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_15 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_16 = tensor.pack %extracted_slice_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_17 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_18 = tensor.pack %extracted_slice_17 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_16, %pack_18 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%20 = arith.mulf %in, %in_20 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%13 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %12) -> (tensor<1x2x16x11x4x4xbf16>) {
%17 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_15 = tensor.extract_slice %extracted_slice[0, %17] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%18 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_16 = tensor.pack %extracted_slice_15 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %18 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_17 = tensor.extract_slice %extracted_slice_5[%17, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%19 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_18 = tensor.pack %extracted_slice_17 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %19 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%20 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_19 = tensor.extract_slice %pack_16[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%21 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_20 = tensor.pack %extracted_slice_19 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %21 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_21 = tensor.extract_slice %pack_18[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%22 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_22 = tensor.pack %extracted_slice_21 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %22 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_23 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_20, %pack_22 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_23 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_24: bf16, %out: bf16):
%24 = arith.mulf %in, %in_24 : bf16
%25 = arith.addf %out, %24 : bf16
linalg.yield %25 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %23 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %20 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_15 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_16 = tensor.pack %extracted_slice_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_17 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_18 = tensor.pack %extracted_slice_17 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_16, %pack_18 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%20 = arith.mulf %in, %in_20 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_14 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_14 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEFuseConsumerIntoLoop (iree-amdaie-fuse-consumer-into-loop) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%10 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %10 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%11 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %11 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%12 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_15 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_16 = tensor.pack %extracted_slice_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_17 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_18 = tensor.pack %extracted_slice_17 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_19 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%extracted_slice_20 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_16, %pack_18 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_21: bf16, %out: bf16):
%21 = arith.mulf %in, %in_21 : bf16
%22 = arith.addf %out, %21 : bf16
linalg.yield %22 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %20 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%13 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %12) -> (tensor<1x2x16x11x4x4xbf16>) {
%17 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_15 = tensor.extract_slice %extracted_slice[0, %17] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%18 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_16 = tensor.pack %extracted_slice_15 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %18 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_17 = tensor.extract_slice %extracted_slice_5[%17, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%19 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_18 = tensor.pack %extracted_slice_17 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %19 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%20 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_19 = tensor.extract_slice %pack_16[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%21 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_20 = tensor.pack %extracted_slice_19 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %21 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_21 = tensor.extract_slice %pack_18[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%22 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_22 = tensor.pack %extracted_slice_21 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %22 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_23 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_20, %pack_22 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_23 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_24: bf16, %out: bf16):
%24 = arith.mulf %in, %in_24 : bf16
%25 = arith.addf %out, %24 : bf16
linalg.yield %25 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %23 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %20 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %13) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_15 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_16 = tensor.pack %extracted_slice_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_17 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_18 = tensor.pack %extracted_slice_17 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_16, %pack_18 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%20 = arith.mulf %in, %in_20 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %7 : tensor<1x2x16x11x4x4xbf16> -> tensor<1x2x44x64xbf16>
%unpack_14 = tensor.unpack %unpack inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_14 into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEFusePackIntoLoop (iree-amdaie-fuse-pack-into-loop) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%10 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %10 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%11 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %11 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%12 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%21 = arith.mulf %in, %in_20 : bf16
%22 = arith.addf %out, %21 : bf16
linalg.yield %22 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %20 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%13 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %12) -> (tensor<1x2x16x11x4x4xbf16>) {
%17 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %17] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%18 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %18 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%17, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%19 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %19 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%20 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%21 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %21 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%22 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %22 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_19, %pack_21 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_22 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_23: bf16, %out: bf16):
%24 = arith.mulf %in, %in_23 : bf16
%25 = arith.addf %out, %24 : bf16
linalg.yield %25 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %23 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %20 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %13, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_22: bf16, %out: bf16):
%20 = arith.mulf %in, %in_22 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
%inserted_slice = tensor.insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_20 = tensor.extract_slice %inserted_slice[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%unpack_21 = tensor.unpack %19 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_21 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %16#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEBufferizeToAllocation (iree-amdaie-bufferize-to-allocation) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%10 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %10 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%11 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %11 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%12 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%21 = arith.mulf %in, %in_20 : bf16
%22 = arith.addf %out, %21 : bf16
linalg.yield %22 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %20 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%13 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %12) -> (tensor<1x2x16x11x4x4xbf16>) {
%17 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %17] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%18 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %18 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%17, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%19 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %19 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%20 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%21 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %21 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%22 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %22 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_19, %pack_21 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_22 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_23: bf16, %out: bf16):
%24 = arith.mulf %in, %in_23 : bf16
%25 = arith.addf %out, %24 : bf16
linalg.yield %25 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %23 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %20 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %13, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_22: bf16, %out: bf16):
%20 = arith.mulf %in, %in_22 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
%inserted_slice = tensor.insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_20 = tensor.extract_slice %inserted_slice[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%unpack_21 = tensor.unpack %19 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_21 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %16#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIELowerToUKernels (iree-amdaie-lower-to-ukernels) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<1x2x16x11x4x4xbf16>) -> tensor<1x2x16x11x4x4xbf16>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%10 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %10 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%11 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %11 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%12 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%19 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_20: bf16, %out: bf16):
%21 = arith.mulf %in, %in_20 : bf16
%22 = arith.addf %out, %21 : bf16
linalg.yield %22 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %20 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%13 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %12) -> (tensor<1x2x16x11x4x4xbf16>) {
%17 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %17] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%18 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %18 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%17, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%19 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %19 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%20 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%21 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %21 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%22 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %22 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_19, %pack_21 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_22 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_23: bf16, %out: bf16):
%24 = arith.mulf %in, %in_23 : bf16
%25 = arith.addf %out, %24 : bf16
linalg.yield %25 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %23 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %20 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%14 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %14 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%15 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %15 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%16:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %13, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%17 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %17 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%18 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %18 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_22: bf16, %out: bf16):
%20 = arith.mulf %in, %in_22 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
%inserted_slice = tensor.insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%extracted_slice_20 = tensor.extract_slice %inserted_slice[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%unpack_21 = tensor.unpack %19 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_21 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %16#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIECleanup (iree-amdaie-cleanup) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %9 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%10 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%11 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_19: bf16, %out: bf16):
%20 = arith.mulf %in, %in_19 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%12 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %16] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%17 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %17 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%18 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %18 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%20 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %20 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%21 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %21 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_19, %pack_21 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_22 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_23: bf16, %out: bf16):
%23 = arith.mulf %in, %in_23 : bf16
%24 = arith.addf %out, %23 : bf16
linalg.yield %24 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%13 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %13 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%14 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %14 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%15:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %12, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_21: bf16, %out: bf16):
%19 = arith.mulf %in, %in_21 : bf16
%20 = arith.addf %out, %19 : bf16
linalg.yield %20 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%unpack_20 = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_20 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %18 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %15#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEInsertLoopsForVectorization (iree-amdaie-insert-loops-for-vectorization) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c151 = arith.constant 151 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %9 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%10 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%11 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_19: bf16, %out: bf16):
%20 = arith.mulf %in, %in_19 : bf16
%21 = arith.addf %out, %20 : bf16
linalg.yield %21 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%12 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %16] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%17 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %17 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%18 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %18 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%20 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %20 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%21 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %21 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_19, %pack_21 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_22 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_23: bf16, %out: bf16):
%23 = arith.mulf %in, %in_23 : bf16
%24 = arith.addf %out, %23 : bf16
linalg.yield %24 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%13 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %13 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%14 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %14 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%15:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %12, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack_15, %pack_17 : tensor<1x1x8x11x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_21: bf16, %out: bf16):
%19 = arith.mulf %in, %in_21 : bf16
%20 = arith.addf %out, %19 : bf16
linalg.yield %20 : bf16
} -> tensor<1x1x16x11x4x4xbf16>
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%unpack_20 = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_20 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %18 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %15#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before AMDAIEVectorization (iree-amdaie-vectorization) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c1 = arith.constant 1 : index
%c1_0 = arith.constant 1 : index
%c1_1 = arith.constant 1 : index
%c1_2 = arith.constant 1 : index
%c1_3 = arith.constant 1 : index
%c1_4 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_5 = arith.constant 0 : index
%c0_6 = arith.constant 0 : index
%c1_7 = arith.constant 1 : index
%c1_8 = arith.constant 1 : index
%c1_9 = arith.constant 1 : index
%c1_10 = arith.constant 1 : index
%c1_11 = arith.constant 1 : index
%c1_12 = arith.constant 1 : index
%c0_13 = arith.constant 0 : index
%c0_14 = arith.constant 0 : index
%c0_15 = arith.constant 0 : index
%c1_16 = arith.constant 1 : index
%c1_17 = arith.constant 1 : index
%c1_18 = arith.constant 1 : index
%c1_19 = arith.constant 1 : index
%c1_20 = arith.constant 1 : index
%c1_21 = arith.constant 1 : index
%c0_22 = arith.constant 0 : index
%c0_23 = arith.constant 0 : index
%c0_24 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%c0_25 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c1_26 = arith.constant 1 : index
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_27 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_28 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_29 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_30 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_31 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0_25) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0_25) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0_25) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_32 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_33 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_31 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_30 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%extracted_slice_34 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%9 = bufferization.to_tensor %alloc_29 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_34 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %9 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_35 = tensor.extract_slice %extracted_slice_32[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%10 = bufferization.to_tensor %alloc_28 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_36 = tensor.pack %extracted_slice_35 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%11 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_41 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_27 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_42 = tensor.pack %extracted_slice_41 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_43 = tensor.extract_slice %pack_36[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_44 = tensor.pack %extracted_slice_43 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_45 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_45 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%c0_46 = arith.constant 0 : index
%c1_47 = arith.constant 1 : index
%c1_48 = arith.constant 1 : index
%c0_49 = arith.constant 0 : index
%c1_50 = arith.constant 1 : index
%c1_51 = arith.constant 1 : index
%c0_52 = arith.constant 0 : index
%c1_53 = arith.constant 1 : index
%c1_54 = arith.constant 1 : index
%c0_55 = arith.constant 0 : index
%c11 = arith.constant 11 : index
%c1_56 = arith.constant 1 : index
%c0_57 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1_58 = arith.constant 1 : index
%c0_59 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_60 = arith.constant 1 : index
%19 = scf.for %arg6 = %c0_46 to %c1_47 step %c1_48 iter_args(%arg7 = %18) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg8 = %c0_49 to %c1_50 step %c1_51 iter_args(%arg9 = %arg7) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg10 = %c0_52 to %c1_53 step %c1_54 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg12 = %c0_55 to %c11 step %c1_56 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg14 = %c0_57 to %c16 step %c1_58 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg16 = %c0_59 to %c8 step %c1_60 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_61 = tensor.extract_slice %pack_42[%arg6, %arg10, %arg16, %arg12, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_62 = tensor.extract_slice %pack_44[%arg10, %arg8, %arg14, %arg16, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_63 = tensor.extract_slice %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_61, %extracted_slice_62 : tensor<1x1x1x1x4x8xbf16>, tensor<1x1x1x1x8x4xbf16>) outs(%extracted_slice_63 : tensor<1x1x1x1x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_64: bf16, %out: bf16):
%26 = arith.mulf %in, %in_64 : bf16
%27 = arith.addf %out, %26 : bf16
linalg.yield %27 : bf16
} -> tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %25 into %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%12 = scf.for %arg3 = %c1_26 to %c151 step %c1_26 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_41 = tensor.extract_slice %extracted_slice[0, %16] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%17 = bufferization.to_tensor %alloc_29 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_42 = tensor.pack %extracted_slice_41 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %17 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_43 = tensor.extract_slice %extracted_slice_32[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%18 = bufferization.to_tensor %alloc_28 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_44 = tensor.pack %extracted_slice_43 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %18 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_45 = tensor.extract_slice %pack_42[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%20 = bufferization.to_tensor %alloc_27 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_46 = tensor.pack %extracted_slice_45 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %20 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_47 = tensor.extract_slice %pack_44[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%21 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_48 = tensor.pack %extracted_slice_47 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %21 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_49 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%c0_50 = arith.constant 0 : index
%c1_51 = arith.constant 1 : index
%c1_52 = arith.constant 1 : index
%c0_53 = arith.constant 0 : index
%c1_54 = arith.constant 1 : index
%c1_55 = arith.constant 1 : index
%c0_56 = arith.constant 0 : index
%c1_57 = arith.constant 1 : index
%c1_58 = arith.constant 1 : index
%c0_59 = arith.constant 0 : index
%c11 = arith.constant 11 : index
%c1_60 = arith.constant 1 : index
%c0_61 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1_62 = arith.constant 1 : index
%c0_63 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_64 = arith.constant 1 : index
%22 = scf.for %arg8 = %c0_50 to %c1_51 step %c1_52 iter_args(%arg9 = %extracted_slice_49) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg10 = %c0_53 to %c1_54 step %c1_55 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg12 = %c0_56 to %c1_57 step %c1_58 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%25 = scf.for %arg14 = %c0_59 to %c11 step %c1_60 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%26 = scf.for %arg16 = %c0_61 to %c16 step %c1_62 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%27 = scf.for %arg18 = %c0_63 to %c8 step %c1_64 iter_args(%arg19 = %arg17) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_65 = tensor.extract_slice %pack_46[%arg8, %arg12, %arg18, %arg14, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_66 = tensor.extract_slice %pack_48[%arg12, %arg10, %arg16, %arg18, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_67 = tensor.extract_slice %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%28 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_65, %extracted_slice_66 : tensor<1x1x1x1x4x8xbf16>, tensor<1x1x1x1x8x4xbf16>) outs(%extracted_slice_67 : tensor<1x1x1x1x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_68: bf16, %out: bf16):
%29 = arith.mulf %in, %in_68 : bf16
%30 = arith.addf %out, %29 : bf16
linalg.yield %30 : bf16
} -> tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %28 into %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %27 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %26 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %25 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_37 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%13 = bufferization.to_tensor %alloc_29 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_38 = tensor.pack %extracted_slice_37 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %13 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_39 = tensor.extract_slice %extracted_slice_32[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%14 = bufferization.to_tensor %alloc_28 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_40 = tensor.pack %extracted_slice_39 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %14 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%15:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %12, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_41 = tensor.extract_slice %pack_38[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_27 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_42 = tensor.pack %extracted_slice_41 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_43 = tensor.extract_slice %pack_40[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_44 = tensor.pack %extracted_slice_43 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_45 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%c0_46 = arith.constant 0 : index
%c1_47 = arith.constant 1 : index
%c1_48 = arith.constant 1 : index
%c0_49 = arith.constant 0 : index
%c1_50 = arith.constant 1 : index
%c1_51 = arith.constant 1 : index
%c0_52 = arith.constant 0 : index
%c1_53 = arith.constant 1 : index
%c1_54 = arith.constant 1 : index
%c0_55 = arith.constant 0 : index
%c11 = arith.constant 11 : index
%c1_56 = arith.constant 1 : index
%c0_57 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1_58 = arith.constant 1 : index
%c0_59 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_60 = arith.constant 1 : index
%18 = scf.for %arg7 = %c0_46 to %c1_47 step %c1_48 iter_args(%arg8 = %extracted_slice_45) -> (tensor<1x1x16x11x4x4xbf16>) {
%19 = scf.for %arg9 = %c0_49 to %c1_50 step %c1_51 iter_args(%arg10 = %arg8) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg11 = %c0_52 to %c1_53 step %c1_54 iter_args(%arg12 = %arg10) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg13 = %c0_55 to %c11 step %c1_56 iter_args(%arg14 = %arg12) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg15 = %c0_57 to %c16 step %c1_58 iter_args(%arg16 = %arg14) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg17 = %c0_59 to %c8 step %c1_60 iter_args(%arg18 = %arg16) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_63 = tensor.extract_slice %pack_42[%arg7, %arg11, %arg17, %arg13, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_64 = tensor.extract_slice %pack_44[%arg11, %arg9, %arg15, %arg17, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_65 = tensor.extract_slice %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_63, %extracted_slice_64 : tensor<1x1x1x1x4x8xbf16>, tensor<1x1x1x1x8x4xbf16>) outs(%extracted_slice_65 : tensor<1x1x1x1x4x4xbf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[44, 128], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [44, 64, 64], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
^bb0(%in: bf16, %in_66: bf16, %out: bf16):
%25 = arith.mulf %in, %in_66 : bf16
%26 = arith.addf %out, %25 : bf16
linalg.yield %26 : bf16
} -> tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %24 into %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %19 : tensor<1x1x16x11x4x4xbf16>
}
%extracted_slice_61 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%unpack_62 = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_61 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_62 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %18 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %15#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_33 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_31 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_30 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_29 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_28 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_27 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before EliminateEmptyTensors (iree-eliminate-empty-tensors) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = tensor.empty() : tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %9 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%10 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%11 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%19 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %18) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg12 = %c0 to %c11 step %c1 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg14 = %c0 to %c16 step %c1 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg16 = %c0 to %c8 step %c1 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_19 = tensor.extract_slice %pack_15[%arg6, %arg10, %arg16, %arg12, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[%arg10, %arg8, %arg14, %arg16, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_21 = tensor.extract_slice %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%25 = vector.transfer_read %extracted_slice_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%26 = vector.transfer_read %extracted_slice_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%27 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%28 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %25, %26, %27 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%29 = vector.transfer_write %28, %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %29 into %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%12 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %16] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%17 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %17 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%18 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %18 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%20 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %20 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%21 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %21 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%22 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %extracted_slice_22) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg12 = %c0 to %c1 step %c1 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%25 = scf.for %arg14 = %c0 to %c11 step %c1 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%26 = scf.for %arg16 = %c0 to %c16 step %c1 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%27 = scf.for %arg18 = %c0 to %c8 step %c1 iter_args(%arg19 = %arg17) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_23 = tensor.extract_slice %pack_19[%arg8, %arg12, %arg18, %arg14, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_24 = tensor.extract_slice %pack_21[%arg12, %arg10, %arg16, %arg18, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_25 = tensor.extract_slice %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%28 = vector.transfer_read %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%29 = vector.transfer_read %extracted_slice_24[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%30 = vector.transfer_read %extracted_slice_25[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%31 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %28, %29, %30 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%32 = vector.transfer_write %31, %extracted_slice_25[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %32 into %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %27 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %26 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %25 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%13 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %13 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%14 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %14 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%15:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %12, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = scf.for %arg7 = %c0 to %c1 step %c1 iter_args(%arg8 = %extracted_slice_18) -> (tensor<1x1x16x11x4x4xbf16>) {
%19 = scf.for %arg9 = %c0 to %c1 step %c1 iter_args(%arg10 = %arg8) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg11 = %c0 to %c1 step %c1 iter_args(%arg12 = %arg10) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg13 = %c0 to %c11 step %c1 iter_args(%arg14 = %arg12) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg15 = %c0 to %c16 step %c1 iter_args(%arg16 = %arg14) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg17 = %c0 to %c8 step %c1 iter_args(%arg18 = %arg16) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_21 = tensor.extract_slice %pack_15[%arg7, %arg11, %arg17, %arg13, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_22 = tensor.extract_slice %pack_17[%arg11, %arg9, %arg15, %arg17, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_23 = tensor.extract_slice %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%24 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%25 = vector.transfer_read %extracted_slice_22[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%26 = vector.transfer_read %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%27 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %26 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%28 = vector.transfer_write %27, %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %28 into %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %19 : tensor<1x1x16x11x4x4xbf16>
}
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%unpack_20 = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_20 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %18 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %15#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before EmptyTensorToAllocTensor (empty-tensor-to-alloc-tensor) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>> -> tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %9 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%10 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%11 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%19 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %18) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg12 = %c0 to %c11 step %c1 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg14 = %c0 to %c16 step %c1 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg16 = %c0 to %c8 step %c1 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_19 = tensor.extract_slice %pack_15[%arg6, %arg10, %arg16, %arg12, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[%arg10, %arg8, %arg14, %arg16, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_21 = tensor.extract_slice %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%25 = vector.transfer_read %extracted_slice_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%26 = vector.transfer_read %extracted_slice_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%27 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%28 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %25, %26, %27 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%29 = vector.transfer_write %28, %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %29 into %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%12 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %16] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%17 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %17 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%18 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %18 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%20 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %20 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%21 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %21 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%22 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %extracted_slice_22) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg12 = %c0 to %c1 step %c1 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%25 = scf.for %arg14 = %c0 to %c11 step %c1 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%26 = scf.for %arg16 = %c0 to %c16 step %c1 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%27 = scf.for %arg18 = %c0 to %c8 step %c1 iter_args(%arg19 = %arg17) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_23 = tensor.extract_slice %pack_19[%arg8, %arg12, %arg18, %arg14, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_24 = tensor.extract_slice %pack_21[%arg12, %arg10, %arg16, %arg18, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_25 = tensor.extract_slice %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%28 = vector.transfer_read %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%29 = vector.transfer_read %extracted_slice_24[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%30 = vector.transfer_read %extracted_slice_25[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%31 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %28, %29, %30 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%32 = vector.transfer_write %31, %extracted_slice_25[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %32 into %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %27 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %26 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %25 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%13 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %13 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%14 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %14 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%15:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %12, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = scf.for %arg7 = %c0 to %c1 step %c1 iter_args(%arg8 = %extracted_slice_18) -> (tensor<1x1x16x11x4x4xbf16>) {
%19 = scf.for %arg9 = %c0 to %c1 step %c1 iter_args(%arg10 = %arg8) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg11 = %c0 to %c1 step %c1 iter_args(%arg12 = %arg10) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg13 = %c0 to %c11 step %c1 iter_args(%arg14 = %arg12) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg15 = %c0 to %c16 step %c1 iter_args(%arg16 = %arg14) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg17 = %c0 to %c8 step %c1 iter_args(%arg18 = %arg16) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_21 = tensor.extract_slice %pack_15[%arg7, %arg11, %arg17, %arg13, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_22 = tensor.extract_slice %pack_17[%arg11, %arg9, %arg15, %arg17, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_23 = tensor.extract_slice %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%24 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%25 = vector.transfer_read %extracted_slice_22[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%26 = vector.transfer_read %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%27 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %26 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%28 = vector.transfer_write %27, %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %28 into %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %19 : tensor<1x1x16x11x4x4xbf16>
}
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%unpack_20 = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_20 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %18 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %15#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before IREEComprehensiveBufferize (iree-codegen-iree-comprehensive-bufferize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [308, 9728], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<308x9728xbf16>> -> tensor<308x9728xbf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [9728, 2432], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9728x2432xbf16>> -> tensor<9728x2432xbf16>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>> -> tensor<308x2432xbf16>
%6 = scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) shared_outs(%arg2 = %5) -> (tensor<308x2432xbf16>) {
%extracted_slice = tensor.extract_slice %3[%arg0, 0] [44, 9728] [1, 1] : tensor<308x9728xbf16> to tensor<44x9728xbf16>
%extracted_slice_5 = tensor.extract_slice %4[0, %arg1] [9728, 128] [1, 1] : tensor<9728x2432xbf16> to tensor<9728x128xbf16>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<308x2432xbf16> to tensor<44x128xbf16>
%7 = bufferization.to_tensor %alloc_4 restrict writable : memref<1x2x44x64xbf16, 1 : i32>
%8 = bufferization.to_tensor %alloc_3 restrict writable : memref<1x2x16x11x4x4xbf16, 2 : i32>
%extracted_slice_7 = tensor.extract_slice %extracted_slice[0, 0] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack = tensor.pack %extracted_slice_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %9 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_8 = tensor.extract_slice %extracted_slice_5[0, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%10 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_9 = tensor.pack %extracted_slice_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %10 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%11 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %8) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_9[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_18 : tensor<1x1x16x11x4x4xbf16>) -> tensor<1x1x16x11x4x4xbf16>
%19 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %18) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg12 = %c0 to %c11 step %c1 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg14 = %c0 to %c16 step %c1 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg16 = %c0 to %c8 step %c1 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_19 = tensor.extract_slice %pack_15[%arg6, %arg10, %arg16, %arg12, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[%arg10, %arg8, %arg14, %arg16, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_21 = tensor.extract_slice %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%25 = vector.transfer_read %extracted_slice_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%26 = vector.transfer_read %extracted_slice_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%27 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%28 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %25, %26, %27 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%29 = vector.transfer_write %28, %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %29 into %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %19 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%12 = scf.for %arg3 = %c1 to %c151 step %c1 iter_args(%arg4 = %11) -> (tensor<1x2x16x11x4x4xbf16>) {
%16 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%extracted_slice_14 = tensor.extract_slice %extracted_slice[0, %16] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%17 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_15 = tensor.pack %extracted_slice_14 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %17 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%16, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%18 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %18 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%19 = scf.forall (%arg5, %arg6) in (1, 2) shared_outs(%arg7 = %arg4) -> (tensor<1x2x16x11x4x4xbf16>) {
%extracted_slice_18 = tensor.extract_slice %pack_15[%arg5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%20 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_19 = tensor.pack %extracted_slice_18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %20 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_20 = tensor.extract_slice %pack_17[0, %arg6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%21 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_21 = tensor.pack %extracted_slice_20 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %21 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_22 = tensor.extract_slice %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%22 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %extracted_slice_22) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (tensor<1x1x16x11x4x4xbf16>) {
%24 = scf.for %arg12 = %c0 to %c1 step %c1 iter_args(%arg13 = %arg11) -> (tensor<1x1x16x11x4x4xbf16>) {
%25 = scf.for %arg14 = %c0 to %c11 step %c1 iter_args(%arg15 = %arg13) -> (tensor<1x1x16x11x4x4xbf16>) {
%26 = scf.for %arg16 = %c0 to %c16 step %c1 iter_args(%arg17 = %arg15) -> (tensor<1x1x16x11x4x4xbf16>) {
%27 = scf.for %arg18 = %c0 to %c8 step %c1 iter_args(%arg19 = %arg17) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_23 = tensor.extract_slice %pack_19[%arg8, %arg12, %arg18, %arg14, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_24 = tensor.extract_slice %pack_21[%arg12, %arg10, %arg16, %arg18, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_25 = tensor.extract_slice %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%28 = vector.transfer_read %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%29 = vector.transfer_read %extracted_slice_24[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%30 = vector.transfer_read %extracted_slice_25[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%31 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %28, %29, %30 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%32 = vector.transfer_write %31, %extracted_slice_25[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %32 into %arg19[%arg8, %arg10, %arg16, %arg14, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %27 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %26 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %25 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %24 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg7[%arg5, %arg6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %19 : tensor<1x2x16x11x4x4xbf16>
}
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, 9664] [44, 64] [1, 1] : tensor<44x9728xbf16> to tensor<44x64xbf16>
%13 = bufferization.to_tensor %alloc_2 restrict writable : memref<1x1x44x64xbf16, 1 : i32>
%pack_11 = tensor.pack %extracted_slice_10 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %13 : tensor<44x64xbf16> -> tensor<1x1x44x64xbf16>
%extracted_slice_12 = tensor.extract_slice %extracted_slice_5[9664, 0] [64, 128] [1, 1] : tensor<9728x128xbf16> to tensor<64x128xbf16>
%14 = bufferization.to_tensor %alloc_1 restrict writable : memref<1x2x64x64xbf16, 1 : i32>
%pack_13 = tensor.pack %extracted_slice_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %14 : tensor<64x128xbf16> -> tensor<1x2x64x64xbf16>
%15:2 = scf.forall (%arg3, %arg4) in (1, 2) shared_outs(%arg5 = %12, %arg6 = %7) -> (tensor<1x2x16x11x4x4xbf16>, tensor<1x2x44x64xbf16>) {
%extracted_slice_14 = tensor.extract_slice %pack_11[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> to tensor<1x1x44x64xbf16>
%16 = bufferization.to_tensor %alloc_0 restrict writable : memref<1x1x8x11x4x8xbf16, 2 : i32>
%pack_15 = tensor.pack %extracted_slice_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %16 : tensor<1x1x44x64xbf16> -> tensor<1x1x8x11x4x8xbf16>
%extracted_slice_16 = tensor.extract_slice %pack_13[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x2x64x64xbf16> to tensor<1x1x64x64xbf16>
%17 = bufferization.to_tensor %alloc restrict writable : memref<1x1x16x8x8x4xbf16, 2 : i32>
%pack_17 = tensor.pack %extracted_slice_16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %17 : tensor<1x1x64x64xbf16> -> tensor<1x1x16x8x8x4xbf16>
%extracted_slice_18 = tensor.extract_slice %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x2x16x11x4x4xbf16> to tensor<1x1x16x11x4x4xbf16>
%18 = scf.for %arg7 = %c0 to %c1 step %c1 iter_args(%arg8 = %extracted_slice_18) -> (tensor<1x1x16x11x4x4xbf16>) {
%19 = scf.for %arg9 = %c0 to %c1 step %c1 iter_args(%arg10 = %arg8) -> (tensor<1x1x16x11x4x4xbf16>) {
%20 = scf.for %arg11 = %c0 to %c1 step %c1 iter_args(%arg12 = %arg10) -> (tensor<1x1x16x11x4x4xbf16>) {
%21 = scf.for %arg13 = %c0 to %c11 step %c1 iter_args(%arg14 = %arg12) -> (tensor<1x1x16x11x4x4xbf16>) {
%22 = scf.for %arg15 = %c0 to %c16 step %c1 iter_args(%arg16 = %arg14) -> (tensor<1x1x16x11x4x4xbf16>) {
%23 = scf.for %arg17 = %c0 to %c8 step %c1 iter_args(%arg18 = %arg16) -> (tensor<1x1x16x11x4x4xbf16>) {
%extracted_slice_21 = tensor.extract_slice %pack_15[%arg7, %arg11, %arg17, %arg13, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x11x4x8xbf16> to tensor<1x1x1x1x4x8xbf16>
%extracted_slice_22 = tensor.extract_slice %pack_17[%arg11, %arg9, %arg15, %arg17, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x8x8x4xbf16> to tensor<1x1x1x1x8x4xbf16>
%extracted_slice_23 = tensor.extract_slice %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> to tensor<1x1x1x1x4x4xbf16>
%24 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x8xbf16>, vector<1x1x1x1x4x8xbf16>
%25 = vector.transfer_read %extracted_slice_22[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x8x4xbf16>, vector<1x1x1x1x8x4xbf16>
%26 = vector.transfer_read %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : tensor<1x1x1x1x4x4xbf16>, vector<1x1x1x1x4x4xbf16>
%27 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %26 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%28 = vector.transfer_write %27, %extracted_slice_23[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, tensor<1x1x1x1x4x4xbf16>
%inserted_slice = tensor.insert_slice %28 into %arg18[%arg7, %arg9, %arg15, %arg13, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x1x1x4x4xbf16> into tensor<1x1x16x11x4x4xbf16>
scf.yield %inserted_slice : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %23 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %22 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %21 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %20 : tensor<1x1x16x11x4x4xbf16>
}
scf.yield %19 : tensor<1x1x16x11x4x4xbf16>
}
%extracted_slice_19 = tensor.extract_slice %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x2x44x64xbf16> to tensor<1x1x44x64xbf16>
%unpack_20 = tensor.unpack %18 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %extracted_slice_19 : tensor<1x1x16x11x4x4xbf16> -> tensor<1x1x44x64xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack_20 into %arg6[%arg3, %arg4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : tensor<1x1x44x64xbf16> into tensor<1x2x44x64xbf16>
tensor.parallel_insert_slice %18 into %arg5[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : tensor<1x1x16x11x4x4xbf16> into tensor<1x2x16x11x4x4xbf16>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%unpack = tensor.unpack %15#1 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %extracted_slice_6 : tensor<1x2x44x64xbf16> -> tensor<44x128xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg2[%arg0, %arg1] [44, 128] [1, 1] : tensor<44x128xbf16> into tensor<308x2432xbf16>
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [308, 2432], strides = [1, 1] : tensor<308x2432xbf16> -> !flow.dispatch.tensor<writeonly:tensor<308x2432xbf16>>
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before ResolveShapedTypeResultDims (resolve-shaped-type-result-dims) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_14 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
%4 = scf.for %arg4 = %c0 to %c1 step %c1 iter_args(%arg5 = %subview_14) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%5 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %arg5) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%6 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%7 = scf.for %arg10 = %c0 to %c11 step %c1 iter_args(%arg11 = %arg9) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%8 = scf.for %arg12 = %c0 to %c16 step %c1 iter_args(%arg13 = %arg11) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%9 = scf.for %arg14 = %c0 to %c8 step %c1 iter_args(%arg15 = %arg13) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%subview_16 = memref.subview %alloc_0[%arg4, %arg8, %arg14, %arg10, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[%arg8, %arg6, %arg12, %arg14, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%10 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%11 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%12 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%13 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %11, %12 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %13, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_18 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_19 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
scf.yield %arg15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %9 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %8 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %7 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %6 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %5 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
%subview_15 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%4 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%3 = scf.for %arg2 = %c1 to %c151 step %c1 iter_args(%arg3 = %alloc_3) -> (memref<1x2x16x11x4x4xbf16, 2 : i32>) {
%4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_12 = memref.subview %subview[0, %4] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_13 = memref.subview %subview_5[%4, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg4, %arg5) in (1, 2) {
%subview_14 = memref.subview %alloc_2[%arg4, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg5, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_16 = memref.subview %arg3[%arg4, %arg5, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%5 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %subview_16) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%6 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%7 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%8 = scf.for %arg12 = %c0 to %c11 step %c1 iter_args(%arg13 = %arg11) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%9 = scf.for %arg14 = %c0 to %c16 step %c1 iter_args(%arg15 = %arg13) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%10 = scf.for %arg16 = %c0 to %c8 step %c1 iter_args(%arg17 = %arg15) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%subview_18 = memref.subview %alloc_0[%arg6, %arg10, %arg16, %arg12, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %alloc[%arg10, %arg8, %arg14, %arg16, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_20 = memref.subview %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%11 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%12 = vector.transfer_read %subview_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%13 = vector.transfer_read %subview_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%14 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %11, %12, %13 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %14, %subview_20[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_21 = memref.subview %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_20 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_21 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
scf.yield %arg17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %10 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %9 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %8 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %7 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %6 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
%subview_17 = memref.subview %arg3[%arg4, %arg5, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%5 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %arg3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_14 = memref.subview %3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = scf.for %arg4 = %c0 to %c1 step %c1 iter_args(%arg5 = %subview_14) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%5 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %arg5) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%6 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%7 = scf.for %arg10 = %c0 to %c11 step %c1 iter_args(%arg11 = %arg9) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%8 = scf.for %arg12 = %c0 to %c16 step %c1 iter_args(%arg13 = %arg11) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%9 = scf.for %arg14 = %c0 to %c8 step %c1 iter_args(%arg15 = %arg13) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%subview_18 = memref.subview %alloc_0[%arg4, %arg8, %arg14, %arg10, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %alloc[%arg8, %arg6, %arg12, %arg14, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_20 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%10 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%11 = vector.transfer_read %subview_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%12 = vector.transfer_read %subview_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%13 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %11, %12 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %13, %subview_20[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_21 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_20 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_21 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
scf.yield %arg15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %9 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %8 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %7 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %6 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %5 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
%subview_15 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %4 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_15 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
%subview_16 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview_15 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) outs(%subview_16 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
%subview_17 = memref.subview %3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%4 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
%subview_11 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_6 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_11 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>) outs(%2 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_14 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
%4 = scf.for %arg4 = %c0 to %c1 step %c1 iter_args(%arg5 = %subview_14) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%5 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %arg5) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%6 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%7 = scf.for %arg10 = %c0 to %c11 step %c1 iter_args(%arg11 = %arg9) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%8 = scf.for %arg12 = %c0 to %c16 step %c1 iter_args(%arg13 = %arg11) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%9 = scf.for %arg14 = %c0 to %c8 step %c1 iter_args(%arg15 = %arg13) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%subview_16 = memref.subview %alloc_0[%arg4, %arg8, %arg14, %arg10, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[%arg8, %arg6, %arg12, %arg14, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%10 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%11 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%12 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%13 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %11, %12 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %13, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_18 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_19 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
scf.yield %arg15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %9 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %8 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %7 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %6 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %5 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
%subview_15 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%4 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%3 = scf.for %arg2 = %c1 to %c151 step %c1 iter_args(%arg3 = %alloc_3) -> (memref<1x2x16x11x4x4xbf16, 2 : i32>) {
%4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_12 = memref.subview %subview[0, %4] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_13 = memref.subview %subview_5[%4, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg4, %arg5) in (1, 2) {
%subview_14 = memref.subview %alloc_2[%arg4, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg5, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_16 = memref.subview %arg3[%arg4, %arg5, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%5 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %subview_16) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%6 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%7 = scf.for %arg10 = %c0 to %c1 step %c1 iter_args(%arg11 = %arg9) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%8 = scf.for %arg12 = %c0 to %c11 step %c1 iter_args(%arg13 = %arg11) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%9 = scf.for %arg14 = %c0 to %c16 step %c1 iter_args(%arg15 = %arg13) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%10 = scf.for %arg16 = %c0 to %c8 step %c1 iter_args(%arg17 = %arg15) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%subview_18 = memref.subview %alloc_0[%arg6, %arg10, %arg16, %arg12, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %alloc[%arg10, %arg8, %arg14, %arg16, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_20 = memref.subview %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%11 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%12 = vector.transfer_read %subview_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%13 = vector.transfer_read %subview_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%14 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %11, %12, %13 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %14, %subview_20[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_21 = memref.subview %arg17[%arg6, %arg8, %arg14, %arg12, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_20 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_21 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
scf.yield %arg17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %10 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %9 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %8 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %7 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %6 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
%subview_17 = memref.subview %arg3[%arg4, %arg5, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%5 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.yield %arg3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_14 = memref.subview %3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = scf.for %arg4 = %c0 to %c1 step %c1 iter_args(%arg5 = %subview_14) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%5 = scf.for %arg6 = %c0 to %c1 step %c1 iter_args(%arg7 = %arg5) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%6 = scf.for %arg8 = %c0 to %c1 step %c1 iter_args(%arg9 = %arg7) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%7 = scf.for %arg10 = %c0 to %c11 step %c1 iter_args(%arg11 = %arg9) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%8 = scf.for %arg12 = %c0 to %c16 step %c1 iter_args(%arg13 = %arg11) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%9 = scf.for %arg14 = %c0 to %c8 step %c1 iter_args(%arg15 = %arg13) -> (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
%subview_18 = memref.subview %alloc_0[%arg4, %arg8, %arg14, %arg10, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %alloc[%arg8, %arg6, %arg12, %arg14, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_20 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%10 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%11 = vector.transfer_read %subview_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%12 = vector.transfer_read %subview_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%13 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %11, %12 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %13, %subview_20[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_21 = memref.subview %arg15[%arg4, %arg6, %arg12, %arg10, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_20 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_21 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
scf.yield %arg15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %9 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %8 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %7 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %6 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
scf.yield %5 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
%subview_15 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %4 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_15 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
%subview_16 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview_15 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) outs(%subview_16 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
%subview_17 = memref.subview %3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%4 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
%subview_11 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_6 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_11 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>) outs(%2 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before CSE (cse) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_14 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_14[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %subview_14[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_18 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_19 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
}
}
}
%subview_15 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_14 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_12 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_13 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_14 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_15 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_16 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_18 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_20 = memref.subview %subview_16[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_20[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_21 = memref.subview %subview_16[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_20 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_21 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
}
}
}
%subview_17 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_16 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_18 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_19 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_20 = memref.subview %subview_14[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_19[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_20[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_20[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%subview_21 = memref.subview %subview_14[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_20 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_21 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
}
}
}
%subview_15 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_15 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
%subview_16 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview_15 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) outs(%subview_16 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
%subview_17 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_14 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
%subview_11 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_6 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_11 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_14 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_15 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_14[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_16[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_16 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_16 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
}
}
}
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_11 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_12 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_13 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_15[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_18 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_18 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
}
}
}
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_15 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_15 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_17[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_17 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_17 : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
}
}
}
%subview_14 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_14 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview_14 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) outs(%subview_14 : memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview_6 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_6 : memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before CleanupBufferAllocView (iree-codegen-cleanup-buffer-alloc-view) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_14 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_15 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_14[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_16[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_11 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_12 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_13 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_15[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_15 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_17[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
%subview_14 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_14 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before HoistStaticallyBoundAllocations (iree-hoist-statically-bound-allocations) //----- //
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_14 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_15 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_14[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_16[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_11 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_12 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_13 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_15[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_15 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_17[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
%subview_14 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_14 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
// -----// IR Dump Before LowerUKernelOpsToCalls (iree-codegen-lower-ukernel-ops-to-calls) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_14 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_15 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_14[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_16[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_11 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_12 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_13 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_15[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_15 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_17[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
%subview_14 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_14 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before EraseHALDescriptorTypeFromMemRef (iree-codegen-erase-hal-descriptor-type-from-memref) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16, #hal.descriptor_type<storage_buffer>> to memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_14 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_15 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_14[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_16[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_11 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_12 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_13 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_15[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_15 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_17[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
%subview_14 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_14 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before FoldMemRefAliasOps (fold-memref-alias-ops) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %0[%arg0, 0] [44, 9728] [1, 1] : memref<308x9728xbf16> to memref<44x9728xbf16, strided<[9728, 1], offset: ?>>
%subview_5 = memref.subview %1[0, %arg1] [9728, 128] [1, 1] : memref<9728x2432xbf16> to memref<9728x128xbf16, strided<[2432, 1], offset: ?>>
%subview_6 = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_7 = memref.subview %subview[0, 0] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %subview_5[0, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_13 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_14 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_15 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_14[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_16[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_11 = memref.subview %subview[0, %3] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_11 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_12 = memref.subview %subview_5[%3, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_13 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_14 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_14 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_15 = memref.subview %alloc_3[%arg3, %arg4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%subview_16 = memref.subview %alloc_0[0, 0, %arg7, %arg5, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %alloc[0, 0, %arg6, %arg7, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_18 = memref.subview %subview_15[0, 0, %arg6, %arg5, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %subview_18[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %subview_18[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_9 = memref.subview %subview[0, 9664] [44, 64] [1, 1] : memref<44x9728xbf16, strided<[9728, 1], offset: ?>> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_10 = memref.subview %subview_5[9664, 0] [64, 128] [1, 1] : memref<9728x128xbf16, strided<[2432, 1], offset: ?>> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%subview_15 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x8x11x4x8xbf16, 2 : i32> to memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>
%subview_16 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x8x8x4xbf16, 2 : i32> to memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>
%subview_17 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%3 = vector.transfer_read %subview_15[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x8xbf16, strided<[2816, 2816, 352, 32, 8, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %subview_16[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x8x4xbf16, strided<[4096, 4096, 256, 32, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %subview_17[%c0, %c0, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %subview_17[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x1x1x1x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
}
}
}
%subview_14 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_13 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_14 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview_6 : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>>)
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before AMDAIEBridgeToAIR (iree-amdaie-bridge-to-air) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
scf.forall (%arg0, %arg1) = (0, 0) to (308, 2432) step (44, 128) {
%subview = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %0[%arg0, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_5 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_6 = memref.subview %1[0, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_6 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_9 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_10 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_11 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_11 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_9 = memref.subview %0[%arg0, %3] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%4 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg2)
%subview_10 = memref.subview %1[%4, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg3, %arg4) in (1, 2) {
%subview_11 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%5 = vector.transfer_read %alloc_0[%c0, %c0, %arg7, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %alloc[%c0, %c0, %arg6, %arg7, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
}
%subview_7 = memref.subview %0[%arg0, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %1[9664, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.forall (%arg2, %arg3) in (1, 2) {
%subview_9 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_10 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_11 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_12 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_12 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>>)
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before AMDAIEDecomposeLinalgExtPackUnPackToAIR (iree-amdaie-decompose-pack-unpack-to-air) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : index
%c44 = arith.constant 44 : index
%c2432 = arith.constant 2432 : index
%c308 = arith.constant 308 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c308, %c2432) step (%c44, %c128) {
%subview = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %0[%arg0, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_5 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_6 = memref.subview %1[0, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_6 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_9 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_10 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_11 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_11 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
scf.reduce
}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg2]
%subview_9 = memref.subview %0[%arg0, %3] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_9 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg2]
%subview_10 = memref.subview %1[%4, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_11 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_12 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_12 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%5 = vector.transfer_read %alloc_0[%c0, %c0, %arg7, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %alloc[%c0, %c0, %arg6, %arg7, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
scf.reduce
}
}
%subview_7 = memref.subview %0[%arg0, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
iree_linalg_ext.pack %subview_7 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %alloc_2 : (memref<44x64xbf16, strided<[9728, 1], offset: ?>> memref<1x1x44x64xbf16, 1 : i32>)
%subview_8 = memref.subview %1[9664, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_1 : (memref<64x128xbf16, strided<[2432, 1], offset: ?>> memref<1x2x64x64xbf16, 1 : i32>)
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_9 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_0 : (memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> memref<1x1x8x11x4x8xbf16, 2 : i32>)
%subview_10 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.pack %subview_10 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 4] into %alloc : (memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> memref<1x1x16x8x8x4xbf16, 2 : i32>)
%subview_11 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_12 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
iree_linalg_ext.unpack %subview_11 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 4] into %subview_12 : (memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>)
scf.reduce
}
iree_linalg_ext.unpack %alloc_4 inner_dims_pos = [0, 1] inner_tiles = [44, 64] into %subview : (memref<1x2x44x64xbf16, 1 : i32> memref<44x128xbf16, strided<[2432, 1], offset: ?>>)
scf.reduce
}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before CSE (cse) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : index
%c44 = arith.constant 44 : index
%c2432 = arith.constant 2432 : index
%c308 = arith.constant 308 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c308, %c2432) step (%c44, %c128) {
%subview = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %0[%arg0, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_5[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_6 = memref.subview %1[0, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_6 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_13 = memref.expand_shape %subview_12 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_14 = memref.transpose %expand_shape_13 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_14[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_16 = memref.expand_shape %subview_15 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_17 = memref.transpose %expand_shape_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc[] [] [], %transpose_17[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_18 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_18 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
scf.reduce
}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg2]
%subview_12 = memref.subview %0[%arg0, %3] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_12[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg2]
%subview_13 = memref.subview %1[%4, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose_15[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_16 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_17 = memref.expand_shape %subview_16 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_18 = memref.transpose %expand_shape_17 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_18[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_19 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_20 = memref.expand_shape %subview_19 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_21 = memref.transpose %expand_shape_20 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc[] [] [], %transpose_21[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%5 = vector.transfer_read %alloc_0[%c0, %c0, %arg7, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %alloc[%c0, %c0, %arg6, %arg7, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
scf.reduce
}
}
%subview_7 = memref.subview %0[%arg0, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_7[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_8 = memref.subview %1[9664, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_9 = memref.expand_shape %subview_8 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose_10[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_13 = memref.expand_shape %subview_12 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_14 = memref.transpose %expand_shape_13 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_14[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_16 = memref.expand_shape %subview_15 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_17 = memref.transpose %expand_shape_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc[] [] [], %transpose_17[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_18 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_19 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_20 = memref.transpose %subview_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_19[] [] [], %transpose_20[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
scf.reduce
}
%transpose_11 = memref.transpose %alloc_4 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_11[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
scf.reduce
}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before ParallelToHerd (air-par-to-herd) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : index
%c44 = arith.constant 44 : index
%c2432 = arith.constant 2432 : index
%c308 = arith.constant 308 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c308, %c2432) step (%c44, %c128) {
%subview = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %0[%arg0, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_5[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_6 = memref.subview %1[0, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_6 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_13 = memref.expand_shape %subview_12 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_14 = memref.transpose %expand_shape_13 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_14[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_16 = memref.expand_shape %subview_15 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_17 = memref.transpose %expand_shape_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc[] [] [], %transpose_17[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_18 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_18 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
scf.reduce
}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg2]
%subview_12 = memref.subview %0[%arg0, %3] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_12[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_13 = memref.subview %1[%3, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose_15[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_16 = memref.subview %alloc_2[%arg3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_17 = memref.expand_shape %subview_16 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_18 = memref.transpose %expand_shape_17 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_18[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_19 = memref.subview %alloc_1[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_20 = memref.expand_shape %subview_19 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_21 = memref.transpose %expand_shape_20 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc[] [] [], %transpose_21[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg5 = %c0 to %c11 step %c1 {
scf.for %arg6 = %c0 to %c16 step %c1 {
scf.for %arg7 = %c0 to %c8 step %c1 {
%4 = vector.transfer_read %alloc_0[%c0, %c0, %arg7, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%5 = vector.transfer_read %alloc[%c0, %c0, %arg6, %arg7, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%6 = vector.transfer_read %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %6 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %7, %alloc_3[%arg3, %arg4, %arg6, %arg5, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
scf.reduce
}
}
%subview_7 = memref.subview %0[%arg0, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_7[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_8 = memref.subview %1[9664, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_9 = memref.expand_shape %subview_8 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose_10[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c1, %c2) step (%c1, %c1) {
%subview_12 = memref.subview %alloc_2[%arg2, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_13 = memref.expand_shape %subview_12 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_14 = memref.transpose %expand_shape_13 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_14[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_15 = memref.subview %alloc_1[0, %arg3, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_16 = memref.expand_shape %subview_15 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_17 = memref.transpose %expand_shape_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc[] [] [], %transpose_17[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_18 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg4 = %c0 to %c11 step %c1 {
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%3 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%4 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%5 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %5 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %6, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_19 = memref.subview %alloc_4[%arg2, %arg3, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_20 = memref.transpose %subview_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_19[] [] [], %transpose_20[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
scf.reduce
}
%transpose_11 = memref.transpose %alloc_4 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_11[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
scf.reduce
}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before ParallelToLaunch (air-par-to-launch) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : index
%c44 = arith.constant 44 : index
%c2432 = arith.constant 2432 : index
%c308 = arith.constant 308 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_0 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_2 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%alloc_3 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c308, %c2432) step (%c44, %c128) {
%subview = memref.subview %2[%arg0, %arg1] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %0[%arg0, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_5[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_6 = memref.subview %1[0, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_6 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
%c1_7 = arith.constant 1 : index
%c0_8 = arith.constant 0 : index
%c1_9 = arith.constant 1 : index
%c2_10 = arith.constant 2 : index
%c0_11 = arith.constant 0 : index
%c1_12 = arith.constant 1 : index
%c1_13 = arith.constant 1 : index
%c2_14 = arith.constant 2 : index
air.herd @herd_0 tile (%arg2, %arg3) in (%arg4=%c1_13, %arg5=%c2_14) args(%arg6=%alloc_2, %arg7=%alloc_0, %arg8=%alloc_1, %arg9=%alloc, %arg10=%alloc_3) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%cst_28 = arith.constant 0.000000e+00 : bf16
%c0_29 = arith.constant 0 : index
%c8_30 = arith.constant 8 : index
%c1_31 = arith.constant 1 : index
%c16_32 = arith.constant 16 : index
%c11_33 = arith.constant 11 : index
%3 = affine.apply affine_map<(d0) -> (d0)>(%arg2)
%4 = affine.apply affine_map<(d0) -> (d0)>(%arg3)
%subview_34 = memref.subview %arg6[%3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_35 = memref.expand_shape %subview_34 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_36 = memref.transpose %expand_shape_35 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg7[] [] [], %transpose_36[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_37 = memref.subview %arg8[0, %4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_38 = memref.expand_shape %subview_37 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_39 = memref.transpose %expand_shape_38 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg9[] [] [], %transpose_39[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_40 = memref.subview %arg10[%3, %4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst_28 : bf16) outs(%subview_40 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg11 = %c0_29 to %c11_33 step %c1_31 {
scf.for %arg12 = %c0_29 to %c16_32 step %c1_31 {
scf.for %arg13 = %c0_29 to %c8_30 step %c1_31 {
%5 = vector.transfer_read %arg7[%c0_29, %c0_29, %arg13, %arg11, %c0_29, %c0_29], %cst_28 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg9[%c0_29, %c0_29, %arg12, %arg13, %c0_29, %c0_29], %cst_28 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg10[%3, %4, %arg12, %arg11, %c0_29, %c0_29], %cst_28 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg10[%3, %4, %arg12, %arg11, %c0_29, %c0_29] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg2 = %c1 to %c151 step %c1 {
%3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg2]
%subview_28 = memref.subview %0[%arg0, %3] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_28[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_29 = memref.subview %1[%3, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_30 = memref.expand_shape %subview_29 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_31 = memref.transpose %expand_shape_30 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose_31[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
%c1_32 = arith.constant 1 : index
%c0_33 = arith.constant 0 : index
%c1_34 = arith.constant 1 : index
%c2_35 = arith.constant 2 : index
%c0_36 = arith.constant 0 : index
%c1_37 = arith.constant 1 : index
%c1_38 = arith.constant 1 : index
%c2_39 = arith.constant 2 : index
air.herd @herd_0 tile (%arg3, %arg4) in (%arg5=%c1_38, %arg6=%c2_39) args(%arg7=%alloc_2, %arg8=%alloc_0, %arg9=%alloc_1, %arg10=%alloc, %arg11=%alloc_3) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c0_40 = arith.constant 0 : index
%cst_41 = arith.constant 0.000000e+00 : bf16
%c8_42 = arith.constant 8 : index
%c1_43 = arith.constant 1 : index
%c16_44 = arith.constant 16 : index
%c11_45 = arith.constant 11 : index
%4 = affine.apply affine_map<(d0) -> (d0)>(%arg3)
%5 = affine.apply affine_map<(d0) -> (d0)>(%arg4)
%subview_46 = memref.subview %arg7[%4, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_47 = memref.expand_shape %subview_46 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_48 = memref.transpose %expand_shape_47 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg8[] [] [], %transpose_48[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_49 = memref.subview %arg9[0, %5, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_50 = memref.expand_shape %subview_49 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_51 = memref.transpose %expand_shape_50 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg10[] [] [], %transpose_51[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg12 = %c0_40 to %c11_45 step %c1_43 {
scf.for %arg13 = %c0_40 to %c16_44 step %c1_43 {
scf.for %arg14 = %c0_40 to %c8_42 step %c1_43 {
%6 = vector.transfer_read %arg8[%c0_40, %c0_40, %arg14, %arg12, %c0_40, %c0_40], %cst_41 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg10[%c0_40, %c0_40, %arg13, %arg14, %c0_40, %c0_40], %cst_41 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg11[%4, %5, %arg13, %arg12, %c0_40, %c0_40], %cst_41 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg11[%4, %5, %arg13, %arg12, %c0_40, %c0_40] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
%subview_15 = memref.subview %0[%arg0, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_2[] [] [], %subview_15[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_16 = memref.subview %1[9664, %arg1] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_17 = memref.expand_shape %subview_16 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_18 = memref.transpose %expand_shape_17 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose_18[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
%c1_19 = arith.constant 1 : index
%c0_20 = arith.constant 0 : index
%c1_21 = arith.constant 1 : index
%c2_22 = arith.constant 2 : index
%c0_23 = arith.constant 0 : index
%c1_24 = arith.constant 1 : index
%c1_25 = arith.constant 1 : index
%c2_26 = arith.constant 2 : index
air.herd @herd_0 tile (%arg2, %arg3) in (%arg4=%c1_25, %arg5=%c2_26) args(%arg6=%alloc_2, %arg7=%alloc_0, %arg8=%alloc_1, %arg9=%alloc, %arg10=%alloc_3, %arg11=%alloc_4) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c0_28 = arith.constant 0 : index
%cst_29 = arith.constant 0.000000e+00 : bf16
%c8_30 = arith.constant 8 : index
%c1_31 = arith.constant 1 : index
%c16_32 = arith.constant 16 : index
%c11_33 = arith.constant 11 : index
%3 = affine.apply affine_map<(d0) -> (d0)>(%arg2)
%4 = affine.apply affine_map<(d0) -> (d0)>(%arg3)
%subview_34 = memref.subview %arg6[%3, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_35 = memref.expand_shape %subview_34 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_36 = memref.transpose %expand_shape_35 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg7[] [] [], %transpose_36[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_37 = memref.subview %arg8[0, %4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_38 = memref.expand_shape %subview_37 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_39 = memref.transpose %expand_shape_38 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg9[] [] [], %transpose_39[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_40 = memref.subview %arg10[%3, %4, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg12 = %c0_28 to %c11_33 step %c1_31 {
scf.for %arg13 = %c0_28 to %c16_32 step %c1_31 {
scf.for %arg14 = %c0_28 to %c8_30 step %c1_31 {
%5 = vector.transfer_read %arg7[%c0_28, %c0_28, %arg14, %arg12, %c0_28, %c0_28], %cst_29 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg9[%c0_28, %c0_28, %arg13, %arg14, %c0_28, %c0_28], %cst_29 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg10[%3, %4, %arg13, %arg12, %c0_28, %c0_28], %cst_29 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg10[%3, %4, %arg13, %arg12, %c0_28, %c0_28] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_41 = memref.subview %arg11[%3, %4, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_42 = memref.transpose %subview_40 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_41[] [] [], %transpose_42[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
air.herd_terminator
}
%transpose_27 = memref.transpose %alloc_4 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_27[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
scf.reduce
}
memref.dealloc %alloc_4 : memref<1x2x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_1 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_0 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
return
}
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : index
%c44 = arith.constant 44 : index
%c2432 = arith.constant 2432 : index
%c308 = arith.constant 308 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c151 = arith.constant 151 : index
%cst = arith.constant 0.000000e+00 : bf16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
%c7 = arith.constant 7 : index
%c0_0 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%c19 = arith.constant 19 : index
%c0_2 = arith.constant 0 : index
%c1_3 = arith.constant 1 : index
%c7_4 = arith.constant 7 : index
%c19_5 = arith.constant 19 : index
air.launch (%arg0, %arg1) in (%arg2=%c7_4, %arg3=%c19_5) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg2, %arg10=%arg3, %arg11=%arg4, %arg12=%arg5, %arg13=%arg6) : index, index, index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c1_6 = arith.constant 1 : index
%c151_7 = arith.constant 151 : index
%3 = affine.apply affine_map<(d0) -> (d0 * 44)>(%arg7)
%4 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg8)
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_8 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_9 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_10 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_11 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_12 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%subview = memref.subview %arg11[%3, %4] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_13 = memref.subview %arg12[%3, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_12[] [] [], %subview_13[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_14 = memref.subview %arg13[0, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_14 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_11[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
%c1_15 = arith.constant 1 : index
%c0_16 = arith.constant 0 : index
%c1_17 = arith.constant 1 : index
%c2_18 = arith.constant 2 : index
%c0_19 = arith.constant 0 : index
%c1_20 = arith.constant 1 : index
%c1_21 = arith.constant 1 : index
%c2_22 = arith.constant 2 : index
air.herd @herd_0 tile (%arg14, %arg15) in (%arg16=%c1_21, %arg17=%c2_22) args(%arg18=%alloc_12, %arg19=%alloc_10, %arg20=%alloc_11, %arg21=%alloc_9, %arg22=%alloc_8) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%cst_36 = arith.constant 0.000000e+00 : bf16
%c0_37 = arith.constant 0 : index
%c8_38 = arith.constant 8 : index
%c1_39 = arith.constant 1 : index
%c16_40 = arith.constant 16 : index
%c11_41 = arith.constant 11 : index
%5 = affine.apply affine_map<(d0) -> (d0)>(%arg14)
%6 = affine.apply affine_map<(d0) -> (d0)>(%arg15)
%subview_42 = memref.subview %arg18[%5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_43 = memref.expand_shape %subview_42 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_44 = memref.transpose %expand_shape_43 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_44[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_45 = memref.subview %arg20[0, %6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_46 = memref.expand_shape %subview_45 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_47 = memref.transpose %expand_shape_46 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg21[] [] [], %transpose_47[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_48 = memref.subview %arg22[%5, %6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst_36 : bf16) outs(%subview_48 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg23 = %c0_37 to %c11_41 step %c1_39 {
scf.for %arg24 = %c0_37 to %c16_40 step %c1_39 {
scf.for %arg25 = %c0_37 to %c8_38 step %c1_39 {
%7 = vector.transfer_read %arg19[%c0_37, %c0_37, %arg25, %arg23, %c0_37, %c0_37], %cst_36 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%8 = vector.transfer_read %arg21[%c0_37, %c0_37, %arg24, %arg25, %c0_37, %c0_37], %cst_36 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%9 = vector.transfer_read %arg22[%5, %6, %arg24, %arg23, %c0_37, %c0_37], %cst_36 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%10 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %8, %9 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %10, %arg22[%5, %6, %arg24, %arg23, %c0_37, %c0_37] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg14 = %c1_6 to %c151_7 step %c1_6 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg14]
%subview_36 = memref.subview %arg12[%3, %5] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_12[] [] [], %subview_36[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_37 = memref.subview %arg13[%5, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_38 = memref.expand_shape %subview_37 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_39 = memref.transpose %expand_shape_38 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_39[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
%c1_40 = arith.constant 1 : index
%c0_41 = arith.constant 0 : index
%c1_42 = arith.constant 1 : index
%c2_43 = arith.constant 2 : index
%c0_44 = arith.constant 0 : index
%c1_45 = arith.constant 1 : index
%c1_46 = arith.constant 1 : index
%c2_47 = arith.constant 2 : index
air.herd @herd_0 tile (%arg15, %arg16) in (%arg17=%c1_46, %arg18=%c2_47) args(%arg19=%alloc_12, %arg20=%alloc_10, %arg21=%alloc_11, %arg22=%alloc_9, %arg23=%alloc_8) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c0_48 = arith.constant 0 : index
%cst_49 = arith.constant 0.000000e+00 : bf16
%c8_50 = arith.constant 8 : index
%c1_51 = arith.constant 1 : index
%c16_52 = arith.constant 16 : index
%c11_53 = arith.constant 11 : index
%6 = affine.apply affine_map<(d0) -> (d0)>(%arg15)
%7 = affine.apply affine_map<(d0) -> (d0)>(%arg16)
%subview_54 = memref.subview %arg19[%6, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_55 = memref.expand_shape %subview_54 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_56 = memref.transpose %expand_shape_55 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg20[] [] [], %transpose_56[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_57 = memref.subview %arg21[0, %7, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_58 = memref.expand_shape %subview_57 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_59 = memref.transpose %expand_shape_58 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg22[] [] [], %transpose_59[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg24 = %c0_48 to %c11_53 step %c1_51 {
scf.for %arg25 = %c0_48 to %c16_52 step %c1_51 {
scf.for %arg26 = %c0_48 to %c8_50 step %c1_51 {
%8 = vector.transfer_read %arg20[%c0_48, %c0_48, %arg26, %arg24, %c0_48, %c0_48], %cst_49 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%9 = vector.transfer_read %arg22[%c0_48, %c0_48, %arg25, %arg26, %c0_48, %c0_48], %cst_49 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%10 = vector.transfer_read %arg23[%6, %7, %arg25, %arg24, %c0_48, %c0_48], %cst_49 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%11 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %9, %10 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %11, %arg23[%6, %7, %arg25, %arg24, %c0_48, %c0_48] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
%subview_23 = memref.subview %arg12[%3, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_12[] [] [], %subview_23[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_24 = memref.subview %arg13[9664, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_25 = memref.expand_shape %subview_24 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_26 = memref.transpose %expand_shape_25 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_26[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
%c1_27 = arith.constant 1 : index
%c0_28 = arith.constant 0 : index
%c1_29 = arith.constant 1 : index
%c2_30 = arith.constant 2 : index
%c0_31 = arith.constant 0 : index
%c1_32 = arith.constant 1 : index
%c1_33 = arith.constant 1 : index
%c2_34 = arith.constant 2 : index
air.herd @herd_0 tile (%arg14, %arg15) in (%arg16=%c1_33, %arg17=%c2_34) args(%arg18=%alloc_12, %arg19=%alloc_10, %arg20=%alloc_11, %arg21=%alloc_9, %arg22=%alloc_8, %arg23=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c0_36 = arith.constant 0 : index
%cst_37 = arith.constant 0.000000e+00 : bf16
%c8_38 = arith.constant 8 : index
%c1_39 = arith.constant 1 : index
%c16_40 = arith.constant 16 : index
%c11_41 = arith.constant 11 : index
%5 = affine.apply affine_map<(d0) -> (d0)>(%arg14)
%6 = affine.apply affine_map<(d0) -> (d0)>(%arg15)
%subview_42 = memref.subview %arg18[%5, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_43 = memref.expand_shape %subview_42 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_44 = memref.transpose %expand_shape_43 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_44[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_45 = memref.subview %arg20[0, %6, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_46 = memref.expand_shape %subview_45 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_47 = memref.transpose %expand_shape_46 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg21[] [] [], %transpose_47[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_48 = memref.subview %arg22[%5, %6, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg24 = %c0_36 to %c11_41 step %c1_39 {
scf.for %arg25 = %c0_36 to %c16_40 step %c1_39 {
scf.for %arg26 = %c0_36 to %c8_38 step %c1_39 {
%7 = vector.transfer_read %arg19[%c0_36, %c0_36, %arg26, %arg24, %c0_36, %c0_36], %cst_37 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%8 = vector.transfer_read %arg21[%c0_36, %c0_36, %arg25, %arg26, %c0_36, %c0_36], %cst_37 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%9 = vector.transfer_read %arg22[%5, %6, %arg25, %arg24, %c0_36, %c0_36], %cst_37 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%10 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %8, %9 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %10, %arg22[%5, %6, %arg25, %arg24, %c0_36, %c0_36] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_49 = memref.subview %arg23[%5, %6, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_50 = memref.transpose %subview_48 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_49[] [] [], %transpose_50[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
air.herd_terminator
}
%transpose_35 = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_35[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
memref.dealloc %alloc_12 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_11 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_10 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_9 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_8 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before CSE (cse) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
air.launch (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
%4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_0 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_3 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_4 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%subview = memref.subview %arg9[%3, %4] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %arg10[%3, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_5[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_6 = memref.subview %arg11[0, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_6 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%cst = arith.constant 0.000000e+00 : bf16
%c0_12 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_13 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_14 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_15 = memref.expand_shape %subview_14 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_16 = memref.transpose %expand_shape_15 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg17[] [] [], %transpose_16[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_17 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_18 = memref.expand_shape %subview_17 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_19 = memref.transpose %expand_shape_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_19[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_20 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_20 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg21 = %c0_12 to %c11 step %c1_13 {
scf.for %arg22 = %c0_12 to %c16 step %c1_13 {
scf.for %arg23 = %c0_12 to %c8 step %c1_13 {
%5 = vector.transfer_read %arg17[%c0_12, %c0_12, %arg23, %arg21, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_12, %c0_12, %arg22, %arg23, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg22, %arg21, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg22, %arg21, %c0_12, %c0_12] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg12 = %c1 to %c151 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
%subview_12 = memref.subview %arg10[%3, %5] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_12[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_13 = memref.subview %arg11[%5, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_15[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg13, %arg14) in (%arg15=%c1, %arg16=%c2) args(%arg17=%alloc_4, %arg18=%alloc_2, %arg19=%alloc_3, %arg20=%alloc_1, %arg21=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c0_16 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_17 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_18 = memref.subview %arg17[%arg13, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_19 = memref.expand_shape %subview_18 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_20 = memref.transpose %expand_shape_19 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg18[] [] [], %transpose_20[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_21 = memref.subview %arg19[0, %arg14, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_22 = memref.expand_shape %subview_21 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_23 = memref.transpose %expand_shape_22 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg20[] [] [], %transpose_23[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg22 = %c0_16 to %c11 step %c1_17 {
scf.for %arg23 = %c0_16 to %c16 step %c1_17 {
scf.for %arg24 = %c0_16 to %c8 step %c1_17 {
%6 = vector.transfer_read %arg18[%c0_16, %c0_16, %arg24, %arg22, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg20[%c0_16, %c0_16, %arg23, %arg24, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg21[%arg13, %arg14, %arg23, %arg22, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg21[%arg13, %arg14, %arg23, %arg22, %c0_16, %c0_16] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
%subview_7 = memref.subview %arg10[%3, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_7[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_8 = memref.subview %arg11[9664, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_9 = memref.expand_shape %subview_8 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_10[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0, %arg21=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c0_12 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_13 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_14 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_15 = memref.expand_shape %subview_14 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_16 = memref.transpose %expand_shape_15 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg17[] [] [], %transpose_16[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_17 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_18 = memref.expand_shape %subview_17 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_19 = memref.transpose %expand_shape_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_19[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_20 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg22 = %c0_12 to %c11 step %c1_13 {
scf.for %arg23 = %c0_12 to %c16 step %c1_13 {
scf.for %arg24 = %c0_12 to %c8 step %c1_13 {
%5 = vector.transfer_read %arg17[%c0_12, %c0_12, %arg24, %arg22, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_12, %c0_12, %arg23, %arg24, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg22, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg23, %arg22, %c0_12, %c0_12] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_21 = memref.subview %arg21[%arg12, %arg13, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_22 = memref.transpose %subview_20 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_21[] [] [], %transpose_22[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
air.herd_terminator
}
%transpose_11 = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_11[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
memref.dealloc %alloc_4 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_1 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_0 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before AMDAIECanonicalizeDma (iree-amdaie-canonicalize-dma) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
air.launch (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
%4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_0 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_3 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_4 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%subview = memref.subview %arg9[%3, %4] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %arg10[%3, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_5[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_6 = memref.subview %arg11[0, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_6 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%cst = arith.constant 0.000000e+00 : bf16
%c0_12 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_13 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_14 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_15 = memref.expand_shape %subview_14 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_16 = memref.transpose %expand_shape_15 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg17[] [] [], %transpose_16[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_17 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_18 = memref.expand_shape %subview_17 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_19 = memref.transpose %expand_shape_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_19[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_20 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_20 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg21 = %c0_12 to %c11 step %c1_13 {
scf.for %arg22 = %c0_12 to %c16 step %c1_13 {
scf.for %arg23 = %c0_12 to %c8 step %c1_13 {
%5 = vector.transfer_read %arg17[%c0_12, %c0_12, %arg23, %arg21, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_12, %c0_12, %arg22, %arg23, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg22, %arg21, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg22, %arg21, %c0_12, %c0_12] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg12 = %c1 to %c151 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
%subview_12 = memref.subview %arg10[%3, %5] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_12[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_13 = memref.subview %arg11[%5, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_15[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg13, %arg14) in (%arg15=%c1, %arg16=%c2) args(%arg17=%alloc_4, %arg18=%alloc_2, %arg19=%alloc_3, %arg20=%alloc_1, %arg21=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c0_16 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_17 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_18 = memref.subview %arg17[%arg13, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_19 = memref.expand_shape %subview_18 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_20 = memref.transpose %expand_shape_19 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg18[] [] [], %transpose_20[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_21 = memref.subview %arg19[0, %arg14, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_22 = memref.expand_shape %subview_21 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_23 = memref.transpose %expand_shape_22 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg20[] [] [], %transpose_23[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg22 = %c0_16 to %c11 step %c1_17 {
scf.for %arg23 = %c0_16 to %c16 step %c1_17 {
scf.for %arg24 = %c0_16 to %c8 step %c1_17 {
%6 = vector.transfer_read %arg18[%c0_16, %c0_16, %arg24, %arg22, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg20[%c0_16, %c0_16, %arg23, %arg24, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg21[%arg13, %arg14, %arg23, %arg22, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg21[%arg13, %arg14, %arg23, %arg22, %c0_16, %c0_16] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
%subview_7 = memref.subview %arg10[%3, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_7[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_8 = memref.subview %arg11[9664, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_9 = memref.expand_shape %subview_8 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_10[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0, %arg21=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c0_12 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_13 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_14 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_15 = memref.expand_shape %subview_14 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_16 = memref.transpose %expand_shape_15 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg17[] [] [], %transpose_16[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_17 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_18 = memref.expand_shape %subview_17 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_19 = memref.transpose %expand_shape_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_19[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_20 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg22 = %c0_12 to %c11 step %c1_13 {
scf.for %arg23 = %c0_12 to %c16 step %c1_13 {
scf.for %arg24 = %c0_12 to %c8 step %c1_13 {
%5 = vector.transfer_read %arg17[%c0_12, %c0_12, %arg24, %arg22, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_12, %c0_12, %arg23, %arg24, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg22, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg23, %arg22, %c0_12, %c0_12] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_21 = memref.subview %arg21[%arg12, %arg13, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_22 = memref.transpose %subview_20 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_21[] [] [], %transpose_22[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
air.herd_terminator
}
%transpose_11 = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_11[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
memref.dealloc %alloc_4 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_1 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_0 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before CopyToDma (air-copy-to-dma) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
air.launch (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
%4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_0 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_3 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_4 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%subview = memref.subview %arg9[%3, %4] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %arg10[%3, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_5[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_6 = memref.subview %arg11[0, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_6 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%cst = arith.constant 0.000000e+00 : bf16
%c0_12 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_13 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_14 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_15 = memref.expand_shape %subview_14 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_16 = memref.transpose %expand_shape_15 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg17[] [] [], %transpose_16[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_17 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_18 = memref.expand_shape %subview_17 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_19 = memref.transpose %expand_shape_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_19[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_20 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_20 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg21 = %c0_12 to %c11 step %c1_13 {
scf.for %arg22 = %c0_12 to %c16 step %c1_13 {
scf.for %arg23 = %c0_12 to %c8 step %c1_13 {
%5 = vector.transfer_read %arg17[%c0_12, %c0_12, %arg23, %arg21, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_12, %c0_12, %arg22, %arg23, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg22, %arg21, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg22, %arg21, %c0_12, %c0_12] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg12 = %c1 to %c151 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
%subview_12 = memref.subview %arg10[%3, %5] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_12[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_13 = memref.subview %arg11[%5, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_15[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg13, %arg14) in (%arg15=%c1, %arg16=%c2) args(%arg17=%alloc_4, %arg18=%alloc_2, %arg19=%alloc_3, %arg20=%alloc_1, %arg21=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c0_16 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_17 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_18 = memref.subview %arg17[%arg13, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_19 = memref.expand_shape %subview_18 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_20 = memref.transpose %expand_shape_19 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg18[] [] [], %transpose_20[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_21 = memref.subview %arg19[0, %arg14, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_22 = memref.expand_shape %subview_21 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_23 = memref.transpose %expand_shape_22 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg20[] [] [], %transpose_23[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
scf.for %arg22 = %c0_16 to %c11 step %c1_17 {
scf.for %arg23 = %c0_16 to %c16 step %c1_17 {
scf.for %arg24 = %c0_16 to %c8 step %c1_17 {
%6 = vector.transfer_read %arg18[%c0_16, %c0_16, %arg24, %arg22, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg20[%c0_16, %c0_16, %arg23, %arg24, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg21[%arg13, %arg14, %arg23, %arg22, %c0_16, %c0_16], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg21[%arg13, %arg14, %arg23, %arg22, %c0_16, %c0_16] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
%subview_7 = memref.subview %arg10[%3, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_4[] [] [], %subview_7[] [] []) : (memref<1x1x44x64xbf16, 1 : i32>, memref<44x64xbf16, strided<[9728, 1], offset: ?>>)
%subview_8 = memref.subview %arg11[9664, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_9 = memref.expand_shape %subview_8 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_10[] [] []) : (memref<1x2x64x64xbf16, 1 : i32>, memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0, %arg21=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c0_12 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_13 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_14 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_15 = memref.expand_shape %subview_14 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_16 = memref.transpose %expand_shape_15 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg17[] [] [], %transpose_16[] [] []) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>)
%subview_17 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_18 = memref.expand_shape %subview_17 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_19 = memref.transpose %expand_shape_18 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%arg19[] [] [], %transpose_19[] [] []) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>)
%subview_20 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg22 = %c0_12 to %c11 step %c1_13 {
scf.for %arg23 = %c0_12 to %c16 step %c1_13 {
scf.for %arg24 = %c0_12 to %c8 step %c1_13 {
%5 = vector.transfer_read %arg17[%c0_12, %c0_12, %arg24, %arg22, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_12, %c0_12, %arg23, %arg24, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg22, %c0_12, %c0_12], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg23, %arg22, %c0_12, %c0_12] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_21 = memref.subview %arg21[%arg12, %arg13, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_22 = memref.transpose %subview_20 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
air.dma_memcpy_nd (%subview_21[] [] [], %transpose_22[] [] []) : (memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>, memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>)
air.herd_terminator
}
%transpose_11 = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
air.dma_memcpy_nd (%subview[] [] [], %transpose_11[] [] []) : (memref<44x128xbf16, strided<[2432, 1], offset: ?>>, memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>)
memref.dealloc %alloc_4 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_1 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_0 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
air.launch (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
%4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_0 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_3 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_4 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
%subview = memref.subview %arg9[%3, %4] [44, 128] [1, 1] : memref<308x2432xbf16> to memref<44x128xbf16, strided<[2432, 1], offset: ?>>
%subview_5 = memref.subview %arg10[%3, 0] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
%c0_6 = arith.constant 0 : index
%c0_7 = arith.constant 0 : index
%c9728 = arith.constant 9728 : index
%c1_8 = arith.constant 1 : index
%c44 = arith.constant 44 : index
%c64 = arith.constant 64 : index
air.dma_memcpy_nd (%alloc_4[] [] [], %arg10[%3, %c0_7] [%c44, %c64] [%c9728, %c1_8]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%subview_9 = memref.subview %arg11[0, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview_9 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
%c0_10 = arith.constant 0 : index
%c0_11 = arith.constant 0 : index
%c0_12 = arith.constant 0 : index
%c0_13 = arith.constant 0 : index
%c155648 = arith.constant 155648 : index
%c64_14 = arith.constant 64 : index
%c2432 = arith.constant 2432 : index
%c1_15 = arith.constant 1 : index
%c1_16 = arith.constant 1 : index
%c2_17 = arith.constant 2 : index
%c64_18 = arith.constant 64 : index
%c64_19 = arith.constant 64 : index
air.dma_memcpy_nd (%alloc_3[] [] [], %arg11[%c0_13, %c0_12, %c0_11, %4] [%c1_16, %c2_17, %c64_18, %c64_19] [%c155648, %c64_14, %c2432, %c1_15]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%cst = arith.constant 0.000000e+00 : bf16
%c0_52 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_53 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_54 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_55 = memref.expand_shape %subview_54 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_56 = memref.transpose %expand_shape_55 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
%c0_57 = arith.constant 0 : index
%c0_58 = arith.constant 0 : index
%c0_59 = arith.constant 0 : index
%c0_60 = arith.constant 0 : index
%c0_61 = arith.constant 0 : index
%c0_62 = arith.constant 0 : index
%c2816_63 = arith.constant 2816 : index
%c2816_64 = arith.constant 2816 : index
%c8_65 = arith.constant 8 : index
%c256 = arith.constant 256 : index
%c64_66 = arith.constant 64 : index
%c1_67 = arith.constant 1 : index
%c1_68 = arith.constant 1 : index
%c1_69 = arith.constant 1 : index
%c8_70 = arith.constant 8 : index
%c11_71 = arith.constant 11 : index
%c4 = arith.constant 4 : index
%c8_72 = arith.constant 8 : index
air.dma_memcpy_nd (%arg17[] [] [], %arg16[%arg12, %c0_58, %c0_61, %c0_62, %c0_59, %c0_60] [%c1_68, %c1_69, %c8_70, %c11_71, %c4, %c8_72] [%c2816_63, %c2816_64, %c8_65, %c256, %c64_66, %c1_67]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%subview_73 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_74 = memref.expand_shape %subview_73 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_75 = memref.transpose %expand_shape_74 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
%c0_76 = arith.constant 0 : index
%c0_77 = arith.constant 0 : index
%c0_78 = arith.constant 0 : index
%c0_79 = arith.constant 0 : index
%c0_80 = arith.constant 0 : index
%c0_81 = arith.constant 0 : index
%c8192 = arith.constant 8192 : index
%c4096 = arith.constant 4096 : index
%c4_82 = arith.constant 4 : index
%c512 = arith.constant 512 : index
%c64_83 = arith.constant 64 : index
%c1_84 = arith.constant 1 : index
%c1_85 = arith.constant 1 : index
%c1_86 = arith.constant 1 : index
%c16_87 = arith.constant 16 : index
%c8_88 = arith.constant 8 : index
%c8_89 = arith.constant 8 : index
%c4_90 = arith.constant 4 : index
air.dma_memcpy_nd (%arg19[] [] [], %arg18[%c0_77, %arg13, %c0_80, %c0_81, %c0_78, %c0_79] [%c1_85, %c1_86, %c16_87, %c8_88, %c8_89, %c4_90] [%c8192, %c4096, %c4_82, %c512, %c64_83, %c1_84]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview_91 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview_91 : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg21 = %c0_52 to %c11 step %c1_53 {
scf.for %arg22 = %c0_52 to %c16 step %c1_53 {
scf.for %arg23 = %c0_52 to %c8 step %c1_53 {
%5 = vector.transfer_read %arg17[%c0_52, %c0_52, %arg23, %arg21, %c0_52, %c0_52], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_52, %c0_52, %arg22, %arg23, %c0_52, %c0_52], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg22, %arg21, %c0_52, %c0_52], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg22, %arg21, %c0_52, %c0_52] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg12 = %c1 to %c151 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
%subview_52 = memref.subview %arg10[%3, %5] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
%c0_53 = arith.constant 0 : index
%c9728_54 = arith.constant 9728 : index
%c1_55 = arith.constant 1 : index
%c44_56 = arith.constant 44 : index
%c64_57 = arith.constant 64 : index
air.dma_memcpy_nd (%alloc_4[] [] [], %arg10[%3, %5] [%c44_56, %c64_57] [%c9728_54, %c1_55]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%subview_58 = memref.subview %arg11[%5, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_59 = memref.expand_shape %subview_58 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_60 = memref.transpose %expand_shape_59 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
%c0_61 = arith.constant 0 : index
%c0_62 = arith.constant 0 : index
%c0_63 = arith.constant 0 : index
%c155648_64 = arith.constant 155648 : index
%c64_65 = arith.constant 64 : index
%c2432_66 = arith.constant 2432 : index
%c1_67 = arith.constant 1 : index
%c1_68 = arith.constant 1 : index
%c2_69 = arith.constant 2 : index
%c64_70 = arith.constant 64 : index
%c64_71 = arith.constant 64 : index
air.dma_memcpy_nd (%alloc_3[] [] [], %arg11[%c0_63, %c0_62, %5, %4] [%c1_68, %c2_69, %c64_70, %c64_71] [%c155648_64, %c64_65, %c2432_66, %c1_67]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg13, %arg14) in (%arg15=%c1, %arg16=%c2) args(%arg17=%alloc_4, %arg18=%alloc_2, %arg19=%alloc_3, %arg20=%alloc_1, %arg21=%alloc_0) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c0_72 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_73 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_74 = memref.subview %arg17[%arg13, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_75 = memref.expand_shape %subview_74 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_76 = memref.transpose %expand_shape_75 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
%c0_77 = arith.constant 0 : index
%c0_78 = arith.constant 0 : index
%c0_79 = arith.constant 0 : index
%c0_80 = arith.constant 0 : index
%c0_81 = arith.constant 0 : index
%c0_82 = arith.constant 0 : index
%c2816_83 = arith.constant 2816 : index
%c2816_84 = arith.constant 2816 : index
%c8_85 = arith.constant 8 : index
%c256 = arith.constant 256 : index
%c64_86 = arith.constant 64 : index
%c1_87 = arith.constant 1 : index
%c1_88 = arith.constant 1 : index
%c1_89 = arith.constant 1 : index
%c8_90 = arith.constant 8 : index
%c11_91 = arith.constant 11 : index
%c4 = arith.constant 4 : index
%c8_92 = arith.constant 8 : index
air.dma_memcpy_nd (%arg18[] [] [], %arg17[%arg13, %c0_78, %c0_81, %c0_82, %c0_79, %c0_80] [%c1_88, %c1_89, %c8_90, %c11_91, %c4, %c8_92] [%c2816_83, %c2816_84, %c8_85, %c256, %c64_86, %c1_87]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%subview_93 = memref.subview %arg19[0, %arg14, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_94 = memref.expand_shape %subview_93 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_95 = memref.transpose %expand_shape_94 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
%c0_96 = arith.constant 0 : index
%c0_97 = arith.constant 0 : index
%c0_98 = arith.constant 0 : index
%c0_99 = arith.constant 0 : index
%c0_100 = arith.constant 0 : index
%c0_101 = arith.constant 0 : index
%c8192 = arith.constant 8192 : index
%c4096 = arith.constant 4096 : index
%c4_102 = arith.constant 4 : index
%c512 = arith.constant 512 : index
%c64_103 = arith.constant 64 : index
%c1_104 = arith.constant 1 : index
%c1_105 = arith.constant 1 : index
%c1_106 = arith.constant 1 : index
%c16_107 = arith.constant 16 : index
%c8_108 = arith.constant 8 : index
%c8_109 = arith.constant 8 : index
%c4_110 = arith.constant 4 : index
air.dma_memcpy_nd (%arg20[] [] [], %arg19[%c0_97, %arg14, %c0_100, %c0_101, %c0_98, %c0_99] [%c1_105, %c1_106, %c16_107, %c8_108, %c8_109, %c4_110] [%c8192, %c4096, %c4_102, %c512, %c64_103, %c1_104]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
scf.for %arg22 = %c0_72 to %c11 step %c1_73 {
scf.for %arg23 = %c0_72 to %c16 step %c1_73 {
scf.for %arg24 = %c0_72 to %c8 step %c1_73 {
%6 = vector.transfer_read %arg18[%c0_72, %c0_72, %arg24, %arg22, %c0_72, %c0_72], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg20[%c0_72, %c0_72, %arg23, %arg24, %c0_72, %c0_72], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg21[%arg13, %arg14, %arg23, %arg22, %c0_72, %c0_72], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg21[%arg13, %arg14, %arg23, %arg22, %c0_72, %c0_72] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
%subview_20 = memref.subview %arg10[%3, 9664] [44, 64] [1, 1] : memref<308x9728xbf16> to memref<44x64xbf16, strided<[9728, 1], offset: ?>>
%c0_21 = arith.constant 0 : index
%c9664 = arith.constant 9664 : index
%c9728_22 = arith.constant 9728 : index
%c1_23 = arith.constant 1 : index
%c44_24 = arith.constant 44 : index
%c64_25 = arith.constant 64 : index
air.dma_memcpy_nd (%alloc_4[] [] [], %arg10[%3, %c9664] [%c44_24, %c64_25] [%c9728_22, %c1_23]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%subview_26 = memref.subview %arg11[9664, %4] [64, 128] [1, 1] : memref<9728x2432xbf16> to memref<64x128xbf16, strided<[2432, 1], offset: ?>>
%expand_shape_27 = memref.expand_shape %subview_26 [[0, 1], [2, 3]] output_shape [1, 64, 2, 64] : memref<64x128xbf16, strided<[2432, 1], offset: ?>> into memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>>
%transpose_28 = memref.transpose %expand_shape_27 (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x64x2x64xbf16, strided<[155648, 2432, 64, 1], offset: ?>> to memref<1x2x64x64xbf16, strided<[155648, 64, 2432, 1], offset: ?>>
%c0_29 = arith.constant 0 : index
%c9664_30 = arith.constant 9664 : index
%c0_31 = arith.constant 0 : index
%c0_32 = arith.constant 0 : index
%c155648_33 = arith.constant 155648 : index
%c64_34 = arith.constant 64 : index
%c2432_35 = arith.constant 2432 : index
%c1_36 = arith.constant 1 : index
%c1_37 = arith.constant 1 : index
%c2_38 = arith.constant 2 : index
%c64_39 = arith.constant 64 : index
%c64_40 = arith.constant 64 : index
air.dma_memcpy_nd (%alloc_3[] [] [], %arg11[%c0_32, %c0_31, %c9664_30, %4] [%c1_37, %c2_38, %c64_39, %c64_40] [%c155648_33, %c64_34, %c2432_35, %c1_36]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_4, %arg17=%alloc_2, %arg18=%alloc_3, %arg19=%alloc_1, %arg20=%alloc_0, %arg21=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c0_52 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_53 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%subview_54 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x1x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32>
%expand_shape_55 = memref.expand_shape %subview_54 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 11, 4, 8, 8] : memref<1x1x44x64xbf16, strided<[2816, 2816, 64, 1], offset: ?>, 1 : i32> into memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32>
%transpose_56 = memref.transpose %expand_shape_55 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x11x4x8x8xbf16, strided<[2816, 2816, 256, 64, 8, 1], offset: ?>, 1 : i32> to memref<1x1x8x11x4x8xbf16, strided<[2816, 2816, 8, 256, 64, 1], offset: ?>, 1 : i32>
%c0_57 = arith.constant 0 : index
%c0_58 = arith.constant 0 : index
%c0_59 = arith.constant 0 : index
%c0_60 = arith.constant 0 : index
%c0_61 = arith.constant 0 : index
%c0_62 = arith.constant 0 : index
%c2816_63 = arith.constant 2816 : index
%c2816_64 = arith.constant 2816 : index
%c8_65 = arith.constant 8 : index
%c256 = arith.constant 256 : index
%c64_66 = arith.constant 64 : index
%c1_67 = arith.constant 1 : index
%c1_68 = arith.constant 1 : index
%c1_69 = arith.constant 1 : index
%c8_70 = arith.constant 8 : index
%c11_71 = arith.constant 11 : index
%c4 = arith.constant 4 : index
%c8_72 = arith.constant 8 : index
air.dma_memcpy_nd (%arg17[] [] [], %arg16[%arg12, %c0_58, %c0_61, %c0_62, %c0_59, %c0_60] [%c1_68, %c1_69, %c8_70, %c11_71, %c4, %c8_72] [%c2816_63, %c2816_64, %c8_65, %c256, %c64_66, %c1_67]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%subview_73 = memref.subview %arg18[0, %arg13, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x2x64x64xbf16, 1 : i32> to memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32>
%expand_shape_74 = memref.expand_shape %subview_73 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 8, 8, 16, 4] : memref<1x1x64x64xbf16, strided<[8192, 4096, 64, 1], offset: ?>, 1 : i32> into memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32>
%transpose_75 = memref.transpose %expand_shape_74 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x8x8x16x4xbf16, strided<[8192, 4096, 512, 64, 4, 1], offset: ?>, 1 : i32> to memref<1x1x16x8x8x4xbf16, strided<[8192, 4096, 4, 512, 64, 1], offset: ?>, 1 : i32>
%c0_76 = arith.constant 0 : index
%c0_77 = arith.constant 0 : index
%c0_78 = arith.constant 0 : index
%c0_79 = arith.constant 0 : index
%c0_80 = arith.constant 0 : index
%c0_81 = arith.constant 0 : index
%c8192 = arith.constant 8192 : index
%c4096 = arith.constant 4096 : index
%c4_82 = arith.constant 4 : index
%c512 = arith.constant 512 : index
%c64_83 = arith.constant 64 : index
%c1_84 = arith.constant 1 : index
%c1_85 = arith.constant 1 : index
%c1_86 = arith.constant 1 : index
%c16_87 = arith.constant 16 : index
%c8_88 = arith.constant 8 : index
%c8_89 = arith.constant 8 : index
%c4_90 = arith.constant 4 : index
air.dma_memcpy_nd (%arg19[] [] [], %arg18[%c0_77, %arg13, %c0_80, %c0_81, %c0_78, %c0_79] [%c1_85, %c1_86, %c16_87, %c8_88, %c8_89, %c4_90] [%c8192, %c4096, %c4_82, %c512, %c64_83, %c1_84]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview_91 = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
scf.for %arg22 = %c0_52 to %c11 step %c1_53 {
scf.for %arg23 = %c0_52 to %c16 step %c1_53 {
scf.for %arg24 = %c0_52 to %c8 step %c1_53 {
%5 = vector.transfer_read %arg17[%c0_52, %c0_52, %arg24, %arg22, %c0_52, %c0_52], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_52, %c0_52, %arg23, %arg24, %c0_52, %c0_52], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg22, %c0_52, %c0_52], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg23, %arg22, %c0_52, %c0_52] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
%subview_92 = memref.subview %arg21[%arg12, %arg13, 0, 0] [1, 1, 44, 64] [1, 1, 1, 1] : memref<1x2x44x64xbf16, 1 : i32> to memref<1x1x44x64xbf16, strided<[5632, 2816, 64, 1], offset: ?>, 1 : i32>
%transpose_93 = memref.transpose %subview_91 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x11x4x16x4xbf16, strided<[5632, 2816, 16, 4, 176, 1], offset: ?>, 2 : i32>
%c0_94 = arith.constant 0 : index
%c0_95 = arith.constant 0 : index
%c0_96 = arith.constant 0 : index
%c0_97 = arith.constant 0 : index
%c0_98 = arith.constant 0 : index
%c0_99 = arith.constant 0 : index
%c0_100 = arith.constant 0 : index
%c5632_101 = arith.constant 5632 : index
%c2816_102 = arith.constant 2816 : index
%c16_103 = arith.constant 16 : index
%c4_104 = arith.constant 4 : index
%c176 = arith.constant 176 : index
%c1_105 = arith.constant 1 : index
%c1_106 = arith.constant 1 : index
%c1_107 = arith.constant 1 : index
%c11_108 = arith.constant 11 : index
%c4_109 = arith.constant 4 : index
%c16_110 = arith.constant 16 : index
%c4_111 = arith.constant 4 : index
%c5632_112 = arith.constant 5632 : index
%c2816_113 = arith.constant 2816 : index
%c64_114 = arith.constant 64 : index
%c1_115 = arith.constant 1 : index
%c1_116 = arith.constant 1 : index
%c1_117 = arith.constant 1 : index
%c44_118 = arith.constant 44 : index
%c64_119 = arith.constant 64 : index
air.dma_memcpy_nd (%arg21[%arg12, %arg13, %c0_99, %c0_100] [%c1_116, %c1_117, %c44_118, %c64_119] [%c5632_112, %c2816_113, %c64_114, %c1_115], %arg20[%arg12, %arg13, %c0_96, %c0_97, %c0_95, %c0_98] [%c1_106, %c1_107, %c11_108, %c4_109, %c16_110, %c4_111] [%c5632_101, %c2816_102, %c16_103, %c4_104, %c176, %c1_105]) : (memref<1x2x44x64xbf16, 1 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>)
air.herd_terminator
}
%transpose_41 = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<1x2x44x64xbf16, 1 : i32> to memref<1x44x2x64xbf16, strided<[5632, 64, 2816, 1]>, 1 : i32>
%c0_42 = arith.constant 0 : index
%c5632 = arith.constant 5632 : index
%c64_43 = arith.constant 64 : index
%c2816 = arith.constant 2816 : index
%c1_44 = arith.constant 1 : index
%c1_45 = arith.constant 1 : index
%c44_46 = arith.constant 44 : index
%c2_47 = arith.constant 2 : index
%c64_48 = arith.constant 64 : index
%c2432_49 = arith.constant 2432 : index
%c1_50 = arith.constant 1 : index
%c44_51 = arith.constant 44 : index
%c128 = arith.constant 128 : index
air.dma_memcpy_nd (%arg9[%3, %4] [%c44_51, %c128] [%c2432_49, %c1_50], %alloc[%c0_42, %c0_42, %c0_42, %c0_42] [%c1_45, %c44_46, %c2_47, %c64_48] [%c5632, %c64_43, %c2816, %c1_44]) : (memref<308x2432xbf16>, memref<1x2x44x64xbf16, 1 : i32>)
memref.dealloc %alloc_4 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_2 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_1 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_0 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before CSE (cse) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
air.launch (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c128 = arith.constant 128 : index
%c2816 = arith.constant 2816 : index
%c5632 = arith.constant 5632 : index
%c9664 = arith.constant 9664 : index
%c2432 = arith.constant 2432 : index
%c155648 = arith.constant 155648 : index
%c64 = arith.constant 64 : index
%c44 = arith.constant 44 : index
%c9728 = arith.constant 9728 : index
%c0_0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
%4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_1 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_3 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
air.dma_memcpy_nd (%alloc_5[] [] [], %arg10[%3, %c0_0] [%c44, %c64] [%c9728, %c1]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
air.dma_memcpy_nd (%alloc_4[] [] [], %arg11[%c0_0, %c0_0, %c0_0, %4] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_5, %arg17=%alloc_3, %arg18=%alloc_4, %arg19=%alloc_2, %arg20=%alloc_1) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_6 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_7 = arith.constant 2816 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0_8 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_9 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
air.dma_memcpy_nd (%arg17[] [] [], %arg16[%arg12, %c0_8, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c8, %c11, %c4, %c8] [%c2816_7, %c2816_7, %c8, %c256, %c64_6, %c1_9]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
air.dma_memcpy_nd (%arg19[] [] [], %arg18[%c0_8, %arg13, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_6, %c1_9]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg21 = %c0_8 to %c11 step %c1_9 {
scf.for %arg22 = %c0_8 to %c16 step %c1_9 {
scf.for %arg23 = %c0_8 to %c8 step %c1_9 {
%5 = vector.transfer_read %arg17[%c0_8, %c0_8, %arg23, %arg21, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_8, %c0_8, %arg22, %arg23, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg22, %arg21, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg22, %arg21, %c0_8, %c0_8] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg12 = %c1 to %c151 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
air.dma_memcpy_nd (%alloc_5[] [] [], %arg10[%3, %5] [%c44, %c64] [%c9728, %c1]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
air.dma_memcpy_nd (%alloc_4[] [] [], %arg11[%c0_0, %c0_0, %5, %4] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg13, %arg14) in (%arg15=%c1, %arg16=%c2) args(%arg17=%alloc_5, %arg18=%alloc_3, %arg19=%alloc_4, %arg20=%alloc_2, %arg21=%alloc_1) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_6 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_7 = arith.constant 2816 : index
%c0_8 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_9 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
air.dma_memcpy_nd (%arg18[] [] [], %arg17[%arg13, %c0_8, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c8, %c11, %c4, %c8] [%c2816_7, %c2816_7, %c8, %c256, %c64_6, %c1_9]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
air.dma_memcpy_nd (%arg20[] [] [], %arg19[%c0_8, %arg14, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_6, %c1_9]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
scf.for %arg22 = %c0_8 to %c11 step %c1_9 {
scf.for %arg23 = %c0_8 to %c16 step %c1_9 {
scf.for %arg24 = %c0_8 to %c8 step %c1_9 {
%6 = vector.transfer_read %arg18[%c0_8, %c0_8, %arg24, %arg22, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg20[%c0_8, %c0_8, %arg23, %arg24, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg21[%arg13, %arg14, %arg23, %arg22, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg21[%arg13, %arg14, %arg23, %arg22, %c0_8, %c0_8] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
air.dma_memcpy_nd (%alloc_5[] [] [], %arg10[%3, %c9664] [%c44, %c64] [%c9728, %c1]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
air.dma_memcpy_nd (%alloc_4[] [] [], %arg11[%c0_0, %c0_0, %c9664, %4] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_5, %arg17=%alloc_3, %arg18=%alloc_4, %arg19=%alloc_2, %arg20=%alloc_1, %arg21=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c44_6 = arith.constant 44 : index
%c176 = arith.constant 176 : index
%c5632_7 = arith.constant 5632 : index
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_8 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_9 = arith.constant 2816 : index
%c0_10 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_11 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
air.dma_memcpy_nd (%arg17[] [] [], %arg16[%arg12, %c0_10, %c0_10, %c0_10, %c0_10, %c0_10] [%c1_11, %c1_11, %c8, %c11, %c4, %c8] [%c2816_9, %c2816_9, %c8, %c256, %c64_8, %c1_11]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
air.dma_memcpy_nd (%arg19[] [] [], %arg18[%c0_10, %arg13, %c0_10, %c0_10, %c0_10, %c0_10] [%c1_11, %c1_11, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_8, %c1_11]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
scf.for %arg22 = %c0_10 to %c11 step %c1_11 {
scf.for %arg23 = %c0_10 to %c16 step %c1_11 {
scf.for %arg24 = %c0_10 to %c8 step %c1_11 {
%5 = vector.transfer_read %arg17[%c0_10, %c0_10, %arg24, %arg22, %c0_10, %c0_10], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_10, %c0_10, %arg23, %arg24, %c0_10, %c0_10], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg22, %c0_10, %c0_10], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg23, %arg22, %c0_10, %c0_10] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.dma_memcpy_nd (%arg21[%arg12, %arg13, %c0_10, %c0_10] [%c1_11, %c1_11, %c44_6, %c64_8] [%c5632_7, %c2816_9, %c64_8, %c1_11], %arg20[%arg12, %arg13, %c0_10, %c0_10, %c0_10, %c0_10] [%c1_11, %c1_11, %c11, %c4, %c16, %c4] [%c5632_7, %c2816_9, %c16, %c4, %c176, %c1_11]) : (memref<1x2x44x64xbf16, 1 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>)
air.herd_terminator
}
air.dma_memcpy_nd (%arg9[%3, %4] [%c44, %c128] [%c2432, %c1], %alloc[%c0_0, %c0_0, %c0_0, %c0_0] [%c1, %c44, %c2, %c64] [%c5632, %c64, %c2816, %c1]) : (memref<308x2432xbf16>, memref<1x2x44x64xbf16, 1 : i32>)
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_4 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_1 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before AIRDependency (air-dependency) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
memref.assume_alignment %0, 64 : memref<308x9728xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
memref.assume_alignment %1, 64 : memref<9728x2432xbf16>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
memref.assume_alignment %2, 64 : memref<308x2432xbf16>
air.launch (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%2, %arg5=%0, %arg6=%1) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> {
%c128 = arith.constant 128 : index
%c2816 = arith.constant 2816 : index
%c5632 = arith.constant 5632 : index
%c9664 = arith.constant 9664 : index
%c2432 = arith.constant 2432 : index
%c155648 = arith.constant 155648 : index
%c64 = arith.constant 64 : index
%c44 = arith.constant 44 : index
%c9728 = arith.constant 9728 : index
%c0_0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
%4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
%alloc_1 = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
%alloc_3 = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
%alloc_4 = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
%alloc_5 = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
air.dma_memcpy_nd (%alloc_5[] [] [], %arg10[%3, %c0_0] [%c44, %c64] [%c9728, %c1]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
air.dma_memcpy_nd (%alloc_4[] [] [], %arg11[%c0_0, %c0_0, %c0_0, %4] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_5, %arg17=%alloc_3, %arg18=%alloc_4, %arg19=%alloc_2, %arg20=%alloc_1) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_6 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_7 = arith.constant 2816 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0_8 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_9 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
air.dma_memcpy_nd (%arg17[] [] [], %arg16[%arg12, %c0_8, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c8, %c11, %c4, %c8] [%c2816_7, %c2816_7, %c8, %c256, %c64_6, %c1_9]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
air.dma_memcpy_nd (%arg19[] [] [], %arg18[%c0_8, %arg13, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_6, %c1_9]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
linalg.fill ins(%cst : bf16) outs(%subview : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
scf.for %arg21 = %c0_8 to %c11 step %c1_9 {
scf.for %arg22 = %c0_8 to %c16 step %c1_9 {
scf.for %arg23 = %c0_8 to %c8 step %c1_9 {
%5 = vector.transfer_read %arg17[%c0_8, %c0_8, %arg23, %arg21, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_8, %c0_8, %arg22, %arg23, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg22, %arg21, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg22, %arg21, %c0_8, %c0_8] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
scf.for %arg12 = %c1 to %c151 step %c1 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
air.dma_memcpy_nd (%alloc_5[] [] [], %arg10[%3, %5] [%c44, %c64] [%c9728, %c1]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
air.dma_memcpy_nd (%alloc_4[] [] [], %arg11[%c0_0, %c0_0, %5, %4] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg13, %arg14) in (%arg15=%c1, %arg16=%c2) args(%arg17=%alloc_5, %arg18=%alloc_3, %arg19=%alloc_4, %arg20=%alloc_2, %arg21=%alloc_1) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_6 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_7 = arith.constant 2816 : index
%c0_8 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_9 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
air.dma_memcpy_nd (%arg18[] [] [], %arg17[%arg13, %c0_8, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c8, %c11, %c4, %c8] [%c2816_7, %c2816_7, %c8, %c256, %c64_6, %c1_9]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
air.dma_memcpy_nd (%arg20[] [] [], %arg19[%c0_8, %arg14, %c0_8, %c0_8, %c0_8, %c0_8] [%c1_9, %c1_9, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_6, %c1_9]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
scf.for %arg22 = %c0_8 to %c11 step %c1_9 {
scf.for %arg23 = %c0_8 to %c16 step %c1_9 {
scf.for %arg24 = %c0_8 to %c8 step %c1_9 {
%6 = vector.transfer_read %arg18[%c0_8, %c0_8, %arg24, %arg22, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%7 = vector.transfer_read %arg20[%c0_8, %c0_8, %arg23, %arg24, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%8 = vector.transfer_read %arg21[%arg13, %arg14, %arg23, %arg22, %c0_8, %c0_8], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %9, %arg21[%arg13, %arg14, %arg23, %arg22, %c0_8, %c0_8] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.herd_terminator
}
}
air.dma_memcpy_nd (%alloc_5[] [] [], %arg10[%3, %c9664] [%c44, %c64] [%c9728, %c1]) : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
air.dma_memcpy_nd (%alloc_4[] [] [], %arg11[%c0_0, %c0_0, %c9664, %4] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%alloc_5, %arg17=%alloc_3, %arg18=%alloc_4, %arg19=%alloc_2, %arg20=%alloc_1, %arg21=%alloc) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> {
%c44_6 = arith.constant 44 : index
%c176 = arith.constant 176 : index
%c5632_7 = arith.constant 5632 : index
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_8 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_9 = arith.constant 2816 : index
%c0_10 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_11 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
air.dma_memcpy_nd (%arg17[] [] [], %arg16[%arg12, %c0_10, %c0_10, %c0_10, %c0_10, %c0_10] [%c1_11, %c1_11, %c8, %c11, %c4, %c8] [%c2816_9, %c2816_9, %c8, %c256, %c64_8, %c1_11]) : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
air.dma_memcpy_nd (%arg19[] [] [], %arg18[%c0_10, %arg13, %c0_10, %c0_10, %c0_10, %c0_10] [%c1_11, %c1_11, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_8, %c1_11]) : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
scf.for %arg22 = %c0_10 to %c11 step %c1_11 {
scf.for %arg23 = %c0_10 to %c16 step %c1_11 {
scf.for %arg24 = %c0_10 to %c8 step %c1_11 {
%5 = vector.transfer_read %arg17[%c0_10, %c0_10, %arg24, %arg22, %c0_10, %c0_10], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
%6 = vector.transfer_read %arg19[%c0_10, %c0_10, %arg23, %arg24, %c0_10, %c0_10], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
%7 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg22, %c0_10, %c0_10], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
%8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %6, %7 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
vector.transfer_write %8, %arg20[%arg12, %arg13, %arg23, %arg22, %c0_10, %c0_10] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
}
}
}
air.dma_memcpy_nd (%arg21[%arg12, %arg13, %c0_10, %c0_10] [%c1_11, %c1_11, %c44_6, %c64_8] [%c5632_7, %c2816_9, %c64_8, %c1_11], %arg20[%arg12, %arg13, %c0_10, %c0_10, %c0_10, %c0_10] [%c1_11, %c1_11, %c11, %c4, %c16, %c4] [%c5632_7, %c2816_9, %c16, %c4, %c176, %c1_11]) : (memref<1x2x44x64xbf16, 1 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>)
air.herd_terminator
}
air.dma_memcpy_nd (%arg9[%3, %4] [%c44, %c128] [%c2432, %c1], %alloc[%c0_0, %c0_0, %c0_0, %c0_0] [%c1, %c44, %c2, %c64] [%c5632, %c64, %c2816, %c1]) : (memref<308x2432xbf16>, memref<1x2x44x64xbf16, 1 : i32>)
memref.dealloc %alloc_5 : memref<1x1x44x64xbf16, 1 : i32>
memref.dealloc %alloc_4 : memref<1x2x64x64xbf16, 1 : i32>
memref.dealloc %alloc_3 : memref<1x1x8x11x4x8xbf16, 2 : i32>
memref.dealloc %alloc_2 : memref<1x1x16x8x8x4xbf16, 2 : i32>
memref.dealloc %alloc_1 : memref<1x2x16x11x4x4xbf16, 2 : i32>
memref.dealloc %alloc : memref<1x2x44x64xbf16, 1 : i32>
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before AIRDependencyScheduleOpt (air-dependency-schedule-opt) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<308x9728xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
air.execute_terminator %1 : memref<308x9728xbf16>
} {id = 1 : i32}
%async_token_0 = air.execute [%async_token] {
memref.assume_alignment %results, 64 : memref<308x9728xbf16>
} {id = 2 : i32}
%async_token_1, %results_2 = air.execute -> (memref<9728x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
air.execute_terminator %1 : memref<9728x2432xbf16>
} {id = 3 : i32}
%async_token_3 = air.execute [%async_token_1] {
memref.assume_alignment %results_2, 64 : memref<9728x2432xbf16>
} {id = 4 : i32}
%async_token_4, %results_5 = air.execute -> (memref<308x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
air.execute_terminator %1 : memref<308x2432xbf16>
} {id = 5 : i32}
%async_token_6 = air.execute [%async_token_4] {
memref.assume_alignment %results_5, 64 : memref<308x2432xbf16>
} {id = 6 : i32}
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 5 : i32} {
%1 = air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 4 : i32} {
%c128 = arith.constant 128 : index
%c2816 = arith.constant 2816 : index
%c5632 = arith.constant 5632 : index
%c9664 = arith.constant 9664 : index
%c2432 = arith.constant 2432 : index
%c155648 = arith.constant 155648 : index
%c64 = arith.constant 64 : index
%c44 = arith.constant 44 : index
%c9728 = arith.constant 9728 : index
%c0_7 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%async_token_8, %results_9 = air.execute -> (index) {
%11 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
air.execute_terminator %11 : index
} {id = 7 : i32}
%async_token_10, %results_11 = air.execute -> (index) {
%11 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
air.execute_terminator %11 : index
} {id = 8 : i32}
%async_token_12, %results_13 = air.execute -> (memref<1x2x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x44x64xbf16, 1 : i32>
} {id = 9 : i32}
%async_token_14, %results_15 = air.execute -> (memref<1x2x16x11x4x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 10 : i32}
%async_token_16, %results_17 = air.execute -> (memref<1x1x16x8x8x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
} {id = 11 : i32}
%async_token_18, %results_19 = air.execute -> (memref<1x1x8x11x4x8xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x11x4x8xbf16, 2 : i32>
} {id = 12 : i32}
%async_token_20, %results_21 = air.execute -> (memref<1x2x64x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x64x64xbf16, 1 : i32>
} {id = 13 : i32}
%async_token_22, %results_23 = air.execute -> (memref<1x1x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x1x44x64xbf16, 1 : i32>
} {id = 14 : i32}
%2 = air.dma_memcpy_nd async [%async_token_8, %async_token_22] (%results_23[] [] [], %arg10[%results_9, %c0_7] [%c44, %c64] [%c9728, %c1]) {id = 1 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%3 = air.dma_memcpy_nd async [%async_token_10, %async_token_20] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %c0_7, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 2 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%4 = air.herd @herd_0 async [%async_token_14, %async_token_16, %async_token_18, %async_token_20, %async_token_22] tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%results_23, %arg17=%results_19, %arg18=%results_21, %arg19=%results_17, %arg20=%results_15) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> attributes {id = 1 : i32} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_30 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_31 = arith.constant 2816 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0_32 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_33 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%11 = air.dma_memcpy_nd async (%arg17[] [] [], %arg16[%arg12, %c0_32, %c0_32, %c0_32, %c0_32, %c0_32] [%c1_33, %c1_33, %c8, %c11, %c4, %c8] [%c2816_31, %c2816_31, %c8, %c256, %c64_30, %c1_33]) {id = 3 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%12 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_32, %arg13, %c0_32, %c0_32, %c0_32, %c0_32] [%c1_33, %c1_33, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_30, %c1_33]) {id = 4 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%async_token_34 = air.execute {
linalg.fill ins(%cst : bf16) outs(%subview : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
} {id = 15 : i32}
%13 = air.wait_all async [%11, %12] {id = 6 : i32}
%14 = scf.for %arg21 = %c0_32 to %c11 step %c1_33 iter_args(%arg22 = %13) -> (!air.async.token) {
%c0_35 = arith.constant 0 : index
%c16_36 = arith.constant 16 : index
%c1_37 = arith.constant 1 : index
%15 = air.wait_all async [%arg22, %arg22] {id = 4 : i32}
%16 = scf.for %arg23 = %c0_35 to %c16_36 step %c1_37 iter_args(%arg24 = %15) -> (!air.async.token) {
%c0_38 = arith.constant 0 : index
%c8_39 = arith.constant 8 : index
%c1_40 = arith.constant 1 : index
%18 = air.wait_all async [%arg24, %arg24] {id = 2 : i32}
%19 = scf.for %arg25 = %c0_38 to %c8_39 step %c1_40 iter_args(%arg26 = %18) -> (!air.async.token) {
%c0_41 = arith.constant 0 : index
%cst_42 = arith.constant 0.000000e+00 : bf16
%async_token_43, %results_44 = air.execute [%arg26] -> (vector<1x1x1x1x4x8xbf16>) {
%23 = vector.transfer_read %arg17[%c0_41, %c0_41, %arg25, %arg21, %c0_41, %c0_41], %cst_42 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %23 : vector<1x1x1x1x4x8xbf16>
} {id = 16 : i32}
%async_token_45, %results_46 = air.execute [%arg26] -> (vector<1x1x1x1x8x4xbf16>) {
%23 = vector.transfer_read %arg19[%c0_41, %c0_41, %arg23, %arg25, %c0_41, %c0_41], %cst_42 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %23 : vector<1x1x1x1x8x4xbf16>
} {id = 17 : i32}
%async_token_47, %results_48 = air.execute [%arg26] -> (vector<1x1x1x1x4x4xbf16>) {
%23 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg21, %c0_41, %c0_41], %cst_42 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %23 : vector<1x1x1x1x4x4xbf16>
} {id = 18 : i32}
%21 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_44, %results_46, %results_48 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_49 = air.execute [%arg26, %async_token_47] {
vector.transfer_write %21, %arg20[%arg12, %arg13, %arg23, %arg21, %c0_41, %c0_41] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 19 : i32}
%22 = air.wait_all async [%arg26, %async_token_43, %async_token_45, %async_token_49] {id = 1 : i32}
scf.yield %22 : !air.async.token
}
%20 = air.wait_all async [%arg24, %19] {id = 3 : i32}
scf.yield %20 : !air.async.token
}
%17 = air.wait_all async [%arg22, %16] {id = 5 : i32}
scf.yield %17 : !air.async.token
}
air.herd_terminator
}
%5 = air.wait_all async [%2, %3, %4] {id = 14 : i32}
%6 = scf.for %arg12 = %c1 to %c151 step %c1 iter_args(%arg13 = %5) -> (!air.async.token) {
%c44_30 = arith.constant 44 : index
%c64_31 = arith.constant 64 : index
%c9728_32 = arith.constant 9728 : index
%c1_33 = arith.constant 1 : index
%c0_34 = arith.constant 0 : index
%c2_35 = arith.constant 2 : index
%c155648_36 = arith.constant 155648 : index
%c2432_37 = arith.constant 2432 : index
%async_token_38, %results_39 = air.execute [%arg13] -> (index) {
%15 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
air.execute_terminator %15 : index
} {id = 20 : i32}
%11 = air.dma_memcpy_nd async [%async_token_38, %arg13, %arg13] (%results_23[] [] [], %arg10[%results_9, %results_39] [%c44_30, %c64_31] [%c9728_32, %c1_33]) {id = 5 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%12 = air.dma_memcpy_nd async [%async_token_38, %arg13, %arg13] (%results_21[] [] [], %arg11[%c0_34, %c0_34, %results_39, %results_11] [%c1_33, %c2_35, %c64_31, %c64_31] [%c155648_36, %c64_31, %c2432_37, %c1_33]) {id = 6 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%13 = air.herd @herd_0 async [%arg13] tile (%arg14, %arg15) in (%arg16=%c1_33, %arg17=%c2_35) args(%arg18=%results_23, %arg19=%results_19, %arg20=%results_21, %arg21=%results_17, %arg22=%results_15) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> attributes {id = 2 : i32} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_40 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_41 = arith.constant 2816 : index
%c0_42 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_43 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%15 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%arg14, %c0_42, %c0_42, %c0_42, %c0_42, %c0_42] [%c1_43, %c1_43, %c8, %c11, %c4, %c8] [%c2816_41, %c2816_41, %c8, %c256, %c64_40, %c1_43]) {id = 7 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%16 = air.dma_memcpy_nd async (%arg21[] [] [], %arg20[%c0_42, %arg15, %c0_42, %c0_42, %c0_42, %c0_42] [%c1_43, %c1_43, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_40, %c1_43]) {id = 8 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%17 = air.wait_all async [%15, %16] {id = 12 : i32}
%18 = scf.for %arg23 = %c0_42 to %c11 step %c1_43 iter_args(%arg24 = %17) -> (!air.async.token) {
%c0_44 = arith.constant 0 : index
%c16_45 = arith.constant 16 : index
%c1_46 = arith.constant 1 : index
%19 = air.wait_all async [%arg24, %arg24] {id = 10 : i32}
%20 = scf.for %arg25 = %c0_44 to %c16_45 step %c1_46 iter_args(%arg26 = %19) -> (!air.async.token) {
%c0_47 = arith.constant 0 : index
%c8_48 = arith.constant 8 : index
%c1_49 = arith.constant 1 : index
%22 = air.wait_all async [%arg26, %arg26] {id = 8 : i32}
%23 = scf.for %arg27 = %c0_47 to %c8_48 step %c1_49 iter_args(%arg28 = %22) -> (!air.async.token) {
%c0_50 = arith.constant 0 : index
%cst_51 = arith.constant 0.000000e+00 : bf16
%async_token_52, %results_53 = air.execute [%arg28] -> (vector<1x1x1x1x4x8xbf16>) {
%27 = vector.transfer_read %arg19[%c0_50, %c0_50, %arg27, %arg23, %c0_50, %c0_50], %cst_51 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %27 : vector<1x1x1x1x4x8xbf16>
} {id = 21 : i32}
%async_token_54, %results_55 = air.execute [%arg28] -> (vector<1x1x1x1x8x4xbf16>) {
%27 = vector.transfer_read %arg21[%c0_50, %c0_50, %arg25, %arg27, %c0_50, %c0_50], %cst_51 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %27 : vector<1x1x1x1x8x4xbf16>
} {id = 22 : i32}
%async_token_56, %results_57 = air.execute [%arg28] -> (vector<1x1x1x1x4x4xbf16>) {
%27 = vector.transfer_read %arg22[%arg14, %arg15, %arg25, %arg23, %c0_50, %c0_50], %cst_51 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %27 : vector<1x1x1x1x4x4xbf16>
} {id = 23 : i32}
%25 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_53, %results_55, %results_57 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_58 = air.execute [%arg28, %async_token_56] {
vector.transfer_write %25, %arg22[%arg14, %arg15, %arg25, %arg23, %c0_50, %c0_50] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 24 : i32}
%26 = air.wait_all async [%arg28, %async_token_52, %async_token_54, %async_token_58] {id = 7 : i32}
scf.yield %26 : !air.async.token
}
%24 = air.wait_all async [%arg26, %23] {id = 9 : i32}
scf.yield %24 : !air.async.token
}
%21 = air.wait_all async [%arg24, %20] {id = 11 : i32}
scf.yield %21 : !air.async.token
}
air.herd_terminator
}
%14 = air.wait_all async [%arg13, %11, %12, %13] {id = 13 : i32}
scf.yield %14 : !air.async.token
}
%7 = air.dma_memcpy_nd async [%6, %6] (%results_23[] [] [], %arg10[%results_9, %c9664] [%c44, %c64] [%c9728, %c1]) {id = 9 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%8 = air.dma_memcpy_nd async [%6, %6] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %c9664, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 10 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%9 = air.herd @herd_0 async [%async_token_12, %6] tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%results_23, %arg17=%results_19, %arg18=%results_21, %arg19=%results_17, %arg20=%results_15, %arg21=%results_13) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> attributes {id = 3 : i32} {
%c44_30 = arith.constant 44 : index
%c176 = arith.constant 176 : index
%c5632_31 = arith.constant 5632 : index
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_32 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_33 = arith.constant 2816 : index
%c0_34 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_35 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%11 = air.dma_memcpy_nd async (%arg17[] [] [], %arg16[%arg12, %c0_34, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c8, %c11, %c4, %c8] [%c2816_33, %c2816_33, %c8, %c256, %c64_32, %c1_35]) {id = 11 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%12 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_34, %arg13, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_32, %c1_35]) {id = 12 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%13 = air.wait_all async [%11, %12] {id = 20 : i32}
%14 = scf.for %arg22 = %c0_34 to %c11 step %c1_35 iter_args(%arg23 = %13) -> (!air.async.token) {
%c0_36 = arith.constant 0 : index
%c16_37 = arith.constant 16 : index
%c1_38 = arith.constant 1 : index
%16 = air.wait_all async [%arg23, %arg23] {id = 18 : i32}
%17 = scf.for %arg24 = %c0_36 to %c16_37 step %c1_38 iter_args(%arg25 = %16) -> (!air.async.token) {
%c0_39 = arith.constant 0 : index
%c8_40 = arith.constant 8 : index
%c1_41 = arith.constant 1 : index
%19 = air.wait_all async [%arg25, %arg25] {id = 16 : i32}
%20 = scf.for %arg26 = %c0_39 to %c8_40 step %c1_41 iter_args(%arg27 = %19) -> (!air.async.token) {
%c0_42 = arith.constant 0 : index
%cst_43 = arith.constant 0.000000e+00 : bf16
%async_token_44, %results_45 = air.execute [%arg27] -> (vector<1x1x1x1x4x8xbf16>) {
%24 = vector.transfer_read %arg17[%c0_42, %c0_42, %arg26, %arg22, %c0_42, %c0_42], %cst_43 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %24 : vector<1x1x1x1x4x8xbf16>
} {id = 25 : i32}
%async_token_46, %results_47 = air.execute [%arg27] -> (vector<1x1x1x1x8x4xbf16>) {
%24 = vector.transfer_read %arg19[%c0_42, %c0_42, %arg24, %arg26, %c0_42, %c0_42], %cst_43 {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %24 : vector<1x1x1x1x8x4xbf16>
} {id = 26 : i32}
%async_token_48, %results_49 = air.execute [%arg27] -> (vector<1x1x1x1x4x4xbf16>) {
%24 = vector.transfer_read %arg20[%arg12, %arg13, %arg24, %arg22, %c0_42, %c0_42], %cst_43 {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %24 : vector<1x1x1x1x4x4xbf16>
} {id = 27 : i32}
%22 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_45, %results_47, %results_49 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_50 = air.execute [%arg27, %async_token_48] {
vector.transfer_write %22, %arg20[%arg12, %arg13, %arg24, %arg22, %c0_42, %c0_42] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 28 : i32}
%23 = air.wait_all async [%arg27, %async_token_44, %async_token_46, %async_token_50] {id = 15 : i32}
scf.yield %23 : !air.async.token
}
%21 = air.wait_all async [%arg25, %20] {id = 17 : i32}
scf.yield %21 : !air.async.token
}
%18 = air.wait_all async [%arg23, %17] {id = 19 : i32}
scf.yield %18 : !air.async.token
}
%15 = air.dma_memcpy_nd async [%14] (%arg21[%arg12, %arg13, %c0_34, %c0_34] [%c1_35, %c1_35, %c44_30, %c64_32] [%c5632_31, %c2816_33, %c64_32, %c1_35], %arg20[%arg12, %arg13, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c11, %c4, %c16, %c4] [%c5632_31, %c2816_33, %c16, %c4, %c176, %c1_35]) {id = 13 : i32} : (memref<1x2x44x64xbf16, 1 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>)
air.herd_terminator
}
%10 = air.dma_memcpy_nd async [%async_token_8, %async_token_10, %9] (%arg9[%results_9, %results_11] [%c44, %c128] [%c2432, %c1], %results_13[%c0_7, %c0_7, %c0_7, %c0_7] [%c1, %c44, %c2, %c64] [%c5632, %c64, %c2816, %c1]) {id = 14 : i32} : (memref<308x2432xbf16>, memref<1x2x44x64xbf16, 1 : i32>)
%async_token_24 = air.execute [%9, %7] {
memref.dealloc %results_23 : memref<1x1x44x64xbf16, 1 : i32>
} {id = 29 : i32}
%async_token_25 = air.execute [%9, %8] {
memref.dealloc %results_21 : memref<1x2x64x64xbf16, 1 : i32>
} {id = 30 : i32}
%async_token_26 = air.execute [%9] {
memref.dealloc %results_19 : memref<1x1x8x11x4x8xbf16, 2 : i32>
} {id = 31 : i32}
%async_token_27 = air.execute [%9] {
memref.dealloc %results_17 : memref<1x1x16x8x8x4xbf16, 2 : i32>
} {id = 32 : i32}
%async_token_28 = air.execute [%9] {
memref.dealloc %results_15 : memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 33 : i32}
%async_token_29 = air.execute [%10] {
memref.dealloc %results_13 : memref<1x2x44x64xbf16, 1 : i32>
} {id = 34 : i32}
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before AIRSpecializeDmaBroadcast (air-specialize-dma-broadcast) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<308x9728xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
air.execute_terminator %1 : memref<308x9728xbf16>
} {id = 1 : i32}
%async_token_0 = air.execute [%async_token] {
memref.assume_alignment %results, 64 : memref<308x9728xbf16>
} {id = 2 : i32}
%async_token_1, %results_2 = air.execute -> (memref<9728x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
air.execute_terminator %1 : memref<9728x2432xbf16>
} {id = 3 : i32}
%async_token_3 = air.execute [%async_token_1] {
memref.assume_alignment %results_2, 64 : memref<9728x2432xbf16>
} {id = 4 : i32}
%async_token_4, %results_5 = air.execute -> (memref<308x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
air.execute_terminator %1 : memref<308x2432xbf16>
} {id = 5 : i32}
%async_token_6 = air.execute [%async_token_4] {
memref.assume_alignment %results_5, 64 : memref<308x2432xbf16>
} {id = 6 : i32}
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 5 : i32} {
%1 = air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 4 : i32} {
%c128 = arith.constant 128 : index
%c2816 = arith.constant 2816 : index
%c5632 = arith.constant 5632 : index
%c9664 = arith.constant 9664 : index
%c2432 = arith.constant 2432 : index
%c155648 = arith.constant 155648 : index
%c64 = arith.constant 64 : index
%c44 = arith.constant 44 : index
%c9728 = arith.constant 9728 : index
%c0_7 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%async_token_8, %results_9 = air.execute -> (index) {
%11 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
air.execute_terminator %11 : index
} {id = 7 : i32}
%async_token_10, %results_11 = air.execute -> (index) {
%11 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
air.execute_terminator %11 : index
} {id = 8 : i32}
%async_token_12, %results_13 = air.execute -> (memref<1x2x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x44x64xbf16, 1 : i32>
} {id = 9 : i32}
%async_token_14, %results_15 = air.execute -> (memref<1x2x16x11x4x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 10 : i32}
%async_token_16, %results_17 = air.execute -> (memref<1x1x16x8x8x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
} {id = 11 : i32}
%async_token_18, %results_19 = air.execute -> (memref<1x1x8x11x4x8xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x11x4x8xbf16, 2 : i32>
} {id = 12 : i32}
%async_token_20, %results_21 = air.execute -> (memref<1x2x64x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x64x64xbf16, 1 : i32>
} {id = 13 : i32}
%async_token_22, %results_23 = air.execute -> (memref<1x1x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x1x44x64xbf16, 1 : i32>
} {id = 14 : i32}
%2 = air.dma_memcpy_nd async [%async_token_8, %async_token_22] (%results_23[] [] [], %arg10[%results_9, %c0_7] [%c44, %c64] [%c9728, %c1]) {id = 1 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%3 = air.dma_memcpy_nd async [%async_token_10, %async_token_20] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %c0_7, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 2 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%4 = air.herd @herd_0 async [%async_token_14, %async_token_16, %async_token_18, %async_token_20, %async_token_22] tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%results_23, %arg17=%results_19, %arg18=%results_21, %arg19=%results_17, %arg20=%results_15) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> attributes {id = 1 : i32} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_30 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_31 = arith.constant 2816 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0_32 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_33 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%11 = air.dma_memcpy_nd async (%arg17[] [] [], %arg16[%arg12, %c0_32, %c0_32, %c0_32, %c0_32, %c0_32] [%c1_33, %c1_33, %c8, %c11, %c4, %c8] [%c2816_31, %c2816_31, %c8, %c256, %c64_30, %c1_33]) {broadcast_pattern = affine_set<(d0, d1)[s0] : (d0 - s0 == 0, d1 >= 0, -d1 + 1 >= 0, s0 >= 0, -s0 >= 0)>, id = 3 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%12 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_32, %arg13, %c0_32, %c0_32, %c0_32, %c0_32] [%c1_33, %c1_33, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_30, %c1_33]) {id = 4 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%async_token_34 = air.execute {
linalg.fill ins(%cst : bf16) outs(%subview : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
} {id = 15 : i32}
%13 = air.wait_all async [%11, %12] {id = 6 : i32}
%14 = scf.for %arg21 = %c0_32 to %c11 step %c1_33 iter_args(%arg22 = %13) -> (!air.async.token) {
%15 = air.wait_all async [%arg22, %arg22] {id = 4 : i32}
%16 = scf.for %arg23 = %c0_32 to %c16 step %c1_33 iter_args(%arg24 = %15) -> (!air.async.token) {
%18 = air.wait_all async [%arg24, %arg24] {id = 2 : i32}
%19 = scf.for %arg25 = %c0_32 to %c8 step %c1_33 iter_args(%arg26 = %18) -> (!air.async.token) {
%async_token_35, %results_36 = air.execute [%arg26] -> (vector<1x1x1x1x4x8xbf16>) {
%23 = vector.transfer_read %arg17[%c0_32, %c0_32, %arg25, %arg21, %c0_32, %c0_32], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %23 : vector<1x1x1x1x4x8xbf16>
} {id = 16 : i32}
%async_token_37, %results_38 = air.execute [%arg26] -> (vector<1x1x1x1x8x4xbf16>) {
%23 = vector.transfer_read %arg19[%c0_32, %c0_32, %arg23, %arg25, %c0_32, %c0_32], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %23 : vector<1x1x1x1x8x4xbf16>
} {id = 17 : i32}
%async_token_39, %results_40 = air.execute [%arg26] -> (vector<1x1x1x1x4x4xbf16>) {
%23 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg21, %c0_32, %c0_32], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %23 : vector<1x1x1x1x4x4xbf16>
} {id = 18 : i32}
%21 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_36, %results_38, %results_40 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_41 = air.execute [%arg26, %async_token_39] {
vector.transfer_write %21, %arg20[%arg12, %arg13, %arg23, %arg21, %c0_32, %c0_32] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 19 : i32}
%22 = air.wait_all async [%arg26, %async_token_35, %async_token_37, %async_token_41] {id = 1 : i32}
scf.yield %22 : !air.async.token
}
%20 = air.wait_all async [%arg24, %19] {id = 3 : i32}
scf.yield %20 : !air.async.token
}
%17 = air.wait_all async [%arg22, %16] {id = 5 : i32}
scf.yield %17 : !air.async.token
}
air.herd_terminator
}
%5 = air.wait_all async [%2, %3, %4] {id = 14 : i32}
%6 = scf.for %arg12 = %c1 to %c151 step %c1 iter_args(%arg13 = %5) -> (!air.async.token) {
%async_token_30, %results_31 = air.execute [%arg13] -> (index) {
%15 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
air.execute_terminator %15 : index
} {id = 20 : i32}
%11 = air.dma_memcpy_nd async [%async_token_30, %arg13, %arg13] (%results_23[] [] [], %arg10[%results_9, %results_31] [%c44, %c64] [%c9728, %c1]) {id = 5 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%12 = air.dma_memcpy_nd async [%async_token_30, %arg13, %arg13] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %results_31, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 6 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%13 = air.herd @herd_0 async [%arg13] tile (%arg14, %arg15) in (%arg16=%c1, %arg17=%c2) args(%arg18=%results_23, %arg19=%results_19, %arg20=%results_21, %arg21=%results_17, %arg22=%results_15) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> attributes {id = 2 : i32} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_32 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_33 = arith.constant 2816 : index
%c0_34 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_35 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%15 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%arg14, %c0_34, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c8, %c11, %c4, %c8] [%c2816_33, %c2816_33, %c8, %c256, %c64_32, %c1_35]) {broadcast_pattern = affine_set<(d0, d1)[s0] : (d0 - s0 == 0, d1 >= 0, -d1 + 1 >= 0, s0 >= 0, -s0 >= 0)>, id = 7 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%16 = air.dma_memcpy_nd async (%arg21[] [] [], %arg20[%c0_34, %arg15, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_32, %c1_35]) {id = 8 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%17 = air.wait_all async [%15, %16] {id = 12 : i32}
%18 = scf.for %arg23 = %c0_34 to %c11 step %c1_35 iter_args(%arg24 = %17) -> (!air.async.token) {
%19 = air.wait_all async [%arg24, %arg24] {id = 10 : i32}
%20 = scf.for %arg25 = %c0_34 to %c16 step %c1_35 iter_args(%arg26 = %19) -> (!air.async.token) {
%22 = air.wait_all async [%arg26, %arg26] {id = 8 : i32}
%23 = scf.for %arg27 = %c0_34 to %c8 step %c1_35 iter_args(%arg28 = %22) -> (!air.async.token) {
%async_token_36, %results_37 = air.execute [%arg28] -> (vector<1x1x1x1x4x8xbf16>) {
%27 = vector.transfer_read %arg19[%c0_34, %c0_34, %arg27, %arg23, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %27 : vector<1x1x1x1x4x8xbf16>
} {id = 21 : i32}
%async_token_38, %results_39 = air.execute [%arg28] -> (vector<1x1x1x1x8x4xbf16>) {
%27 = vector.transfer_read %arg21[%c0_34, %c0_34, %arg25, %arg27, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %27 : vector<1x1x1x1x8x4xbf16>
} {id = 22 : i32}
%async_token_40, %results_41 = air.execute [%arg28] -> (vector<1x1x1x1x4x4xbf16>) {
%27 = vector.transfer_read %arg22[%arg14, %arg15, %arg25, %arg23, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %27 : vector<1x1x1x1x4x4xbf16>
} {id = 23 : i32}
%25 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_37, %results_39, %results_41 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_42 = air.execute [%arg28, %async_token_40] {
vector.transfer_write %25, %arg22[%arg14, %arg15, %arg25, %arg23, %c0_34, %c0_34] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 24 : i32}
%26 = air.wait_all async [%arg28, %async_token_36, %async_token_38, %async_token_42] {id = 7 : i32}
scf.yield %26 : !air.async.token
}
%24 = air.wait_all async [%arg26, %23] {id = 9 : i32}
scf.yield %24 : !air.async.token
}
%21 = air.wait_all async [%arg24, %20] {id = 11 : i32}
scf.yield %21 : !air.async.token
}
air.herd_terminator
}
%14 = air.wait_all async [%arg13, %11, %12, %13] {id = 13 : i32}
scf.yield %14 : !air.async.token
}
%7 = air.dma_memcpy_nd async [%6, %6] (%results_23[] [] [], %arg10[%results_9, %c9664] [%c44, %c64] [%c9728, %c1]) {id = 9 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%8 = air.dma_memcpy_nd async [%6, %6] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %c9664, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 10 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%9 = air.herd @herd_0 async [%async_token_12, %6] tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%results_23, %arg17=%results_19, %arg18=%results_21, %arg19=%results_17, %arg20=%results_15, %arg21=%results_13) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> attributes {id = 3 : i32} {
%c44_30 = arith.constant 44 : index
%c176 = arith.constant 176 : index
%c5632_31 = arith.constant 5632 : index
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_32 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_33 = arith.constant 2816 : index
%c0_34 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_35 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%11 = air.dma_memcpy_nd async (%arg17[] [] [], %arg16[%arg12, %c0_34, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c8, %c11, %c4, %c8] [%c2816_33, %c2816_33, %c8, %c256, %c64_32, %c1_35]) {broadcast_pattern = affine_set<(d0, d1)[s0] : (d0 - s0 == 0, d1 >= 0, -d1 + 1 >= 0, s0 >= 0, -s0 >= 0)>, id = 11 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
%12 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_34, %arg13, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_32, %c1_35]) {id = 12 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%13 = air.wait_all async [%11, %12] {id = 20 : i32}
%14 = scf.for %arg22 = %c0_34 to %c11 step %c1_35 iter_args(%arg23 = %13) -> (!air.async.token) {
%16 = air.wait_all async [%arg23, %arg23] {id = 18 : i32}
%17 = scf.for %arg24 = %c0_34 to %c16 step %c1_35 iter_args(%arg25 = %16) -> (!air.async.token) {
%19 = air.wait_all async [%arg25, %arg25] {id = 16 : i32}
%20 = scf.for %arg26 = %c0_34 to %c8 step %c1_35 iter_args(%arg27 = %19) -> (!air.async.token) {
%async_token_36, %results_37 = air.execute [%arg27] -> (vector<1x1x1x1x4x8xbf16>) {
%24 = vector.transfer_read %arg17[%c0_34, %c0_34, %arg26, %arg22, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %24 : vector<1x1x1x1x4x8xbf16>
} {id = 25 : i32}
%async_token_38, %results_39 = air.execute [%arg27] -> (vector<1x1x1x1x8x4xbf16>) {
%24 = vector.transfer_read %arg19[%c0_34, %c0_34, %arg24, %arg26, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %24 : vector<1x1x1x1x8x4xbf16>
} {id = 26 : i32}
%async_token_40, %results_41 = air.execute [%arg27] -> (vector<1x1x1x1x4x4xbf16>) {
%24 = vector.transfer_read %arg20[%arg12, %arg13, %arg24, %arg22, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %24 : vector<1x1x1x1x4x4xbf16>
} {id = 27 : i32}
%22 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_37, %results_39, %results_41 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_42 = air.execute [%arg27, %async_token_40] {
vector.transfer_write %22, %arg20[%arg12, %arg13, %arg24, %arg22, %c0_34, %c0_34] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 28 : i32}
%23 = air.wait_all async [%arg27, %async_token_36, %async_token_38, %async_token_42] {id = 15 : i32}
scf.yield %23 : !air.async.token
}
%21 = air.wait_all async [%arg25, %20] {id = 17 : i32}
scf.yield %21 : !air.async.token
}
%18 = air.wait_all async [%arg23, %17] {id = 19 : i32}
scf.yield %18 : !air.async.token
}
%15 = air.dma_memcpy_nd async [%14] (%arg21[%arg12, %arg13, %c0_34, %c0_34] [%c1_35, %c1_35, %c44_30, %c64_32] [%c5632_31, %c2816_33, %c64_32, %c1_35], %arg20[%arg12, %arg13, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c11, %c4, %c16, %c4] [%c5632_31, %c2816_33, %c16, %c4, %c176, %c1_35]) {id = 13 : i32} : (memref<1x2x44x64xbf16, 1 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>)
air.herd_terminator
}
%10 = air.dma_memcpy_nd async [%async_token_8, %async_token_10, %9] (%arg9[%results_9, %results_11] [%c44, %c128] [%c2432, %c1], %results_13[%c0_7, %c0_7, %c0_7, %c0_7] [%c1, %c44, %c2, %c64] [%c5632, %c64, %c2816, %c1]) {id = 14 : i32} : (memref<308x2432xbf16>, memref<1x2x44x64xbf16, 1 : i32>)
%async_token_24 = air.execute [%9, %7] {
memref.dealloc %results_23 : memref<1x1x44x64xbf16, 1 : i32>
} {id = 29 : i32}
%async_token_25 = air.execute [%9, %8] {
memref.dealloc %results_21 : memref<1x2x64x64xbf16, 1 : i32>
} {id = 30 : i32}
%async_token_26 = air.execute [%9] {
memref.dealloc %results_19 : memref<1x1x8x11x4x8xbf16, 2 : i32>
} {id = 31 : i32}
%async_token_27 = air.execute [%9] {
memref.dealloc %results_17 : memref<1x1x16x8x8x4xbf16, 2 : i32>
} {id = 32 : i32}
%async_token_28 = air.execute [%9] {
memref.dealloc %results_15 : memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 33 : i32}
%async_token_29 = air.execute [%10] {
memref.dealloc %results_13 : memref<1x2x44x64xbf16, 1 : i32>
} {id = 34 : i32}
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before DmaToChannel (air-dma-to-channel) //----- //
module {
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<308x9728xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
air.execute_terminator %1 : memref<308x9728xbf16>
} {id = 1 : i32}
%async_token_0 = air.execute [%async_token] {
memref.assume_alignment %results, 64 : memref<308x9728xbf16>
} {id = 2 : i32}
%async_token_1, %results_2 = air.execute -> (memref<9728x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
air.execute_terminator %1 : memref<9728x2432xbf16>
} {id = 3 : i32}
%async_token_3 = air.execute [%async_token_1] {
memref.assume_alignment %results_2, 64 : memref<9728x2432xbf16>
} {id = 4 : i32}
%async_token_4, %results_5 = air.execute -> (memref<308x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
air.execute_terminator %1 : memref<308x2432xbf16>
} {id = 5 : i32}
%async_token_6 = air.execute [%async_token_4] {
memref.assume_alignment %results_5, 64 : memref<308x2432xbf16>
} {id = 6 : i32}
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 5 : i32} {
%1 = air.segment @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 4 : i32} {
%c128 = arith.constant 128 : index
%c2816 = arith.constant 2816 : index
%c5632 = arith.constant 5632 : index
%c9664 = arith.constant 9664 : index
%c2432 = arith.constant 2432 : index
%c155648 = arith.constant 155648 : index
%c64 = arith.constant 64 : index
%c44 = arith.constant 44 : index
%c9728 = arith.constant 9728 : index
%c0_7 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%async_token_8, %results_9 = air.execute -> (index) {
%11 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg7]
air.execute_terminator %11 : index
} {id = 7 : i32}
%async_token_10, %results_11 = air.execute -> (index) {
%11 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg8]
air.execute_terminator %11 : index
} {id = 8 : i32}
%async_token_12, %results_13 = air.execute -> (memref<1x2x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x44x64xbf16, 1 : i32>
} {id = 9 : i32}
%async_token_14, %results_15 = air.execute -> (memref<1x2x16x11x4x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x2x16x11x4x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 10 : i32}
%async_token_16, %results_17 = air.execute -> (memref<1x1x16x8x8x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x16x8x8x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x16x8x8x4xbf16, 2 : i32>
} {id = 11 : i32}
%async_token_18, %results_19 = air.execute -> (memref<1x1x8x11x4x8xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x11x4x8xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x11x4x8xbf16, 2 : i32>
} {id = 12 : i32}
%async_token_20, %results_21 = air.execute -> (memref<1x2x64x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x64x64xbf16, 1 : i32>
} {id = 13 : i32}
%async_token_22, %results_23 = air.execute -> (memref<1x1x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x1x44x64xbf16, 1 : i32>
} {id = 14 : i32}
%2 = air.dma_memcpy_nd async [%async_token_8, %async_token_22] (%results_23[] [] [], %arg10[%results_9, %c0_7] [%c44, %c64] [%c9728, %c1]) {id = 1 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%3 = air.dma_memcpy_nd async [%async_token_10, %async_token_20] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %c0_7, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 2 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%4 = air.herd @herd_0 async [%async_token_14, %async_token_16, %async_token_18, %async_token_20, %async_token_22] tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%results_23, %arg17=%results_19, %arg18=%results_21, %arg19=%results_17, %arg20=%results_15) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> attributes {id = 1 : i32} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_30 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_31 = arith.constant 2816 : index
%cst = arith.constant 0.000000e+00 : bf16
%c0_32 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1_33 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%11 = affine.if affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>()[%arg12, %arg13] -> !air.async.token {
%c0_35 = arith.constant 0 : index
%15 = air.dma_memcpy_nd async (%arg17[] [] [], %arg16[%c0_35, %c0_32, %c0_32, %c0_32, %c0_32, %c0_32] [%c1_33, %c1_33, %c8, %c11, %c4, %c8] [%c2816_31, %c2816_31, %c8, %c256, %c64_30, %c1_33]) {broadcast_set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>, id = 3 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
affine.yield %15 : !air.async.token
} else {
%15 = air.wait_all async
affine.yield %15 : !air.async.token
}
%12 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_32, %arg13, %c0_32, %c0_32, %c0_32, %c0_32] [%c1_33, %c1_33, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_30, %c1_33]) {id = 4 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%subview = memref.subview %arg20[%arg12, %arg13, 0, 0, 0, 0] [1, 1, 16, 11, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x2x16x11x4x4xbf16, 2 : i32> to memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>
%async_token_34 = air.execute {
linalg.fill ins(%cst : bf16) outs(%subview : memref<1x1x16x11x4x4xbf16, strided<[5632, 2816, 176, 16, 4, 1], offset: ?>, 2 : i32>)
} {id = 15 : i32}
%13 = air.wait_all async [%11, %12] {id = 6 : i32}
%14 = scf.for %arg21 = %c0_32 to %c11 step %c1_33 iter_args(%arg22 = %13) -> (!air.async.token) {
%15 = air.wait_all async [%arg22, %arg22] {id = 4 : i32}
%16 = scf.for %arg23 = %c0_32 to %c16 step %c1_33 iter_args(%arg24 = %15) -> (!air.async.token) {
%18 = air.wait_all async [%arg24, %arg24] {id = 2 : i32}
%19 = scf.for %arg25 = %c0_32 to %c8 step %c1_33 iter_args(%arg26 = %18) -> (!air.async.token) {
%async_token_35, %results_36 = air.execute [%arg26] -> (vector<1x1x1x1x4x8xbf16>) {
%23 = vector.transfer_read %arg17[%c0_32, %c0_32, %arg25, %arg21, %c0_32, %c0_32], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %23 : vector<1x1x1x1x4x8xbf16>
} {id = 16 : i32}
%async_token_37, %results_38 = air.execute [%arg26] -> (vector<1x1x1x1x8x4xbf16>) {
%23 = vector.transfer_read %arg19[%c0_32, %c0_32, %arg23, %arg25, %c0_32, %c0_32], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %23 : vector<1x1x1x1x8x4xbf16>
} {id = 17 : i32}
%async_token_39, %results_40 = air.execute [%arg26] -> (vector<1x1x1x1x4x4xbf16>) {
%23 = vector.transfer_read %arg20[%arg12, %arg13, %arg23, %arg21, %c0_32, %c0_32], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %23 : vector<1x1x1x1x4x4xbf16>
} {id = 18 : i32}
%21 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_36, %results_38, %results_40 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_41 = air.execute [%arg26, %async_token_39] {
vector.transfer_write %21, %arg20[%arg12, %arg13, %arg23, %arg21, %c0_32, %c0_32] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 19 : i32}
%22 = air.wait_all async [%arg26, %async_token_35, %async_token_37, %async_token_41] {id = 1 : i32}
scf.yield %22 : !air.async.token
}
%20 = air.wait_all async [%arg24, %19] {id = 3 : i32}
scf.yield %20 : !air.async.token
}
%17 = air.wait_all async [%arg22, %16] {id = 5 : i32}
scf.yield %17 : !air.async.token
}
air.herd_terminator
}
%5 = air.wait_all async [%2, %3, %4] {id = 14 : i32}
%6 = scf.for %arg12 = %c1 to %c151 step %c1 iter_args(%arg13 = %5) -> (!air.async.token) {
%async_token_30, %results_31 = air.execute [%arg13] -> (index) {
%15 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg12]
air.execute_terminator %15 : index
} {id = 20 : i32}
%11 = air.dma_memcpy_nd async [%async_token_30, %arg13, %arg13] (%results_23[] [] [], %arg10[%results_9, %results_31] [%c44, %c64] [%c9728, %c1]) {id = 5 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%12 = air.dma_memcpy_nd async [%async_token_30, %arg13, %arg13] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %results_31, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 6 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%13 = air.herd @herd_0 async [%arg13] tile (%arg14, %arg15) in (%arg16=%c1, %arg17=%c2) args(%arg18=%results_23, %arg19=%results_19, %arg20=%results_21, %arg21=%results_17, %arg22=%results_15) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32> attributes {id = 2 : i32} {
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_32 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_33 = arith.constant 2816 : index
%c0_34 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_35 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%15 = affine.if affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>()[%arg14, %arg15] -> !air.async.token {
%c0_36 = arith.constant 0 : index
%19 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_36, %c0_34, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c8, %c11, %c4, %c8] [%c2816_33, %c2816_33, %c8, %c256, %c64_32, %c1_35]) {broadcast_set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>, id = 7 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
affine.yield %19 : !air.async.token
} else {
%19 = air.wait_all async
affine.yield %19 : !air.async.token
}
%16 = air.dma_memcpy_nd async (%arg21[] [] [], %arg20[%c0_34, %arg15, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_32, %c1_35]) {id = 8 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%17 = air.wait_all async [%15, %16] {id = 12 : i32}
%18 = scf.for %arg23 = %c0_34 to %c11 step %c1_35 iter_args(%arg24 = %17) -> (!air.async.token) {
%19 = air.wait_all async [%arg24, %arg24] {id = 10 : i32}
%20 = scf.for %arg25 = %c0_34 to %c16 step %c1_35 iter_args(%arg26 = %19) -> (!air.async.token) {
%22 = air.wait_all async [%arg26, %arg26] {id = 8 : i32}
%23 = scf.for %arg27 = %c0_34 to %c8 step %c1_35 iter_args(%arg28 = %22) -> (!air.async.token) {
%async_token_36, %results_37 = air.execute [%arg28] -> (vector<1x1x1x1x4x8xbf16>) {
%27 = vector.transfer_read %arg19[%c0_34, %c0_34, %arg27, %arg23, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %27 : vector<1x1x1x1x4x8xbf16>
} {id = 21 : i32}
%async_token_38, %results_39 = air.execute [%arg28] -> (vector<1x1x1x1x8x4xbf16>) {
%27 = vector.transfer_read %arg21[%c0_34, %c0_34, %arg25, %arg27, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %27 : vector<1x1x1x1x8x4xbf16>
} {id = 22 : i32}
%async_token_40, %results_41 = air.execute [%arg28] -> (vector<1x1x1x1x4x4xbf16>) {
%27 = vector.transfer_read %arg22[%arg14, %arg15, %arg25, %arg23, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %27 : vector<1x1x1x1x4x4xbf16>
} {id = 23 : i32}
%25 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_37, %results_39, %results_41 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_42 = air.execute [%arg28, %async_token_40] {
vector.transfer_write %25, %arg22[%arg14, %arg15, %arg25, %arg23, %c0_34, %c0_34] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 24 : i32}
%26 = air.wait_all async [%arg28, %async_token_36, %async_token_38, %async_token_42] {id = 7 : i32}
scf.yield %26 : !air.async.token
}
%24 = air.wait_all async [%arg26, %23] {id = 9 : i32}
scf.yield %24 : !air.async.token
}
%21 = air.wait_all async [%arg24, %20] {id = 11 : i32}
scf.yield %21 : !air.async.token
}
air.herd_terminator
}
%14 = air.wait_all async [%arg13, %11, %12, %13] {id = 13 : i32}
scf.yield %14 : !air.async.token
}
%7 = air.dma_memcpy_nd async [%6, %6] (%results_23[] [] [], %arg10[%results_9, %c9664] [%c44, %c64] [%c9728, %c1]) {id = 9 : i32} : (memref<1x1x44x64xbf16, 1 : i32>, memref<308x9728xbf16>)
%8 = air.dma_memcpy_nd async [%6, %6] (%results_21[] [] [], %arg11[%c0_7, %c0_7, %c9664, %results_11] [%c1, %c2, %c64, %c64] [%c155648, %c64, %c2432, %c1]) {id = 10 : i32} : (memref<1x2x64x64xbf16, 1 : i32>, memref<9728x2432xbf16>)
%9 = air.herd @herd_0 async [%async_token_12, %6] tile (%arg12, %arg13) in (%arg14=%c1, %arg15=%c2) args(%arg16=%results_23, %arg17=%results_19, %arg18=%results_21, %arg19=%results_17, %arg20=%results_15, %arg21=%results_13) : memref<1x1x44x64xbf16, 1 : i32>, memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>, memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>, memref<1x2x44x64xbf16, 1 : i32> attributes {id = 3 : i32} {
%c44_30 = arith.constant 44 : index
%c176 = arith.constant 176 : index
%c5632_31 = arith.constant 5632 : index
%c512 = arith.constant 512 : index
%c4096 = arith.constant 4096 : index
%c8192 = arith.constant 8192 : index
%c4 = arith.constant 4 : index
%c64_32 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c2816_33 = arith.constant 2816 : index
%c0_34 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%c8 = arith.constant 8 : index
%c1_35 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c11 = arith.constant 11 : index
%11 = affine.if affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>()[%arg12, %arg13] -> !air.async.token {
%c0_36 = arith.constant 0 : index
%16 = air.dma_memcpy_nd async (%arg17[] [] [], %arg16[%c0_36, %c0_34, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c8, %c11, %c4, %c8] [%c2816_33, %c2816_33, %c8, %c256, %c64_32, %c1_35]) {broadcast_set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>, id = 11 : i32} : (memref<1x1x8x11x4x8xbf16, 2 : i32>, memref<1x1x44x64xbf16, 1 : i32>)
affine.yield %16 : !air.async.token
} else {
%16 = air.wait_all async
affine.yield %16 : !air.async.token
}
%12 = air.dma_memcpy_nd async (%arg19[] [] [], %arg18[%c0_34, %arg13, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c16, %c8, %c8, %c4] [%c8192, %c4096, %c4, %c512, %c64_32, %c1_35]) {id = 12 : i32} : (memref<1x1x16x8x8x4xbf16, 2 : i32>, memref<1x2x64x64xbf16, 1 : i32>)
%13 = air.wait_all async [%11, %12] {id = 20 : i32}
%14 = scf.for %arg22 = %c0_34 to %c11 step %c1_35 iter_args(%arg23 = %13) -> (!air.async.token) {
%16 = air.wait_all async [%arg23, %arg23] {id = 18 : i32}
%17 = scf.for %arg24 = %c0_34 to %c16 step %c1_35 iter_args(%arg25 = %16) -> (!air.async.token) {
%19 = air.wait_all async [%arg25, %arg25] {id = 16 : i32}
%20 = scf.for %arg26 = %c0_34 to %c8 step %c1_35 iter_args(%arg27 = %19) -> (!air.async.token) {
%async_token_36, %results_37 = air.execute [%arg27] -> (vector<1x1x1x1x4x8xbf16>) {
%24 = vector.transfer_read %arg17[%c0_34, %c0_34, %arg26, %arg22, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x11x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
air.execute_terminator %24 : vector<1x1x1x1x4x8xbf16>
} {id = 25 : i32}
%async_token_38, %results_39 = air.execute [%arg27] -> (vector<1x1x1x1x8x4xbf16>) {
%24 = vector.transfer_read %arg19[%c0_34, %c0_34, %arg24, %arg26, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
air.execute_terminator %24 : vector<1x1x1x1x8x4xbf16>
} {id = 26 : i32}
%async_token_40, %results_41 = air.execute [%arg27] -> (vector<1x1x1x1x4x4xbf16>) {
%24 = vector.transfer_read %arg20[%arg12, %arg13, %arg24, %arg22, %c0_34, %c0_34], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x2x16x11x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
air.execute_terminator %24 : vector<1x1x1x1x4x4xbf16>
} {id = 27 : i32}
%22 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %results_37, %results_39, %results_41 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
%async_token_42 = air.execute [%arg27, %async_token_40] {
vector.transfer_write %22, %arg20[%arg12, %arg13, %arg24, %arg22, %c0_34, %c0_34] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 28 : i32}
%23 = air.wait_all async [%arg27, %async_token_36, %async_token_38, %async_token_42] {id = 15 : i32}
scf.yield %23 : !air.async.token
}
%21 = air.wait_all async [%arg25, %20] {id = 17 : i32}
scf.yield %21 : !air.async.token
}
%18 = air.wait_all async [%arg23, %17] {id = 19 : i32}
scf.yield %18 : !air.async.token
}
%15 = air.dma_memcpy_nd async [%14] (%arg21[%arg12, %arg13, %c0_34, %c0_34] [%c1_35, %c1_35, %c44_30, %c64_32] [%c5632_31, %c2816_33, %c64_32, %c1_35], %arg20[%arg12, %arg13, %c0_34, %c0_34, %c0_34, %c0_34] [%c1_35, %c1_35, %c11, %c4, %c16, %c4] [%c5632_31, %c2816_33, %c16, %c4, %c176, %c1_35]) {id = 13 : i32} : (memref<1x2x44x64xbf16, 1 : i32>, memref<1x2x16x11x4x4xbf16, 2 : i32>)
air.herd_terminator
}
%10 = air.dma_memcpy_nd async [%async_token_8, %async_token_10, %9] (%arg9[%results_9, %results_11] [%c44, %c128] [%c2432, %c1], %results_13[%c0_7, %c0_7, %c0_7, %c0_7] [%c1, %c44, %c2, %c64] [%c5632, %c64, %c2816, %c1]) {id = 14 : i32} : (memref<308x2432xbf16>, memref<1x2x44x64xbf16, 1 : i32>)
%async_token_24 = air.execute [%9, %7] {
memref.dealloc %results_23 : memref<1x1x44x64xbf16, 1 : i32>
} {id = 29 : i32}
%async_token_25 = air.execute [%9, %8] {
memref.dealloc %results_21 : memref<1x2x64x64xbf16, 1 : i32>
} {id = 30 : i32}
%async_token_26 = air.execute [%9] {
memref.dealloc %results_19 : memref<1x1x8x11x4x8xbf16, 2 : i32>
} {id = 31 : i32}
%async_token_27 = air.execute [%9] {
memref.dealloc %results_17 : memref<1x1x16x8x8x4xbf16, 2 : i32>
} {id = 32 : i32}
%async_token_28 = air.execute [%9] {
memref.dealloc %results_15 : memref<1x2x16x11x4x4xbf16, 2 : i32>
} {id = 33 : i32}
%async_token_29 = air.execute [%10] {
memref.dealloc %results_13 : memref<1x2x44x64xbf16, 1 : i32>
} {id = 34 : i32}
air.segment_terminator
}
air.launch_terminator
}
return
}
}
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
module {
air.channel @channel_13 [1, 1]
air.channel @channel_12 [1, 2]
air.channel @channel_11 [1, 2]
air.channel @channel_10 [1, 1]
air.channel @channel_9 [1, 1]
air.channel @channel_8 [1, 2]
air.channel @channel_7 [1, 1]
air.channel @channel_6 [1, 1]
air.channel @channel_5 [1, 2]
air.channel @channel_4 [1, 1]
air.channel @channel_3 [1, 1]
air.channel @channel_2 [1, 1] {broadcast_shape = [1, 2]}
air.channel @channel_1 [1, 1] {broadcast_shape = [1, 2]}
air.channel @channel_0 [1, 1] {broadcast_shape = [1, 2]}
func.func @matmul_bf16_large_dispatch_0_matmul_308x2432x9728_bf16() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
%c19 = arith.constant 19 : index
%c7 = arith.constant 7 : index
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<308x9728xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<308x9728xbf16>
air.execute_terminator %1 : memref<308x9728xbf16>
} {id = 1 : i32}
%async_token_0 = air.execute [%async_token] {
memref.assume_alignment %results, 64 : memref<308x9728xbf16>
} {id = 2 : i32}
%async_token_1, %results_2 = air.execute -> (memref<9728x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<9728x2432xbf16>
air.execute_terminator %1 : memref<9728x2432xbf16>
} {id = 3 : i32}
%async_token_3 = air.execute [%async_token_1] {
memref.assume_alignment %results_2, 64 : memref<9728x2432xbf16>
} {id = 4 : i32}
%async_token_4, %results_5 = air.execute -> (memref<308x2432xbf16>) {
%1 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<308x2432xbf16>
air.execute_terminator %1 : memref<308x2432xbf16>
} {id = 5 : i32}
%async_token_6 = air.execute [%async_token_4] {
memref.assume_alignment %results_5, 64 : memref<308x2432xbf16>
} {id = 6 : i32}
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c7, %arg3=%c19) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<308x2432xbf16>, memref<308x9728xbf16>, memref<9728x2432xbf16> attributes {id = 5 : i32} {
%c64 = arith.constant 64 : index
%c44 = arith.constant 44 : index
%c9728 = arith.constant 9728 : index
%c0_7 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%async_token_8, %results_9 = air.execute -> (index) {
%72 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg0]
air.execute_terminator %72 : index
} {id = 7 : i32}
%1 = air.channel.put async [%async_token_8] @channel_3[] (%arg5[%results_9, %c0_7] [%c44, %c64] [%c9728, %c1]) : (memref<308x9728xbf16>)
%2 = air.wait_all async
%3 = air.wait_all async
%4 = air.wait_all async
%5 = air.wait_all async
%6 = air.wait_all async
%7 = air.wait_all async
%8 = air.wait_all async
%9 = air.wait_all async
%c2432 = arith.constant 2432 : index
%c155648 = arith.constant 155648 : index
%c64_10 = arith.constant 64 : index
%c0_11 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1_12 = arith.constant 1 : index
%async_token_13, %results_14 = air.execute -> (index) {
%72 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg1]
air.execute_terminator %72 : index
} {id = 8 : i32}
%10 = air.channel.put async [%async_token_13] @channel_4[] (%arg6[%c0_11, %c0_11, %c0_11, %results_14] [%c1_12, %c2, %c64_10, %c64_10] [%c155648, %c64_10, %c2432, %c1_12]) : (memref<9728x2432xbf16>)
%11 = air.wait_all async
%12 = air.wait_all async
%13 = air.wait_all async
%14 = air.wait_all async
%15 = air.wait_all async
%16 = air.wait_all async
%17 = air.wait_all async
%18 = air.wait_all async
%c64_15 = arith.constant 64 : index
%c44_16 = arith.constant 44 : index
%c9728_17 = arith.constant 9728 : index
%c1_18 = arith.constant 1 : index
%c151 = arith.constant 151 : index
%async_token_19, %results_20 = air.execute -> (index) {
%72 = affine.apply affine_map<()[s0] -> (s0 * 44)>()[%arg0]
air.execute_terminator %72 : index
} {id = 7 : i32}
%async_token_21, %results_22 = air.execute -> (index) {
%72 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%arg1]
air.execute_terminator %72 : index
} {id = 8 : i32}
%async_token_23, %results_24 = air.execute -> (memref<1x2x64x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x2x64x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x2x64x64xbf16, 1 : i32>
} {id = 13 : i32}
%async_token_25, %results_26 = air.execute -> (memref<1x1x44x64xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<1x1x44x64xbf16, 1 : i32>
air.execute_terminator %alloc : memref<1x1x44x64xbf16, 1 : i32>
} {id = 14 : i32}
%19 = air.wait_all async
%20 = air.wait_all async
%21 = air.wait_all async
%22 = air.wait_all async [%19, %20, %21] {id = 14 : i32}
%23 = scf.for %arg7 = %c1_18 to %c151 step %c1_18 iter_args(%arg8 = %22) -> (!air.async.token) {
%c44_64 = arith.constant 44 : index
%c64_65 = arith.constant 64 : index
%c9728_66 = arith.constant 9728 : index
%c1_67 = arith.constant 1 : index
%c0_68 = arith.constant 0 : index
%c2_69 = arith.constant 2 : index
%c155648_70 = arith.constant 155648 : index
%c2432_71 = arith.constant 2432 : index
%async_token_72, %results_73 = air.execute [%arg8] -> (index) {
%74 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg7]
air.execute_terminator %74 : index
} {id = 20 : i32}
%72 = air.channel.put async [%async_token_72, %async_token_19, %arg8] @channel_6[] (%arg5[%results_20, %results_73] [%c44_16, %c64_15] [%c9728_17, %c1_18]) : (memref<308x9728xbf16>)
%73 = air.wait_all async [%72]
scf.yield %73 : !air.async.token
}
%24 = air.wait_all async
%25 = air.wait_all async
%26 = air.wait_all async
%27 =
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment