Skip to content

Instantly share code, notes, and snippets.

@antiagainst
Last active December 19, 2023 07:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antiagainst/d555247460af2e4e153e8087dcde7e80 to your computer and use it in GitHub Desktop.
Save antiagainst/d555247460af2e4e153e8087dcde7e80 to your computer and use it in GitHub Desktop.
// tools/iree-compile --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=valhall-unknown-android31 ~/models/mhlo-dot.mlir -o /dev/null --mlir-print-ir-after-all --mlir-print-ir-after-change --mlir-disable-threading --mlir-elide-elementsattrs-if-larger=8 -debug-only=iree-spirv-vectorize &>! mhlo-dot.log
// iree-org/iree@a8e4c38c
// -----// IR Dump After mlir::iree_compiler::IREE::HAL::(anonymous namespace)::MaterializeInterfacesPass (iree-hal-materialize-interfaces) //----- //
#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, #spv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>}>], legacy_sync}>
#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, #spv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>
module attributes {hal.device.targets = [#device_target_vulkan]} {
hal.executable private @dot_dispatch_0 {
hal.executable.variant public @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb {
hal.executable.export public @dot_dispatch_0_matmul_128x64x256 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%x, %y, %z = flow.dispatch.default_workgroup_count %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @dot_dispatch_0_matmul_128x64x256() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<128x256xf32>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x64xf32>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<128x64xf32>
%7 = linalg.init_tensor [128, 64] : tensor<128x64xf32>
%8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x64xf32>) -> tensor<128x64xf32>
%9 = linalg.matmul ins(%4, %5 : tensor<128x256xf32>, tensor<256x64xf32>) outs(%8 : tensor<128x64xf32>) -> tensor<128x64xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%9, %6 : tensor<128x64xf32>, tensor<128x64xf32>) outs(%7 : tensor<128x64xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%11 = arith.subf %arg0, %arg1 : f32
linalg.yield %11 : f32
} -> tensor<128x64xf32>
flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [128, 64], strides = [1, 1] : tensor<128x64xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
return
}
}
}
}
func.func @dot(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0 = arith.constant 0 : index
%c131072 = arith.constant 131072 : index
%c65536 = arith.constant 65536 : index
%c32768 = arith.constant 32768 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c553648160_i32 = arith.constant 553648160 : i32
%c1_i32 = arith.constant 1 : i32
%c256 = arith.constant 256 : index
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%c128, %c256]) type(%c553648160_i32) encoding(%c1_i32)
%0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<128x256xf32> in !stream.resource<external>{%c131072}
hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("tensor") shape([%c256, %c64]) type(%c553648160_i32) encoding(%c1_i32)
%1 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<256x64xf32> in !stream.resource<external>{%c65536}
hal.buffer_view.assert<%arg2 : !hal.buffer_view> message("tensor") shape([%c128, %c64]) type(%c553648160_i32) encoding(%c1_i32)
%2 = stream.tensor.import %arg2 : !hal.buffer_view -> tensor<128x64xf32> in !stream.resource<external>{%c32768}
%3 = stream.resource.alloc uninitialized : !stream.resource<external>{%c32768}
%4 = stream.cmd.execute with(%0 as %arg3: !stream.resource<external>{%c131072}, %1 as %arg4: !stream.resource<external>{%c65536}, %2 as %arg5: !stream.resource<external>{%c32768}, %3 as %arg6: !stream.resource<external>{%c32768}) {
stream.cmd.dispatch @dot_dispatch_0::@dot_dispatch_0_matmul_128x64x256[%c128, %c64] {
ro %arg3[%c0 for %c131072] : !stream.resource<external>{%c131072},
ro %arg4[%c0 for %c65536] : !stream.resource<external>{%c65536},
ro %arg5[%c0 for %c32768] : !stream.resource<external>{%c32768},
wo %arg6[%c0 for %c32768] : !stream.resource<external>{%c32768}
} attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>, #hal.interface.binding<0, 3>]}
} => !stream.timepoint
%5 = stream.timepoint.await %4 => %3 : !stream.resource<external>{%c32768}
%6 = stream.tensor.export %5 : tensor<128x64xf32> in !stream.resource<external>{%c32768} -> !hal.buffer_view
return %6 : !hal.buffer_view
}
}
// -----// IR Dump After TileAndDistributeToWorkgroups (iree-codegen-tile-and-distribute-to-workgroups) //----- //
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, #spv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>}> {
hal.executable.export public @dot_dispatch_0_matmul_128x64x256 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<SPIRVVectorize workload_per_wg = [32, 8]>, workgroup_size = [8 : index, 2 : index, 1 : index]} {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%c1 = arith.constant 1 : index
%0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg1]
%1 = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg2]
hal.return %1, %0, %c1 : index, index, index
}
builtin.module {
func.func @dot_dispatch_0_matmul_128x64x256() {
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%c8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<?x256xf32>
%9 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, %c32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x?xf32>
%10 = linalg.init_tensor [8, 32] : tensor<8x32xf32>
%11 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%cst : f32) outs(%10 : tensor<8x32xf32>) -> tensor<8x32xf32>
%12 = tensor.cast %9 : tensor<256x?xf32> to tensor<256x32xf32>
%13 = tensor.cast %8 : tensor<?x256xf32> to tensor<8x256xf32>
%14 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%13, %12 : tensor<8x256xf32>, tensor<256x32xf32>) outs(%11 : tensor<8x32xf32>) -> tensor<8x32xf32>
%15 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c8, %c32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<?x?xf32>
%16 = tensor.cast %15 : tensor<?x?xf32> to tensor<8x32xf32>
%17 = linalg.init_tensor [8, 32] : tensor<8x32xf32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14, %16 : tensor<8x32xf32>, tensor<8x32xf32>) outs(%17 : tensor<8x32xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
%20 = arith.subf %arg2, %arg3 : f32
linalg.yield %20 : f32
} -> tensor<8x32xf32>
%19 = tensor.cast %18 : tensor<8x32xf32> to tensor<?x?xf32>
flow.dispatch.tensor.store %19, %3, offsets = [%arg0, %arg1], sizes = [%c8, %c32], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
}
}
// -----// IR Dump After ConvertToDestinationPassingStyle (iree-codegen-convert-to-destination-passing-style) //----- //
func.func @dot_dispatch_0_matmul_128x64x256() {
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%c8_0 = arith.constant 8 : index
%c32_1 = arith.constant 32 : index
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [%c8_0, %c32_1], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<?x?xf32>
%9 = tensor.cast %8 : tensor<?x?xf32> to tensor<8x32xf32>
%10 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%c8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<?x256xf32>
%11 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, %c32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x?xf32>
%12 = linalg.init_tensor [8, 32] : tensor<8x32xf32>
%13 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%cst : f32) outs(%9 : tensor<8x32xf32>) -> tensor<8x32xf32>
%14 = tensor.cast %11 : tensor<256x?xf32> to tensor<256x32xf32>
%15 = tensor.cast %10 : tensor<?x256xf32> to tensor<8x256xf32>
%16 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%15, %14 : tensor<8x256xf32>, tensor<256x32xf32>) outs(%13 : tensor<8x32xf32>) -> tensor<8x32xf32>
%17 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c8, %c32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<?x?xf32>
%18 = tensor.cast %17 : tensor<?x?xf32> to tensor<8x32xf32>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18 : tensor<8x32xf32>) outs(%16 : tensor<8x32xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} {
^bb0(%arg2: f32, %arg3: f32):
%21 = arith.subf %arg3, %arg2 : f32
linalg.yield %21 : f32
} -> tensor<8x32xf32>
%20 = tensor.cast %19 : tensor<8x32xf32> to tensor<?x?xf32>
flow.dispatch.tensor.store %20, %3, offsets = [%arg0, %arg1], sizes = [%c8, %c32], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @dot_dispatch_0_matmul_128x64x256() {
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%cst : f32) outs(%8 : tensor<8x32xf32>) -> tensor<8x32xf32>
%12 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%9, %10 : tensor<8x256xf32>, tensor<256x32xf32>) outs(%11 : tensor<8x32xf32>) -> tensor<8x32xf32>
%13 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<8x32xf32>) outs(%12 : tensor<8x32xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} {
^bb0(%arg2: f32, %arg3: f32):
%15 = arith.subf %arg3, %arg2 : f32
linalg.yield %15 : f32
} -> tensor<8x32xf32>
flow.dispatch.tensor.store %14, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
}
// -----// IR Dump After SPIRVTile (iree-spirv-tile) //----- //
func.func @dot_dispatch_0_matmul_128x64x256() {
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%cst : f32) outs(%15 : tensor<4x4xf32>) -> tensor<4x4xf32>
%17 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%18 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%19 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %16) -> (tensor<4x4xf32>) {
%22 = tensor.extract_slice %17[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%23 = tensor.extract_slice %18[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%24 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} ins(%22, %23 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg7 : tensor<4x4xf32>) -> tensor<4x4xf32>
scf.yield %24 : tensor<4x4xf32>
}
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<4x4xf32>) outs(%19 : tensor<4x4xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 32], [4, 4], [0, 0, 4]]>} {
^bb0(%arg6: f32, %arg7: f32):
%22 = arith.subf %arg7, %arg6 : f32
linalg.yield %22 : f32
} -> tensor<4x4xf32>
%21 = tensor.insert_slice %20 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %21 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After vectorization ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = vector.transfer_write %cst, %15[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
%17 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%18 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%19 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %16) -> (tensor<4x4xf32>) {
%25 = tensor.extract_slice %17[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%26 = tensor.extract_slice %18[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%27 = vector.transfer_read %25[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%28 = vector.transfer_read %26[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%29 = vector.transfer_read %arg7[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%30 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %27, %28, %29 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%31 = vector.transfer_write %30, %arg7[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
scf.yield %31 : tensor<4x4xf32>
}
%20 = vector.transfer_read %14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%21 = vector.transfer_read %19[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%22 = arith.subf %21, %20 : vector<4x4xf32>
%23 = vector.transfer_write %22, %19[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
%24 = tensor.insert_slice %23 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %24 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After peephole optimization ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = vector.transfer_write %cst, %15[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
%17 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%18 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%19 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %16) -> (tensor<4x4xf32>) {
%25 = tensor.extract_slice %17[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%26 = tensor.extract_slice %18[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%27 = vector.transfer_read %25[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%28 = vector.transfer_read %26[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%29 = vector.transfer_read %arg7[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%30 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %27, %28, %29 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%31 = vector.transfer_write %30, %arg7[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
scf.yield %31 : tensor<4x4xf32>
}
%20 = vector.transfer_read %14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%21 = vector.transfer_read %19[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%22 = arith.subf %21, %20 : vector<4x4xf32>
%23 = vector.transfer_write %22, %19[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
%24 = tensor.insert_slice %23 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %24 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After lowering multi_reduction ops ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = vector.transfer_write %cst, %15[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
%17 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%18 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%19 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %16) -> (tensor<4x4xf32>) {
%25 = tensor.extract_slice %17[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%26 = tensor.extract_slice %18[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%27 = vector.transfer_read %25[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%28 = vector.transfer_read %26[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%29 = vector.transfer_read %arg7[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%30 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %27, %28, %29 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%31 = vector.transfer_write %30, %arg7[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
scf.yield %31 : tensor<4x4xf32>
}
%20 = vector.transfer_read %14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%21 = vector.transfer_read %19[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%22 = arith.subf %21, %20 : vector<4x4xf32>
%23 = vector.transfer_write %22, %19[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
%24 = tensor.insert_slice %23 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %24 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After unrolling vector ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = vector.extract_strided_slice %cst {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
%17 = vector.transfer_write %16, %15[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%18 = vector.extract_strided_slice %cst {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
%19 = vector.transfer_write %18, %17[%c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%20 = vector.extract_strided_slice %cst {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
%21 = vector.transfer_write %20, %19[%c2, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%22 = vector.extract_strided_slice %cst {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
%23 = vector.transfer_write %22, %21[%c3, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%24 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%25 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%26 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %23) -> (tensor<4x4xf32>) {
%44 = tensor.extract_slice %24[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%45 = tensor.extract_slice %25[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%46 = vector.transfer_read %44[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%47 = vector.transfer_read %44[%c1, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%48 = vector.transfer_read %44[%c2, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%49 = vector.transfer_read %44[%c3, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%50 = vector.transfer_read %45[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%51 = vector.transfer_read %45[%c1, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%52 = vector.transfer_read %45[%c2, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%53 = vector.transfer_read %45[%c3, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%54 = vector.transfer_read %arg7[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%55 = vector.transfer_read %arg7[%c1, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%56 = vector.transfer_read %arg7[%c2, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%57 = vector.transfer_read %arg7[%c3, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%58 = vector.extract_strided_slice %46 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%59 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %58, %50, %54 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%60 = vector.extract_strided_slice %46 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%61 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %60, %51, %59 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%62 = vector.extract_strided_slice %46 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%63 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %62, %52, %61 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%64 = vector.extract_strided_slice %46 {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%65 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %64, %53, %63 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%66 = vector.extract_strided_slice %47 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%67 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %66, %50, %55 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%68 = vector.extract_strided_slice %47 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%69 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %68, %51, %67 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%70 = vector.extract_strided_slice %47 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%71 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %70, %52, %69 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%72 = vector.extract_strided_slice %47 {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%73 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %72, %53, %71 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%74 = vector.extract_strided_slice %48 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%75 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %74, %50, %56 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%76 = vector.extract_strided_slice %48 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%77 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %76, %51, %75 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%78 = vector.extract_strided_slice %48 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%79 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %78, %52, %77 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%80 = vector.extract_strided_slice %48 {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%81 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %80, %53, %79 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%82 = vector.extract_strided_slice %49 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%83 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %82, %50, %57 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%84 = vector.extract_strided_slice %49 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%85 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %84, %51, %83 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%86 = vector.extract_strided_slice %49 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%87 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %86, %52, %85 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%88 = vector.extract_strided_slice %49 {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
%89 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %88, %53, %87 : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
%90 = vector.transfer_write %65, %arg7[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%91 = vector.transfer_write %73, %90[%c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%92 = vector.transfer_write %81, %91[%c2, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%93 = vector.transfer_write %89, %92[%c3, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
scf.yield %93 : tensor<4x4xf32>
}
%27 = vector.transfer_read %14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%28 = vector.transfer_read %14[%c1, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%29 = vector.transfer_read %14[%c2, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%30 = vector.transfer_read %14[%c3, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%31 = vector.transfer_read %26[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%32 = vector.transfer_read %26[%c1, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%33 = vector.transfer_read %26[%c2, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%34 = vector.transfer_read %26[%c3, %c0], %cst_0 {in_bounds = [true, true]} : tensor<4x4xf32>, vector<1x4xf32>
%35 = arith.subf %31, %27 : vector<1x4xf32>
%36 = arith.subf %32, %28 : vector<1x4xf32>
%37 = arith.subf %33, %29 : vector<1x4xf32>
%38 = arith.subf %34, %30 : vector<1x4xf32>
%39 = vector.transfer_write %35, %26[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%40 = vector.transfer_write %36, %39[%c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%41 = vector.transfer_write %37, %40[%c2, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%42 = vector.transfer_write %38, %41[%c3, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<4x4xf32>
%43 = tensor.insert_slice %42 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %43 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After casting away leading size-1 dims ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%cst = arith.constant dense<0.000000e+00> : vector<4x4xf32>
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = vector.extract %cst[0] : vector<4x4xf32>
%17 = vector.transfer_write %16, %15[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%18 = vector.extract %cst[1] : vector<4x4xf32>
%19 = vector.transfer_write %18, %17[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%20 = vector.extract %cst[2] : vector<4x4xf32>
%21 = vector.transfer_write %20, %19[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%22 = vector.extract %cst[3] : vector<4x4xf32>
%23 = vector.transfer_write %22, %21[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%24 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%25 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%26 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %23) -> (tensor<4x4xf32>) {
%44 = tensor.extract_slice %24[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%45 = tensor.extract_slice %25[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%46 = vector.transfer_read %44[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%47 = vector.transfer_read %44[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%48 = vector.transfer_read %44[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%49 = vector.transfer_read %44[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%50 = vector.transfer_read %45[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%51 = vector.broadcast %50 : vector<4xf32> to vector<1x4xf32>
%52 = vector.transfer_read %45[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%53 = vector.broadcast %52 : vector<4xf32> to vector<1x4xf32>
%54 = vector.transfer_read %45[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%55 = vector.broadcast %54 : vector<4xf32> to vector<1x4xf32>
%56 = vector.transfer_read %45[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%57 = vector.broadcast %56 : vector<4xf32> to vector<1x4xf32>
%58 = vector.transfer_read %arg7[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%59 = vector.transfer_read %arg7[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%60 = vector.transfer_read %arg7[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%61 = vector.transfer_read %arg7[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%62 = vector.extract_strided_slice %46 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%63 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %62, %51, %58 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%64 = vector.extract_strided_slice %46 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%65 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %64, %53, %63 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%66 = vector.extract_strided_slice %46 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%67 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %66, %55, %65 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%68 = vector.extract_strided_slice %46 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%69 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %68, %57, %67 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%70 = vector.extract_strided_slice %47 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%71 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %70, %51, %59 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%72 = vector.extract_strided_slice %47 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%73 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %72, %53, %71 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%74 = vector.extract_strided_slice %47 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%75 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %74, %55, %73 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%76 = vector.extract_strided_slice %47 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%77 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %76, %57, %75 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%78 = vector.extract_strided_slice %48 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%79 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %78, %51, %60 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%80 = vector.extract_strided_slice %48 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%81 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %80, %53, %79 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%82 = vector.extract_strided_slice %48 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%83 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %82, %55, %81 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%84 = vector.extract_strided_slice %48 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%85 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %84, %57, %83 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%86 = vector.extract_strided_slice %49 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%87 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %86, %51, %61 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%88 = vector.extract_strided_slice %49 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%89 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %88, %53, %87 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%90 = vector.extract_strided_slice %49 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%91 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %90, %55, %89 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%92 = vector.extract_strided_slice %49 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%93 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %92, %57, %91 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%94 = vector.transfer_write %69, %arg7[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%95 = vector.transfer_write %77, %94[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%96 = vector.transfer_write %85, %95[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%97 = vector.transfer_write %93, %96[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
scf.yield %97 : tensor<4x4xf32>
}
%27 = vector.transfer_read %14[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%28 = vector.transfer_read %14[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%29 = vector.transfer_read %14[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%30 = vector.transfer_read %14[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%31 = vector.transfer_read %26[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%32 = vector.transfer_read %26[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%33 = vector.transfer_read %26[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%34 = vector.transfer_read %26[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%35 = arith.subf %31, %27 : vector<4xf32>
%36 = arith.subf %32, %28 : vector<4xf32>
%37 = arith.subf %33, %29 : vector<4xf32>
%38 = arith.subf %34, %30 : vector<4xf32>
%39 = vector.transfer_write %35, %26[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%40 = vector.transfer_write %36, %39[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%41 = vector.transfer_write %37, %40[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%42 = vector.transfer_write %38, %41[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%43 = tensor.insert_slice %42 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %43 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After breaking down n-D inserts/extracts ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c128 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c64 step %7 {
%8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %8) -> (tensor<8x32xf32>) {
%13 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%14 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%15 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = vector.transfer_write %cst, %15[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%17 = vector.transfer_write %cst, %16[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%18 = vector.transfer_write %cst, %17[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%19 = vector.transfer_write %cst, %18[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%20 = tensor.extract_slice %9[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%21 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%22 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %19) -> (tensor<4x4xf32>) {
%40 = tensor.extract_slice %20[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%41 = tensor.extract_slice %21[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%42 = vector.transfer_read %40[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%43 = vector.transfer_read %40[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%44 = vector.transfer_read %40[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%45 = vector.transfer_read %40[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%46 = vector.transfer_read %41[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%47 = vector.broadcast %46 : vector<4xf32> to vector<1x4xf32>
%48 = vector.transfer_read %41[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%49 = vector.broadcast %48 : vector<4xf32> to vector<1x4xf32>
%50 = vector.transfer_read %41[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%51 = vector.broadcast %50 : vector<4xf32> to vector<1x4xf32>
%52 = vector.transfer_read %41[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%53 = vector.broadcast %52 : vector<4xf32> to vector<1x4xf32>
%54 = vector.transfer_read %arg7[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%55 = vector.transfer_read %arg7[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%56 = vector.transfer_read %arg7[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%57 = vector.transfer_read %arg7[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%58 = vector.extract_strided_slice %42 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%59 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %58, %47, %54 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%60 = vector.extract_strided_slice %42 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%61 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %60, %49, %59 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%62 = vector.extract_strided_slice %42 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%63 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %62, %51, %61 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%64 = vector.extract_strided_slice %42 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%65 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %64, %53, %63 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%66 = vector.extract_strided_slice %43 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%67 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %66, %47, %55 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%68 = vector.extract_strided_slice %43 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%69 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %68, %49, %67 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%70 = vector.extract_strided_slice %43 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%71 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %70, %51, %69 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%72 = vector.extract_strided_slice %43 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%73 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %72, %53, %71 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%74 = vector.extract_strided_slice %44 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%75 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %74, %47, %56 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%76 = vector.extract_strided_slice %44 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%77 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %76, %49, %75 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%78 = vector.extract_strided_slice %44 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%79 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %78, %51, %77 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%80 = vector.extract_strided_slice %44 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%81 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %80, %53, %79 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%82 = vector.extract_strided_slice %45 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%83 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %82, %47, %57 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%84 = vector.extract_strided_slice %45 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%85 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %84, %49, %83 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%86 = vector.extract_strided_slice %45 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%87 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %86, %51, %85 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%88 = vector.extract_strided_slice %45 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%89 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %88, %53, %87 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%90 = vector.transfer_write %65, %arg7[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%91 = vector.transfer_write %73, %90[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%92 = vector.transfer_write %81, %91[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%93 = vector.transfer_write %89, %92[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
scf.yield %93 : tensor<4x4xf32>
}
%23 = vector.transfer_read %14[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%24 = vector.transfer_read %14[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%25 = vector.transfer_read %14[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%26 = vector.transfer_read %14[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%27 = vector.transfer_read %22[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%28 = vector.transfer_read %22[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%29 = vector.transfer_read %22[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%30 = vector.transfer_read %22[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%31 = arith.subf %27, %23 : vector<4xf32>
%32 = arith.subf %28, %24 : vector<4xf32>
%33 = arith.subf %29, %25 : vector<4xf32>
%34 = arith.subf %30, %26 : vector<4xf32>
%35 = vector.transfer_write %31, %22[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%36 = vector.transfer_write %32, %35[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%37 = vector.transfer_write %33, %36[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%38 = vector.transfer_write %34, %37[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%39 = tensor.insert_slice %38 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %39 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %13 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After hoisting vector transfers ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %4 to %c128 step %5 {
%8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
scf.for %arg1 = %6 to %c64 step %7 {
%9 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %9) -> (tensor<8x32xf32>) {
%13 = tensor.extract_slice %8[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%14 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%15 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%17 = vector.transfer_write %cst, %16[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%18 = vector.transfer_write %cst, %17[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%19 = vector.transfer_write %cst, %18[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%20 = vector.transfer_write %cst, %19[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%21 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%22:4 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%40 = tensor.extract_slice %13[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%41 = tensor.extract_slice %21[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%42 = vector.transfer_read %40[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%43 = vector.transfer_read %40[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%44 = vector.transfer_read %40[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%45 = vector.transfer_read %40[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%46 = vector.transfer_read %41[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%47 = vector.broadcast %46 : vector<4xf32> to vector<1x4xf32>
%48 = vector.transfer_read %41[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%49 = vector.broadcast %48 : vector<4xf32> to vector<1x4xf32>
%50 = vector.transfer_read %41[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%51 = vector.broadcast %50 : vector<4xf32> to vector<1x4xf32>
%52 = vector.transfer_read %41[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%53 = vector.broadcast %52 : vector<4xf32> to vector<1x4xf32>
%54 = vector.extract_strided_slice %42 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%55 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %54, %47, %arg10 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%56 = vector.extract_strided_slice %42 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%57 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %56, %49, %55 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%58 = vector.extract_strided_slice %42 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%59 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %58, %51, %57 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%60 = vector.extract_strided_slice %42 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%61 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %60, %53, %59 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%62 = vector.extract_strided_slice %43 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%63 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %62, %47, %arg9 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%64 = vector.extract_strided_slice %43 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%65 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %64, %49, %63 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%66 = vector.extract_strided_slice %43 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%67 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %66, %51, %65 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%68 = vector.extract_strided_slice %43 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%69 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %68, %53, %67 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%70 = vector.extract_strided_slice %44 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%71 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %70, %47, %arg8 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%72 = vector.extract_strided_slice %44 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%73 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %72, %49, %71 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%74 = vector.extract_strided_slice %44 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%75 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %74, %51, %73 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%76 = vector.extract_strided_slice %44 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%77 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %76, %53, %75 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%78 = vector.extract_strided_slice %45 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%79 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %78, %47, %arg7 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%80 = vector.extract_strided_slice %45 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%81 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %80, %49, %79 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%82 = vector.extract_strided_slice %45 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%83 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %82, %51, %81 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%84 = vector.extract_strided_slice %45 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%85 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %84, %53, %83 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
scf.yield %85, %77, %69, %61 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
%23 = vector.transfer_write %22#3, %20[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%24 = vector.transfer_write %22#2, %23[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%25 = vector.transfer_write %22#1, %24[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%26 = vector.transfer_write %22#0, %25[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%27 = vector.transfer_read %15[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%28 = vector.transfer_read %15[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%29 = vector.transfer_read %15[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%30 = vector.transfer_read %15[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%31 = arith.subf %22#3, %27 : vector<4xf32>
%32 = arith.subf %22#2, %28 : vector<4xf32>
%33 = arith.subf %22#1, %29 : vector<4xf32>
%34 = arith.subf %22#0, %30 : vector<4xf32>
%35 = vector.transfer_write %31, %26[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%36 = vector.transfer_write %32, %35[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%37 = vector.transfer_write %33, %36[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%38 = vector.transfer_write %34, %37[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%39 = tensor.insert_slice %38 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %39 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %14 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After lowering transfer ops ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %4 to %c128 step %5 {
%8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
scf.for %arg1 = %6 to %c64 step %7 {
%9 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %9) -> (tensor<8x32xf32>) {
%13 = tensor.extract_slice %8[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%14 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%15 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%17 = vector.transfer_write %cst, %16[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%18 = vector.transfer_write %cst, %17[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%19 = vector.transfer_write %cst, %18[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%20 = vector.transfer_write %cst, %19[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%21 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%22:4 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%40 = tensor.extract_slice %13[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%41 = tensor.extract_slice %21[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%42 = vector.transfer_read %40[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%43 = vector.transfer_read %40[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%44 = vector.transfer_read %40[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%45 = vector.transfer_read %40[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%46 = vector.transfer_read %41[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%47 = vector.broadcast %46 : vector<4xf32> to vector<1x4xf32>
%48 = vector.transfer_read %41[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%49 = vector.broadcast %48 : vector<4xf32> to vector<1x4xf32>
%50 = vector.transfer_read %41[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%51 = vector.broadcast %50 : vector<4xf32> to vector<1x4xf32>
%52 = vector.transfer_read %41[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%53 = vector.broadcast %52 : vector<4xf32> to vector<1x4xf32>
%54 = vector.extract_strided_slice %42 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%55 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %54, %47, %arg10 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%56 = vector.extract_strided_slice %42 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%57 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %56, %49, %55 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%58 = vector.extract_strided_slice %42 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%59 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %58, %51, %57 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%60 = vector.extract_strided_slice %42 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%61 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %60, %53, %59 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%62 = vector.extract_strided_slice %43 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%63 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %62, %47, %arg9 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%64 = vector.extract_strided_slice %43 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%65 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %64, %49, %63 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%66 = vector.extract_strided_slice %43 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%67 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %66, %51, %65 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%68 = vector.extract_strided_slice %43 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%69 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %68, %53, %67 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%70 = vector.extract_strided_slice %44 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%71 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %70, %47, %arg8 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%72 = vector.extract_strided_slice %44 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%73 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %72, %49, %71 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%74 = vector.extract_strided_slice %44 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%75 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %74, %51, %73 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%76 = vector.extract_strided_slice %44 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%77 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %76, %53, %75 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%78 = vector.extract_strided_slice %45 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%79 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %78, %47, %arg7 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%80 = vector.extract_strided_slice %45 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%81 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %80, %49, %79 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%82 = vector.extract_strided_slice %45 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%83 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %82, %51, %81 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
%84 = vector.extract_strided_slice %45 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%85 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %84, %53, %83 : vector<1xf32>, vector<1x4xf32> into vector<4xf32>
scf.yield %85, %77, %69, %61 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
%23 = vector.transfer_write %22#3, %20[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%24 = vector.transfer_write %22#2, %23[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%25 = vector.transfer_write %22#1, %24[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%26 = vector.transfer_write %22#0, %25[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%27 = vector.transfer_read %15[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%28 = vector.transfer_read %15[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%29 = vector.transfer_read %15[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%30 = vector.transfer_read %15[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%31 = arith.subf %22#3, %27 : vector<4xf32>
%32 = arith.subf %22#2, %28 : vector<4xf32>
%33 = arith.subf %22#1, %29 : vector<4xf32>
%34 = arith.subf %22#0, %30 : vector<4xf32>
%35 = vector.transfer_write %31, %26[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%36 = vector.transfer_write %32, %35[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%37 = vector.transfer_write %33, %36[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%38 = vector.transfer_write %34, %37[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%39 = tensor.insert_slice %38 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %39 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %14 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After lowering contract ops ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %4 to %c128 step %5 {
%8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
scf.for %arg1 = %6 to %c64 step %7 {
%9 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %9) -> (tensor<8x32xf32>) {
%13 = tensor.extract_slice %8[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%14 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%15 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%17 = vector.transfer_write %cst, %16[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%18 = vector.transfer_write %cst, %17[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%19 = vector.transfer_write %cst, %18[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%20 = vector.transfer_write %cst, %19[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%21 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%22:4 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%40 = tensor.extract_slice %13[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%41 = tensor.extract_slice %21[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%42 = vector.transfer_read %40[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%43 = vector.transfer_read %40[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%44 = vector.transfer_read %40[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%45 = vector.transfer_read %40[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%46 = vector.transfer_read %41[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%47 = vector.transfer_read %41[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%48 = vector.transfer_read %41[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%49 = vector.transfer_read %41[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%50 = vector.extract_strided_slice %42 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%51 = vector.broadcast %50 : vector<1xf32> to vector<1x4xf32>
%52 = vector.extract %51[0] : vector<1x4xf32>
%53 = vector.fma %52, %46, %arg10 : vector<4xf32>
%54 = vector.extract_strided_slice %42 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%55 = vector.broadcast %54 : vector<1xf32> to vector<1x4xf32>
%56 = vector.extract %55[0] : vector<1x4xf32>
%57 = vector.fma %56, %47, %53 : vector<4xf32>
%58 = vector.extract_strided_slice %42 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%59 = vector.broadcast %58 : vector<1xf32> to vector<1x4xf32>
%60 = vector.extract %59[0] : vector<1x4xf32>
%61 = vector.fma %60, %48, %57 : vector<4xf32>
%62 = vector.extract_strided_slice %42 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%63 = vector.broadcast %62 : vector<1xf32> to vector<1x4xf32>
%64 = vector.extract %63[0] : vector<1x4xf32>
%65 = vector.fma %64, %49, %61 : vector<4xf32>
%66 = vector.extract_strided_slice %43 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%67 = vector.broadcast %66 : vector<1xf32> to vector<1x4xf32>
%68 = vector.extract %67[0] : vector<1x4xf32>
%69 = vector.fma %68, %46, %arg9 : vector<4xf32>
%70 = vector.extract_strided_slice %43 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%71 = vector.broadcast %70 : vector<1xf32> to vector<1x4xf32>
%72 = vector.extract %71[0] : vector<1x4xf32>
%73 = vector.fma %72, %47, %69 : vector<4xf32>
%74 = vector.extract_strided_slice %43 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%75 = vector.broadcast %74 : vector<1xf32> to vector<1x4xf32>
%76 = vector.extract %75[0] : vector<1x4xf32>
%77 = vector.fma %76, %48, %73 : vector<4xf32>
%78 = vector.extract_strided_slice %43 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%79 = vector.broadcast %78 : vector<1xf32> to vector<1x4xf32>
%80 = vector.extract %79[0] : vector<1x4xf32>
%81 = vector.fma %80, %49, %77 : vector<4xf32>
%82 = vector.extract_strided_slice %44 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%83 = vector.broadcast %82 : vector<1xf32> to vector<1x4xf32>
%84 = vector.extract %83[0] : vector<1x4xf32>
%85 = vector.fma %84, %46, %arg8 : vector<4xf32>
%86 = vector.extract_strided_slice %44 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%87 = vector.broadcast %86 : vector<1xf32> to vector<1x4xf32>
%88 = vector.extract %87[0] : vector<1x4xf32>
%89 = vector.fma %88, %47, %85 : vector<4xf32>
%90 = vector.extract_strided_slice %44 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%91 = vector.broadcast %90 : vector<1xf32> to vector<1x4xf32>
%92 = vector.extract %91[0] : vector<1x4xf32>
%93 = vector.fma %92, %48, %89 : vector<4xf32>
%94 = vector.extract_strided_slice %44 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%95 = vector.broadcast %94 : vector<1xf32> to vector<1x4xf32>
%96 = vector.extract %95[0] : vector<1x4xf32>
%97 = vector.fma %96, %49, %93 : vector<4xf32>
%98 = vector.extract_strided_slice %45 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%99 = vector.broadcast %98 : vector<1xf32> to vector<1x4xf32>
%100 = vector.extract %99[0] : vector<1x4xf32>
%101 = vector.fma %100, %46, %arg7 : vector<4xf32>
%102 = vector.extract_strided_slice %45 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%103 = vector.broadcast %102 : vector<1xf32> to vector<1x4xf32>
%104 = vector.extract %103[0] : vector<1x4xf32>
%105 = vector.fma %104, %47, %101 : vector<4xf32>
%106 = vector.extract_strided_slice %45 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%107 = vector.broadcast %106 : vector<1xf32> to vector<1x4xf32>
%108 = vector.extract %107[0] : vector<1x4xf32>
%109 = vector.fma %108, %48, %105 : vector<4xf32>
%110 = vector.extract_strided_slice %45 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%111 = vector.broadcast %110 : vector<1xf32> to vector<1x4xf32>
%112 = vector.extract %111[0] : vector<1x4xf32>
%113 = vector.fma %112, %49, %109 : vector<4xf32>
scf.yield %113, %97, %81, %65 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
%23 = vector.transfer_write %22#3, %20[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%24 = vector.transfer_write %22#2, %23[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%25 = vector.transfer_write %22#1, %24[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%26 = vector.transfer_write %22#0, %25[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%27 = vector.transfer_read %15[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%28 = vector.transfer_read %15[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%29 = vector.transfer_read %15[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%30 = vector.transfer_read %15[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%31 = arith.subf %22#3, %27 : vector<4xf32>
%32 = arith.subf %22#2, %28 : vector<4xf32>
%33 = arith.subf %22#1, %29 : vector<4xf32>
%34 = arith.subf %22#0, %30 : vector<4xf32>
%35 = vector.transfer_write %31, %26[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%36 = vector.transfer_write %32, %35[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%37 = vector.transfer_write %33, %36[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%38 = vector.transfer_write %34, %37[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%39 = tensor.insert_slice %38 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %39 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %14 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
--- After lowering various vector ops ---
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %4 to %c128 step %5 {
%8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
scf.for %arg1 = %6 to %c64 step %7 {
%9 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %9) -> (tensor<8x32xf32>) {
%13 = tensor.extract_slice %8[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%14 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%15 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%17 = vector.transfer_write %cst, %16[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%18 = vector.transfer_write %cst, %17[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%19 = vector.transfer_write %cst, %18[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%20 = vector.transfer_write %cst, %19[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%21 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%22:4 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%40 = tensor.extract_slice %13[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%41 = tensor.extract_slice %21[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%42 = vector.transfer_read %40[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%43 = vector.transfer_read %40[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%44 = vector.transfer_read %40[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%45 = vector.transfer_read %40[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%46 = vector.transfer_read %41[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%47 = vector.transfer_read %41[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%48 = vector.transfer_read %41[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%49 = vector.transfer_read %41[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%50 = vector.extract %42[0] : vector<4xf32>
%51 = vector.splat %50 : vector<4xf32>
%52 = vector.fma %51, %46, %arg10 : vector<4xf32>
%53 = vector.extract %42[1] : vector<4xf32>
%54 = vector.splat %53 : vector<4xf32>
%55 = vector.fma %54, %47, %52 : vector<4xf32>
%56 = vector.extract %42[2] : vector<4xf32>
%57 = vector.splat %56 : vector<4xf32>
%58 = vector.fma %57, %48, %55 : vector<4xf32>
%59 = vector.extract %42[3] : vector<4xf32>
%60 = vector.splat %59 : vector<4xf32>
%61 = vector.fma %60, %49, %58 : vector<4xf32>
%62 = vector.extract %43[0] : vector<4xf32>
%63 = vector.splat %62 : vector<4xf32>
%64 = vector.fma %63, %46, %arg9 : vector<4xf32>
%65 = vector.extract %43[1] : vector<4xf32>
%66 = vector.splat %65 : vector<4xf32>
%67 = vector.fma %66, %47, %64 : vector<4xf32>
%68 = vector.extract %43[2] : vector<4xf32>
%69 = vector.splat %68 : vector<4xf32>
%70 = vector.fma %69, %48, %67 : vector<4xf32>
%71 = vector.extract %43[3] : vector<4xf32>
%72 = vector.splat %71 : vector<4xf32>
%73 = vector.fma %72, %49, %70 : vector<4xf32>
%74 = vector.extract %44[0] : vector<4xf32>
%75 = vector.splat %74 : vector<4xf32>
%76 = vector.fma %75, %46, %arg8 : vector<4xf32>
%77 = vector.extract %44[1] : vector<4xf32>
%78 = vector.splat %77 : vector<4xf32>
%79 = vector.fma %78, %47, %76 : vector<4xf32>
%80 = vector.extract %44[2] : vector<4xf32>
%81 = vector.splat %80 : vector<4xf32>
%82 = vector.fma %81, %48, %79 : vector<4xf32>
%83 = vector.extract %44[3] : vector<4xf32>
%84 = vector.splat %83 : vector<4xf32>
%85 = vector.fma %84, %49, %82 : vector<4xf32>
%86 = vector.extract %45[0] : vector<4xf32>
%87 = vector.splat %86 : vector<4xf32>
%88 = vector.fma %87, %46, %arg7 : vector<4xf32>
%89 = vector.extract %45[1] : vector<4xf32>
%90 = vector.splat %89 : vector<4xf32>
%91 = vector.fma %90, %47, %88 : vector<4xf32>
%92 = vector.extract %45[2] : vector<4xf32>
%93 = vector.splat %92 : vector<4xf32>
%94 = vector.fma %93, %48, %91 : vector<4xf32>
%95 = vector.extract %45[3] : vector<4xf32>
%96 = vector.splat %95 : vector<4xf32>
%97 = vector.fma %96, %49, %94 : vector<4xf32>
scf.yield %97, %85, %73, %61 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
%23 = vector.transfer_write %22#3, %20[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%24 = vector.transfer_write %22#2, %23[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%25 = vector.transfer_write %22#1, %24[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%26 = vector.transfer_write %22#0, %25[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%27 = vector.transfer_read %15[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%28 = vector.transfer_read %15[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%29 = vector.transfer_read %15[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%30 = vector.transfer_read %15[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%31 = arith.subf %22#3, %27 : vector<4xf32>
%32 = arith.subf %22#2, %28 : vector<4xf32>
%33 = arith.subf %22#1, %29 : vector<4xf32>
%34 = arith.subf %22#0, %30 : vector<4xf32>
%35 = vector.transfer_write %31, %26[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%36 = vector.transfer_write %32, %35[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%37 = vector.transfer_write %33, %36[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%38 = vector.transfer_write %34, %37[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%39 = tensor.insert_slice %38 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %39 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %14 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
// -----// IR Dump After SPIRVVectorize (iree-spirv-vectorize) //----- //
func.func @dot_dispatch_0_matmul_128x64x256() {
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:256x64xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x64xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
%6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %4 to %c128 step %5 {
%8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [8, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<8x256xf32>
scf.for %arg1 = %6 to %c64 step %7 {
%9 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<writeonly:128x64xf32> -> tensor<8x32xf32>
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x64xf32> -> tensor<256x32xf32>
%11 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<8x32xf32>
%12 = scf.for %arg2 = %c0 to %c8 step %c4 iter_args(%arg3 = %9) -> (tensor<8x32xf32>) {
%13 = tensor.extract_slice %8[%arg2, 0] [4, 256] [1, 1] : tensor<8x256xf32> to tensor<4x256xf32>
%14 = scf.for %arg4 = %c0 to %c32 step %c4 iter_args(%arg5 = %arg3) -> (tensor<8x32xf32>) {
%15 = tensor.extract_slice %11[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%16 = tensor.extract_slice %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<8x32xf32> to tensor<4x4xf32>
%17 = tensor.extract_slice %10[0, %arg4] [256, 4] [1, 1] : tensor<256x32xf32> to tensor<256x4xf32>
%18:4 = scf.for %arg6 = %c0 to %c256 step %c4 iter_args(%arg7 = %cst, %arg8 = %cst, %arg9 = %cst, %arg10 = %cst) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%32 = tensor.extract_slice %13[0, %arg6] [4, 4] [1, 1] : tensor<4x256xf32> to tensor<4x4xf32>
%33 = tensor.extract_slice %17[%arg6, 0] [4, 4] [1, 1] : tensor<256x4xf32> to tensor<4x4xf32>
%34 = vector.transfer_read %32[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%35 = vector.transfer_read %32[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%36 = vector.transfer_read %32[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%37 = vector.transfer_read %32[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%38 = vector.transfer_read %33[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%39 = vector.transfer_read %33[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%40 = vector.transfer_read %33[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%41 = vector.transfer_read %33[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%42 = vector.extract %34[0] : vector<4xf32>
%43 = vector.splat %42 : vector<4xf32>
%44 = vector.fma %43, %38, %arg10 : vector<4xf32>
%45 = vector.extract %34[1] : vector<4xf32>
%46 = vector.splat %45 : vector<4xf32>
%47 = vector.fma %46, %39, %44 : vector<4xf32>
%48 = vector.extract %34[2] : vector<4xf32>
%49 = vector.splat %48 : vector<4xf32>
%50 = vector.fma %49, %40, %47 : vector<4xf32>
%51 = vector.extract %34[3] : vector<4xf32>
%52 = vector.splat %51 : vector<4xf32>
%53 = vector.fma %52, %41, %50 : vector<4xf32>
%54 = vector.extract %35[0] : vector<4xf32>
%55 = vector.splat %54 : vector<4xf32>
%56 = vector.fma %55, %38, %arg9 : vector<4xf32>
%57 = vector.extract %35[1] : vector<4xf32>
%58 = vector.splat %57 : vector<4xf32>
%59 = vector.fma %58, %39, %56 : vector<4xf32>
%60 = vector.extract %35[2] : vector<4xf32>
%61 = vector.splat %60 : vector<4xf32>
%62 = vector.fma %61, %40, %59 : vector<4xf32>
%63 = vector.extract %35[3] : vector<4xf32>
%64 = vector.splat %63 : vector<4xf32>
%65 = vector.fma %64, %41, %62 : vector<4xf32>
%66 = vector.extract %36[0] : vector<4xf32>
%67 = vector.splat %66 : vector<4xf32>
%68 = vector.fma %67, %38, %arg8 : vector<4xf32>
%69 = vector.extract %36[1] : vector<4xf32>
%70 = vector.splat %69 : vector<4xf32>
%71 = vector.fma %70, %39, %68 : vector<4xf32>
%72 = vector.extract %36[2] : vector<4xf32>
%73 = vector.splat %72 : vector<4xf32>
%74 = vector.fma %73, %40, %71 : vector<4xf32>
%75 = vector.extract %36[3] : vector<4xf32>
%76 = vector.splat %75 : vector<4xf32>
%77 = vector.fma %76, %41, %74 : vector<4xf32>
%78 = vector.extract %37[0] : vector<4xf32>
%79 = vector.splat %78 : vector<4xf32>
%80 = vector.fma %79, %38, %arg7 : vector<4xf32>
%81 = vector.extract %37[1] : vector<4xf32>
%82 = vector.splat %81 : vector<4xf32>
%83 = vector.fma %82, %39, %80 : vector<4xf32>
%84 = vector.extract %37[2] : vector<4xf32>
%85 = vector.splat %84 : vector<4xf32>
%86 = vector.fma %85, %40, %83 : vector<4xf32>
%87 = vector.extract %37[3] : vector<4xf32>
%88 = vector.splat %87 : vector<4xf32>
%89 = vector.fma %88, %41, %86 : vector<4xf32>
scf.yield %89, %77, %65, %53 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
%19 = vector.transfer_read %15[%c0, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%20 = vector.transfer_read %15[%c1, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%21 = vector.transfer_read %15[%c2, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%22 = vector.transfer_read %15[%c3, %c0], %cst_0 {in_bounds = [true]} : tensor<4x4xf32>, vector<4xf32>
%23 = arith.subf %18#3, %19 : vector<4xf32>
%24 = arith.subf %18#2, %20 : vector<4xf32>
%25 = arith.subf %18#1, %21 : vector<4xf32>
%26 = arith.subf %18#0, %22 : vector<4xf32>
%27 = vector.transfer_write %23, %16[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%28 = vector.transfer_write %24, %27[%c1, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%29 = vector.transfer_write %25, %28[%c2, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%30 = vector.transfer_write %26, %29[%c3, %c0] {in_bounds = [true]} : vector<4xf32>, tensor<4x4xf32>
%31 = tensor.insert_slice %30 into %arg5[%arg2, %arg4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x32xf32>
scf.yield %31 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 0 : index}
scf.yield %14 : tensor<8x32xf32>
} {iree.spirv.distribute_dim = 1 : index}
flow.dispatch.tensor.store %12, %3, offsets = [%arg0, %arg1], sizes = [8, 32], strides = [1, 1] : tensor<8x32xf32> -> !flow.dispatch.tensor<writeonly:128x64xf32>
}
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment