Skip to content

Instantly share code, notes, and snippets.

@antiagainst
Created September 7, 2023 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antiagainst/df550182422a1ba56a0b3676a29bc054 to your computer and use it in GitHub Desktop.
Save antiagainst/df550182422a1ba56a0b3676a29bc054 to your computer and use it in GitHub Desktop.
wip-gemv-subgroup.mlir
This file has been truncated, but you can view the full file.
// -----// IR Dump After TypePropagation (iree-codegen-type-propagation) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%15 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%16 = flow.dispatch.tensor.load %11, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%17 = flow.dispatch.tensor.load %12, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%18 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%19 = tensor.empty() : tensor<4096xf16>
%20 = tensor.empty() : tensor<4096x32x128xf16>
%21 = linalg.fill ins(%cst : f16) outs(%19 : tensor<4096xf16>) -> tensor<4096xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %16, %17 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%20 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%24 = arith.extui %in : i4 to i32
%25 = arith.uitofp %24 : i32 to f16
%26 = arith.subf %25, %in_1 : f16
%27 = arith.mulf %26, %in_0 : f16
linalg.yield %27 : f16
} -> tensor<4096x32x128xf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%18, %22 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%21 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%24 = arith.mulf %in, %in_0 : f16
%25 = arith.addf %24, %out : f16
linalg.yield %25 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %23, %14, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After BubbleUpOrdinalOps (iree-codegen-bubble-up-ordinal-ops) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%15 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%16 = flow.dispatch.tensor.load %11, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%17 = flow.dispatch.tensor.load %12, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%18 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%19 = tensor.empty() : tensor<4096xf16>
%20 = tensor.empty() : tensor<4096x32x128xf16>
%21 = linalg.fill ins(%cst : f16) outs(%19 : tensor<4096xf16>) -> tensor<4096xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %16, %17 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%20 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%24 = arith.extui %in : i4 to i32
%25 = arith.uitofp %24 : i32 to f16
%26 = arith.subf %25, %in_1 : f16
%27 = arith.mulf %26, %in_0 : f16
linalg.yield %27 : f16
} -> tensor<4096x32x128xf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%18, %22 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%21 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%24 = arith.mulf %in, %in_0 : f16
%25 = arith.addf %24, %out : f16
linalg.yield %25 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %23, %14, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
// -----// IR Dump After BufferizeCopyOnlyDispatches (iree-codegen-bufferize-copy-only-dispatches) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%15 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%16 = flow.dispatch.tensor.load %11, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%17 = flow.dispatch.tensor.load %12, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%18 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%19 = tensor.empty() : tensor<4096xf16>
%20 = tensor.empty() : tensor<4096x32x128xf16>
%21 = linalg.fill ins(%cst : f16) outs(%19 : tensor<4096xf16>) -> tensor<4096xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %16, %17 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%20 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%24 = arith.extui %in : i4 to i32
%25 = arith.uitofp %24 : i32 to f16
%26 = arith.subf %25, %in_1 : f16
%27 = arith.mulf %26, %in_0 : f16
linalg.yield %27 : f16
} -> tensor<4096x32x128xf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%18, %22 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%21 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%24 = arith.mulf %in, %in_0 : f16
%25 = arith.addf %24, %out : f16
linalg.yield %25 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %23, %14, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
// -----// IR Dump After DecomposeSoftmax (iree-linalg-ext-decompose-softmax) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%15 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%16 = flow.dispatch.tensor.load %11, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%17 = flow.dispatch.tensor.load %12, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%18 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%19 = tensor.empty() : tensor<4096xf16>
%20 = tensor.empty() : tensor<4096x32x128xf16>
%21 = linalg.fill ins(%cst : f16) outs(%19 : tensor<4096xf16>) -> tensor<4096xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %16, %17 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%20 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%24 = arith.extui %in : i4 to i32
%25 = arith.uitofp %24 : i32 to f16
%26 = arith.subf %25, %in_1 : f16
%27 = arith.mulf %26, %in_0 : f16
linalg.yield %27 : f16
} -> tensor<4096x32x128xf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%18, %22 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%21 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%24 = arith.mulf %in, %in_0 : f16
%25 = arith.addf %24, %out : f16
linalg.yield %25 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %23, %14, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After SPIRVGeneralizeNamedOps (iree-spirv-generalize-named-ops) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%15 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [4096, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4096x32x128xi4>
%16 = flow.dispatch.tensor.load %11, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%17 = flow.dispatch.tensor.load %12, offsets = [0, 0], sizes = [4096, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4096x32xf16>
%18 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%19 = tensor.empty() : tensor<4096xf16>
%20 = tensor.empty() : tensor<4096x32x128xf16>
%21 = linalg.fill ins(%cst : f16) outs(%19 : tensor<4096xf16>) -> tensor<4096xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %16, %17 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>) outs(%20 : tensor<4096x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%24 = arith.extui %in : i4 to i32
%25 = arith.uitofp %24 : i32 to f16
%26 = arith.subf %25, %in_1 : f16
%27 = arith.mulf %26, %in_0 : f16
linalg.yield %27 : f16
} -> tensor<4096x32x128xf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%18, %22 : tensor<32x128xf16>, tensor<4096x32x128xf16>) outs(%21 : tensor<4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%24 = arith.mulf %in, %in_0 : f16
%25 = arith.addf %24, %out : f16
linalg.yield %25 : f16
} -> tensor<4096xf16>
flow.dispatch.tensor.store %23, %14, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After TileAndDistributeToWorkgroups (iree-codegen-tile-and-distribute-to-workgroups) //----- //
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>}> {
hal.executable.export public @forward_dispatch_3_generic_4096x32x128_f16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 5, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<SPIRVMatvecPromoteSubgroupReduce>, workgroup_size = [64 : index, 4 : index, 1 : index]} {
^bb0(%arg0: !hal.device):
%c1024 = arith.constant 1024 : index
%c1 = arith.constant 1 : index
hal.return %c1024, %c1, %c1 : index, index, index
}
builtin.module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%15 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%16 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%17 = flow.dispatch.tensor.load %10, offsets = [%16, 0, 0], sizes = [%c4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<?x32x128xi4>
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%19 = flow.dispatch.tensor.load %11, offsets = [%18, 0], sizes = [%c4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<?x32xf16>
%20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%21 = flow.dispatch.tensor.load %12, offsets = [%20, 0], sizes = [%c4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<?x32xf16>
%22 = tensor.empty() : tensor<4x32x128xf16>
%cast = tensor.cast %17 : tensor<?x32x128xi4> to tensor<4x32x128xi4>
%cast_0 = tensor.cast %19 : tensor<?x32xf16> to tensor<4x32xf16>
%cast_1 = tensor.cast %21 : tensor<?x32xf16> to tensor<4x32xf16>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast, %cast_0, %cast_1 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%22 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_3: f16, %in_4: f16, %out: f16):
%28 = arith.extui %in : i4 to i32
%29 = arith.uitofp %28 : i32 to f16
%30 = arith.subf %29, %in_4 : f16
%31 = arith.mulf %30, %in_3 : f16
linalg.yield %31 : f16
} -> tensor<4x32x128xf16>
%24 = tensor.empty() : tensor<4xf16>
%25 = linalg.fill ins(%cst : f16) outs(%24 : tensor<4xf16>) -> tensor<4xf16>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%15, %23 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%25 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_3: f16, %out: f16):
%28 = arith.mulf %in, %in_3 : f16
%29 = arith.addf %28, %out : f16
linalg.yield %29 : f16
} -> tensor<4xf16>
%cast_2 = tensor.cast %26 : tensor<4xf16> to tensor<?xf16>
%27 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
flow.dispatch.tensor.store %cast_2, %14, offsets = [%27], sizes = [%c4], strides = [1] : tensor<?xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
}
// -----// IR Dump After ConvertToDestinationPassingStyle (iree-codegen-convert-to-destination-passing-style) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [%c4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<?xf16>
%cast = tensor.cast %16 : tensor<?xf16> to tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%workgroup_id_x_0 = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
%19 = flow.dispatch.tensor.load %10, offsets = [%18, 0, 0], sizes = [%c4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<?x32x128xi4>
%20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
%21 = flow.dispatch.tensor.load %11, offsets = [%20, 0], sizes = [%c4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<?x32xf16>
%22 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
%23 = flow.dispatch.tensor.load %12, offsets = [%22, 0], sizes = [%c4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<?x32xf16>
%24 = tensor.empty() : tensor<4x32x128xf16>
%cast_1 = tensor.cast %19 : tensor<?x32x128xi4> to tensor<4x32x128xi4>
%cast_2 = tensor.cast %21 : tensor<?x32xf16> to tensor<4x32xf16>
%cast_3 = tensor.cast %23 : tensor<?x32xf16> to tensor<4x32xf16>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1, %cast_2, %cast_3 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%24 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_5: f16, %in_6: f16, %out: f16):
%29 = arith.extui %in : i4 to i32
%30 = arith.uitofp %29 : i32 to f16
%31 = arith.subf %30, %in_6 : f16
%32 = arith.mulf %31, %in_5 : f16
linalg.yield %32 : f16
} -> tensor<4x32x128xf16>
%26 = linalg.fill ins(%cst : f16) outs(%cast : tensor<4xf16>) -> tensor<4xf16>
%27 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%17, %25 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%26 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_5: f16, %out: f16):
%29 = arith.mulf %in, %in_5 : f16
%30 = arith.addf %29, %out : f16
linalg.yield %30 : f16
} -> tensor<4xf16>
%cast_4 = tensor.cast %27 : tensor<4xf16> to tensor<?xf16>
%28 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
flow.dispatch.tensor.store %cast_4, %14, offsets = [%28], sizes = [%c4], strides = [1] : tensor<?xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%workgroup_id_x_0 = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
%19 = flow.dispatch.tensor.load %10, offsets = [%18, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
%21 = flow.dispatch.tensor.load %11, offsets = [%20, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%22 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
%23 = flow.dispatch.tensor.load %12, offsets = [%22, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%24 = tensor.empty() : tensor<4x32x128xf16>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%19, %21, %23 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%24 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_1: f16, %in_2: f16, %out: f16):
%29 = arith.extui %in : i4 to i32
%30 = arith.uitofp %29 : i32 to f16
%31 = arith.subf %30, %in_2 : f16
%32 = arith.mulf %31, %in_1 : f16
linalg.yield %32 : f16
} -> tensor<4x32x128xf16>
%26 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%27 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%17, %25 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%26 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_1: f16, %out: f16):
%29 = arith.mulf %in, %in_1 : f16
%30 = arith.addf %29, %out : f16
linalg.yield %30 : f16
} -> tensor<4xf16>
%28 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x_0]
flow.dispatch.tensor.store %27, %14, offsets = [%28], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%25 = arith.extui %in : i4 to i32
%26 = arith.uitofp %25 : i32 to f16
%27 = arith.subf %26, %in_1 : f16
%28 = arith.mulf %27, %in_0 : f16
linalg.yield %28 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%17, %22 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%23 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_0: f16, %out: f16):
%25 = arith.mulf %in, %in_0 : f16
%26 = arith.addf %25, %out : f16
linalg.yield %26 : f16
} -> tensor<4xf16>
flow.dispatch.tensor.store %24, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
// --- After promotion ---
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%24, %22 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%23 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_0: f16, %out: f16):
%26 = arith.mulf %in, %in_0 : f16
%27 = arith.addf %26, %out : f16
linalg.yield %27 : f16
} -> tensor<4xf16>
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After GPUTensorAlloc (iree-codegen-gpu-tensor-alloc) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%24, %22 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%23 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_0: f16, %out: f16):
%26 = arith.mulf %in, %in_0 : f16
%27 = arith.addf %26, %out : f16
linalg.yield %27 : f16
} -> tensor<4xf16>
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
initial GEMV op: %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%24, %22 : tensor<32x128xf16>, tensor<4x32x128xf16>) outs(%23 : tensor<4xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_0: f16, %out: f16):
%26 = arith.mulf %in, %in_0 : f16
%27 = arith.addf %26, %out : f16
linalg.yield %27 : f16
} -> tensor<4xf16>
initial workgroup size: [64 : index, 4 : index, 1 : index]
//--- After tiling parallel dimensions ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%c4 = arith.constant 4 : index
%25 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %23) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %24[0, 0] [32, 128] [1, 1] : tensor<32x128xf16> to tensor<32x128xf16>
%extracted_slice_0 = tensor.extract_slice %22[%arg0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x32x128xf16>
%extracted_slice_1 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%extracted_slice, %extracted_slice_0 : tensor<32x128xf16>, tensor<1x32x128xf16>) outs(%extracted_slice_1 : tensor<1xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_2: f16, %out: f16):
%27 = arith.mulf %in, %in_2 : f16
%28 = arith.addf %27, %out : f16
linalg.yield %28 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %26 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
//--- After canonicalization ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%25 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %23) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %22[%arg0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x32x128xf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%24, %extracted_slice : tensor<32x128xf16>, tensor<1x32x128xf16>) outs(%extracted_slice_0 : tensor<1xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[4], [0, 1, 8]]>} {
^bb0(%in: f16, %in_1: f16, %out: f16):
%27 = arith.mulf %in, %in_1 : f16
%28 = arith.addf %27, %out : f16
linalg.yield %28 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %26 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
//--- After tiling reduction dimensions ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%25 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %23) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %22[%arg0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x32x128xf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c0_1 = arith.constant 0 : index
%c1_2 = arith.constant 1 : index
%c8_3 = arith.constant 8 : index
%26 = tensor.empty() : tensor<1x1x8xf16>
%cst_4 = arith.constant 0.000000e+00 : f16
%27 = linalg.fill ins(%cst_4 : f16) outs(%26 : tensor<1x1x8xf16>) -> tensor<1x1x8xf16>
%c0_5 = arith.constant 0 : index
%c1_6 = arith.constant 1 : index
%c0_7 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%28 = scf.for %arg2 = %c0_7 to %c32 step %c1 iter_args(%arg3 = %27) -> (tensor<1x1x8xf16>) {
%c0_8 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%30 = scf.for %arg4 = %c0_8 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_9 = tensor.extract_slice %24[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_10 = tensor.extract_slice %extracted_slice[0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<1x32x128xf16> to tensor<1x1x8xf16>
%extracted_slice_11 = tensor.extract_slice %arg5[0, 0, 0] [%c1_6, 1, %c8] [1, 1, 1] : tensor<1x1x8xf16> to tensor<?x1x?xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_9, %extracted_slice_10 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%extracted_slice_11 : tensor<?x1x?xf16>) {
^bb0(%in: f16, %in_14: f16, %out: f16):
%32 = arith.mulf %in, %in_14 : f16
%33 = arith.addf %32, %out : f16
linalg.yield %33 : f16
} -> tensor<?x1x?xf16>
%c0_12 = arith.constant 0 : index
%dim = tensor.dim %31, %c0_12 : tensor<?x1x?xf16>
%c2 = arith.constant 2 : index
%dim_13 = tensor.dim %31, %c2 : tensor<?x1x?xf16>
%inserted_slice = tensor.insert_slice %31 into %arg5[0, 0, 0] [%dim, 1, %dim_13] [1, 1, 1] : tensor<?x1x?xf16> into tensor<1x1x8xf16>
scf.yield %inserted_slice : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %30 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%28 : tensor<1x1x8xf16>) outs(%extracted_slice_0 : tensor<1xf16>) {
^bb0(%in: f16, %out: f16):
%30 = arith.addf %in, %out : f16
linalg.yield %30 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %29 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
//--- After canonicalization ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%25 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %23) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%26 = tensor.empty() : tensor<1x1x8xf16>
%27 = linalg.fill ins(%cst : f16) outs(%26 : tensor<1x1x8xf16>) -> tensor<1x1x8xf16>
%28 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %27) -> (tensor<1x1x8xf16>) {
%30 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_0 = tensor.extract_slice %24[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_1 = tensor.extract_slice %22[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%arg5 : tensor<1x1x8xf16>) {
^bb0(%in: f16, %in_2: f16, %out: f16):
%32 = arith.mulf %in, %in_2 : f16
%33 = arith.addf %32, %out : f16
linalg.yield %33 : f16
} -> tensor<1x1x8xf16>
scf.yield %31 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %30 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%28 : tensor<1x1x8xf16>) outs(%extracted_slice : tensor<1xf16>) {
^bb0(%in: f16, %out: f16):
%30 = arith.addf %in, %out : f16
linalg.yield %30 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %29 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
current gemv op: %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%arg5 : tensor<1x1x8xf16>) {
^bb0(%in: f16, %in_2: f16, %out: f16):
%32 = arith.mulf %in, %in_2 : f16
%33 = arith.addf %32, %out : f16
linalg.yield %33 : f16
} -> tensor<1x1x8xf16>
looking at consumer op: %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%arg5 : tensor<1x1x8xf16>) {
^bb0(%in: f16, %in_2: f16, %out: f16):
%32 = arith.mulf %in, %in_2 : f16
%33 = arith.addf %32, %out : f16
linalg.yield %33 : f16
} -> tensor<1x1x8xf16>
looking at operand: %extracted_slice_0 = tensor.extract_slice %24[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
looking at operand: %extracted_slice_1 = tensor.extract_slice %22[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
yes
processing slice: %extracted_slice_1 = tensor.extract_slice %22[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
//--- After fusing producer ops ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %19, %20 : tensor<4x32x128xi4>, tensor<4x32xf16>, tensor<4x32xf16>) outs(%21 : tensor<4x32x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%26 = arith.extui %in : i4 to i32
%27 = arith.uitofp %26 : i32 to f16
%28 = arith.subf %27, %in_1 : f16
%29 = arith.mulf %28, %in_0 : f16
linalg.yield %29 : f16
} -> tensor<4x32x128xf16>
%23 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%24 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%25 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %23) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%26 = tensor.empty() : tensor<1x1x8xf16>
%27 = linalg.fill ins(%cst : f16) outs(%26 : tensor<1x1x8xf16>) -> tensor<1x1x8xf16>
%28 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %27) -> (tensor<1x1x8xf16>) {
%30 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_0 = tensor.extract_slice %24[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_1 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%extracted_slice_2 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_3 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_4 = tensor.extract_slice %21[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1, %extracted_slice_2, %extracted_slice_3 : tensor<1x1x8xi4>, tensor<1x1xf16>, tensor<1x1xf16>) outs(%extracted_slice_4 : tensor<1x1x8xf16>) {
^bb0(%in: i4, %in_6: f16, %in_7: f16, %out: f16):
%33 = arith.extui %in : i4 to i32
%34 = arith.uitofp %33 : i32 to f16
%35 = arith.subf %34, %in_7 : f16
%36 = arith.mulf %35, %in_6 : f16
linalg.yield %36 : f16
} -> tensor<1x1x8xf16>
%extracted_slice_5 = tensor.extract_slice %22[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
%32 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_0, %31 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%arg5 : tensor<1x1x8xf16>) {
^bb0(%in: f16, %in_6: f16, %out: f16):
%33 = arith.mulf %in, %in_6 : f16
%34 = arith.addf %33, %out : f16
linalg.yield %34 : f16
} -> tensor<1x1x8xf16>
scf.yield %32 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %30 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%28 : tensor<1x1x8xf16>) outs(%extracted_slice : tensor<1xf16>) {
^bb0(%in: f16, %out: f16):
%30 = arith.addf %in, %out : f16
linalg.yield %30 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %29 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %25, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
//--- After canonicalization ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%23 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%24 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %22) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%25 = tensor.empty() : tensor<1x1x8xf16>
%26 = linalg.fill ins(%cst : f16) outs(%25 : tensor<1x1x8xf16>) -> tensor<1x1x8xf16>
%27 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %26) -> (tensor<1x1x8xf16>) {
%29 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_0 = tensor.extract_slice %23[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_1 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%extracted_slice_2 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_3 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_4 = tensor.extract_slice %21[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
%30 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1, %extracted_slice_2, %extracted_slice_3 : tensor<1x1x8xi4>, tensor<1x1xf16>, tensor<1x1xf16>) outs(%extracted_slice_4 : tensor<1x1x8xf16>) {
^bb0(%in: i4, %in_5: f16, %in_6: f16, %out: f16):
%32 = arith.extui %in : i4 to i32
%33 = arith.uitofp %32 : i32 to f16
%34 = arith.subf %33, %in_6 : f16
%35 = arith.mulf %34, %in_5 : f16
linalg.yield %35 : f16
} -> tensor<1x1x8xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_0, %30 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%arg5 : tensor<1x1x8xf16>) {
^bb0(%in: f16, %in_5: f16, %out: f16):
%32 = arith.mulf %in, %in_5 : f16
%33 = arith.addf %32, %out : f16
linalg.yield %33 : f16
} -> tensor<1x1x8xf16>
scf.yield %31 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %29 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%28 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%27 : tensor<1x1x8xf16>) outs(%extracted_slice : tensor<1xf16>) {
^bb0(%in: f16, %out: f16):
%29 = arith.addf %in, %out : f16
linalg.yield %29 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %28 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %24, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After SPIRVTileGEMV (iree-spirv-tile-gemv) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = tensor.empty() : tensor<4x32x128xf16>
%22 = linalg.fill ins(%cst : f16) outs(%16 : tensor<4xf16>) -> tensor<4xf16>
%23 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%24 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %22) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%25 = tensor.empty() : tensor<1x1x8xf16>
%26 = linalg.fill ins(%cst : f16) outs(%25 : tensor<1x1x8xf16>) -> tensor<1x1x8xf16>
%27 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %26) -> (tensor<1x1x8xf16>) {
%29 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_0 = tensor.extract_slice %23[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_1 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%extracted_slice_2 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_3 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_4 = tensor.extract_slice %21[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xf16> to tensor<1x1x8xf16>
%30 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1, %extracted_slice_2, %extracted_slice_3 : tensor<1x1x8xi4>, tensor<1x1xf16>, tensor<1x1xf16>) outs(%extracted_slice_4 : tensor<1x1x8xf16>) {
^bb0(%in: i4, %in_5: f16, %in_6: f16, %out: f16):
%32 = arith.extui %in : i4 to i32
%33 = arith.uitofp %32 : i32 to f16
%34 = arith.subf %33, %in_6 : f16
%35 = arith.mulf %34, %in_5 : f16
linalg.yield %35 : f16
} -> tensor<1x1x8xf16>
%31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_0, %30 : tensor<1x8xf16>, tensor<1x1x8xf16>) outs(%arg5 : tensor<1x1x8xf16>) {
^bb0(%in: f16, %in_5: f16, %out: f16):
%32 = arith.mulf %in, %in_5 : f16
%33 = arith.addf %32, %out : f16
linalg.yield %33 : f16
} -> tensor<1x1x8xf16>
scf.yield %31 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %29 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%28 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%27 : tensor<1x1x8xf16>) outs(%extracted_slice : tensor<1xf16>) {
^bb0(%in: f16, %out: f16):
%29 = arith.addf %in, %out : f16
linalg.yield %29 : f16
} -> tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %28 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %24, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0_i4 = arith.constant 0 : i4
%cst = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_0, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = tensor.empty() : tensor<1x1x8xf16>
%25 = vector.transfer_write %cst, %24[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
%26 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %25) -> (tensor<1x1x8xf16>) {
%31 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_2 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_3 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%extracted_slice_4 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_5 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%32 = vector.transfer_read %extracted_slice_3[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%33 = vector.transfer_read %extracted_slice_4[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%34 = vector.broadcast %33 : vector<1x1xf16> to vector<8x1x1xf16>
%35 = vector.transpose %34, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%36 = vector.transfer_read %extracted_slice_5[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%37 = vector.broadcast %36 : vector<1x1xf16> to vector<8x1x1xf16>
%38 = vector.transpose %37, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%39 = arith.extui %32 : vector<1x1x8xi4> to vector<1x1x8xi32>
%40 = arith.uitofp %39 : vector<1x1x8xi32> to vector<1x1x8xf16>
%41 = arith.subf %40, %38 : vector<1x1x8xf16>
%42 = arith.mulf %41, %35 : vector<1x1x8xf16>
%43 = vector.transfer_read %extracted_slice_2[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%44 = vector.broadcast %43 : vector<1x8xf16> to vector<1x1x8xf16>
%45 = vector.transfer_read %arg5[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%46 = arith.mulf %44, %42 : vector<1x1x8xf16>
%47 = arith.addf %46, %45 : vector<1x1x8xf16>
%48 = vector.transfer_write %47, %arg5[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
scf.yield %48 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %31 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%27 = vector.transfer_read %26[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%28 = vector.transfer_read %extracted_slice[%c0], %cst_1 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%29 = vector.multi_reduction <add>, %27, %28 [1, 2] : vector<1x1x8xf16> to vector<1xf16>
%30 = vector.transfer_write %29, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %30 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
//--- After preprocessing reduction ---//
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0_i4 = arith.constant 0 : i4
%cst = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_0, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = tensor.empty() : tensor<1x1x8xf16>
%25 = vector.transfer_write %cst, %24[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
%26 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %25) -> (tensor<1x1x8xf16>) {
%34 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_2 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_3 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%extracted_slice_4 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_5 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%35 = vector.transfer_read %extracted_slice_3[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%36 = vector.transfer_read %extracted_slice_4[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%37 = vector.broadcast %36 : vector<1x1xf16> to vector<8x1x1xf16>
%38 = vector.transpose %37, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%39 = vector.transfer_read %extracted_slice_5[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%40 = vector.broadcast %39 : vector<1x1xf16> to vector<8x1x1xf16>
%41 = vector.transpose %40, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%42 = arith.extui %35 : vector<1x1x8xi4> to vector<1x1x8xi32>
%43 = arith.uitofp %42 : vector<1x1x8xi32> to vector<1x1x8xf16>
%44 = arith.subf %43, %41 : vector<1x1x8xf16>
%45 = arith.mulf %44, %38 : vector<1x1x8xf16>
%46 = vector.transfer_read %extracted_slice_2[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%47 = vector.broadcast %46 : vector<1x8xf16> to vector<1x1x8xf16>
%48 = vector.transfer_read %arg5[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%49 = arith.mulf %47, %45 : vector<1x1x8xf16>
%50 = arith.addf %49, %48 : vector<1x1x8xf16>
%51 = vector.transfer_write %50, %arg5[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
scf.yield %51 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %34 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%27 = vector.transfer_read %26[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%28 = vector.transfer_read %extracted_slice[%c0], %cst_1 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%29 = vector.extract %27[0, 0] : vector<1x1x8xf16>
%30 = vector.extract %28[0] : vector<1xf16>
%31 = vector.reduction <add>, %29, %30 : vector<8xf16> into f16
%32 = vector.broadcast %31 : f16 to vector<1xf16>
%33 = vector.transfer_write %32, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %33 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
reduction op: %31 = vector.reduction <add>, %29, %30 : vector<8xf16> into f16
initial workgroup size: [64 : index, 4 : index, 1 : index]
// -----// IR Dump After SPIRVVectorizeGEMV (iree-spirv-vectorize-gemv) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0_i4 = arith.constant 0 : i4
%cst = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_0, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = tensor.empty() : tensor<1x1x8xf16>
%25 = vector.transfer_write %cst, %24[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
%26 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %25) -> (tensor<1x1x8xf16>) {
%72 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (tensor<1x1x8xf16>) {
%extracted_slice_18 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_19 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%extracted_slice_20 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_21 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%73 = vector.transfer_read %extracted_slice_19[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%74 = vector.transfer_read %extracted_slice_20[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%75 = vector.broadcast %74 : vector<1x1xf16> to vector<8x1x1xf16>
%76 = vector.transpose %75, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%77 = vector.transfer_read %extracted_slice_21[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%78 = vector.broadcast %77 : vector<1x1xf16> to vector<8x1x1xf16>
%79 = vector.transpose %78, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%80 = arith.extui %73 : vector<1x1x8xi4> to vector<1x1x8xi32>
%81 = arith.uitofp %80 : vector<1x1x8xi32> to vector<1x1x8xf16>
%82 = arith.subf %81, %79 : vector<1x1x8xf16>
%83 = arith.mulf %82, %76 : vector<1x1x8xf16>
%84 = vector.transfer_read %extracted_slice_18[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%85 = vector.broadcast %84 : vector<1x8xf16> to vector<1x1x8xf16>
%86 = vector.transfer_read %arg5[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%87 = arith.mulf %85, %83 : vector<1x1x8xf16>
%88 = arith.addf %87, %86 : vector<1x1x8xf16>
%89 = vector.transfer_write %88, %arg5[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
scf.yield %89 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 0 : index}
scf.yield %72 : tensor<1x1x8xf16>
} {iree.spirv.distribute_delinearize_x = 1 : index}
%27 = vector.transfer_read %26[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%28 = vector.transfer_read %extracted_slice[%c0], %cst_1 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%29 = vector.extract %27[0, 0] : vector<1x1x8xf16>
%30 = vector.extract %28[0] : vector<1xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%31 = vector.extract_strided_slice %29 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %cst_2 [0] : f16 into vector<2xf16>
%34 = vector.extract_strided_slice %29 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%35 = vector.reduction <add>, %34 : vector<4xf16> into f16
%36 = vector.insert %35, %33 [1] : f16 into vector<2xf16>
%37 = vector.bitcast %36 : vector<2xf16> to vector<1xi32>
%38 = vector.extract %37[0] : vector<1xi32>
%c1_i32 = arith.constant 1 : i32
%c64_i32 = arith.constant 64 : i32
%shuffleResult, %valid = gpu.shuffle xor %38, %c1_i32, %c64_i32 : i32
%39 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%40 = vector.bitcast %39 : vector<1xi32> to vector<2xf16>
%41 = arith.addf %36, %40 : vector<2xf16>
%42 = vector.bitcast %41 : vector<2xf16> to vector<1xi32>
%43 = vector.extract %42[0] : vector<1xi32>
%c2_i32 = arith.constant 2 : i32
%c64_i32_3 = arith.constant 64 : i32
%shuffleResult_4, %valid_5 = gpu.shuffle xor %43, %c2_i32, %c64_i32_3 : i32
%44 = vector.broadcast %shuffleResult_4 : i32 to vector<1xi32>
%45 = vector.bitcast %44 : vector<1xi32> to vector<2xf16>
%46 = arith.addf %41, %45 : vector<2xf16>
%47 = vector.bitcast %46 : vector<2xf16> to vector<1xi32>
%48 = vector.extract %47[0] : vector<1xi32>
%c4_i32 = arith.constant 4 : i32
%c64_i32_6 = arith.constant 64 : i32
%shuffleResult_7, %valid_8 = gpu.shuffle xor %48, %c4_i32, %c64_i32_6 : i32
%49 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%50 = vector.bitcast %49 : vector<1xi32> to vector<2xf16>
%51 = arith.addf %46, %50 : vector<2xf16>
%52 = vector.bitcast %51 : vector<2xf16> to vector<1xi32>
%53 = vector.extract %52[0] : vector<1xi32>
%c8_i32 = arith.constant 8 : i32
%c64_i32_9 = arith.constant 64 : i32
%shuffleResult_10, %valid_11 = gpu.shuffle xor %53, %c8_i32, %c64_i32_9 : i32
%54 = vector.broadcast %shuffleResult_10 : i32 to vector<1xi32>
%55 = vector.bitcast %54 : vector<1xi32> to vector<2xf16>
%56 = arith.addf %51, %55 : vector<2xf16>
%57 = vector.bitcast %56 : vector<2xf16> to vector<1xi32>
%58 = vector.extract %57[0] : vector<1xi32>
%c16_i32 = arith.constant 16 : i32
%c64_i32_12 = arith.constant 64 : i32
%shuffleResult_13, %valid_14 = gpu.shuffle xor %58, %c16_i32, %c64_i32_12 : i32
%59 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%60 = vector.bitcast %59 : vector<1xi32> to vector<2xf16>
%61 = arith.addf %56, %60 : vector<2xf16>
%62 = vector.bitcast %61 : vector<2xf16> to vector<1xi32>
%63 = vector.extract %62[0] : vector<1xi32>
%c32_i32 = arith.constant 32 : i32
%c64_i32_15 = arith.constant 64 : i32
%shuffleResult_16, %valid_17 = gpu.shuffle xor %63, %c32_i32, %c64_i32_15 : i32
%64 = vector.broadcast %shuffleResult_16 : i32 to vector<1xi32>
%65 = vector.bitcast %64 : vector<1xi32> to vector<2xf16>
%66 = arith.addf %61, %65 : vector<2xf16>
%67 = vector.reduction <add>, %66 : vector<2xf16> into f16
%68 = arith.addf %67, %30 : f16
%69 = vector.reduction <add>, %29, %30 : vector<8xf16> into f16
%70 = vector.broadcast %68 : f16 to vector<1xf16>
%71 = vector.transfer_write %70, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %71 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After HoistRedundantVectorTransfers (iree-codegen-hoist-redundant-vector-transfers) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0_i4 = arith.constant 0 : i4
%cst = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_0, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = tensor.empty() : tensor<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c1_i32 = arith.constant 1 : i32
%c64_i32 = arith.constant 64 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32_3 = arith.constant 64 : i32
%c4_i32 = arith.constant 4 : i32
%c64_i32_4 = arith.constant 64 : i32
%c8_i32 = arith.constant 8 : i32
%c64_i32_5 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c64_i32_6 = arith.constant 64 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32_7 = arith.constant 64 : i32
%24 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%25 = vector.transfer_write %cst, %23[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
%26 = vector.transfer_read %25[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%27:2 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %25, %arg4 = %26) -> (tensor<1x1x8xf16>, vector<1x1x8xf16>) {
%extracted_slice_18 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_19 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%74:2 = scf.for %arg5 = %c0 to %c128 step %c8 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (tensor<1x1x8xf16>, vector<1x1x8xf16>) {
%extracted_slice_20 = tensor.extract_slice %22[%arg2, %arg5] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_21 = tensor.extract_slice %18[%arg0, %arg2, %arg5] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%75 = vector.transfer_read %extracted_slice_21[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%76 = vector.transfer_read %extracted_slice_18[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%77 = vector.broadcast %76 : vector<1x1xf16> to vector<8x1x1xf16>
%78 = vector.transpose %77, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%79 = vector.transfer_read %extracted_slice_19[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%80 = vector.broadcast %79 : vector<1x1xf16> to vector<8x1x1xf16>
%81 = vector.transpose %80, [1, 2, 0] : vector<8x1x1xf16> to vector<1x1x8xf16>
%82 = arith.extui %75 : vector<1x1x8xi4> to vector<1x1x8xi32>
%83 = arith.uitofp %82 : vector<1x1x8xi32> to vector<1x1x8xf16>
%84 = arith.subf %83, %81 : vector<1x1x8xf16>
%85 = arith.mulf %84, %78 : vector<1x1x8xf16>
%86 = vector.transfer_read %extracted_slice_20[%c0, %c0], %cst_1 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%87 = vector.broadcast %86 : vector<1x8xf16> to vector<1x1x8xf16>
%88 = arith.mulf %87, %85 : vector<1x1x8xf16>
%89 = arith.addf %88, %arg7 : vector<1x1x8xf16>
scf.yield %arg6, %89 : tensor<1x1x8xf16>, vector<1x1x8xf16>
}
scf.yield %74#0, %74#1 : tensor<1x1x8xf16>, vector<1x1x8xf16>
}
%28 = vector.transfer_write %27#1, %27#0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x1x8xf16>, tensor<1x1x8xf16>
%29 = vector.transfer_read %28[%c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true]} : tensor<1x1x8xf16>, vector<1x1x8xf16>
%30 = vector.transfer_read %extracted_slice[%c0], %cst_1 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%31 = vector.extract %29[0, 0] : vector<1x1x8xf16>
%32 = vector.extract %30[0] : vector<1xf16>
%33 = vector.extract_strided_slice %31 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %cst_2 [0] : f16 into vector<2xf16>
%36 = vector.extract_strided_slice %31 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%37 = vector.reduction <add>, %36 : vector<4xf16> into f16
%38 = vector.insert %37, %35 [1] : f16 into vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %40, %c1_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %45, %c2_i32, %c64_i32_3 : i32
%46 = vector.broadcast %shuffleResult_8 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %50, %c4_i32, %c64_i32_4 : i32
%51 = vector.broadcast %shuffleResult_10 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %55, %c8_i32, %c64_i32_5 : i32
%56 = vector.broadcast %shuffleResult_12 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_14, %valid_15 = gpu.shuffle xor %60, %c16_i32, %c64_i32_6 : i32
%61 = vector.broadcast %shuffleResult_14 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_16, %valid_17 = gpu.shuffle xor %65, %c32_i32, %c64_i32_7 : i32
%66 = vector.broadcast %shuffleResult_16 : i32 to vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.reduction <add>, %68 : vector<2xf16> into f16
%70 = arith.addf %69, %32 : f16
%71 = vector.reduction <add>, %31, %32 : vector<8xf16> into f16
%72 = vector.broadcast %70 : f16 to vector<1xf16>
%73 = vector.transfer_write %72, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %73 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %24, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_1, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst_0) -> (vector<1x1x8xf16>) {
%extracted_slice_13 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_14 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%68 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (vector<1x1x8xf16>) {
%extracted_slice_15 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_16 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%69 = vector.transfer_read %extracted_slice_16[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%70 = vector.transfer_read %extracted_slice_13[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%71 = vector.broadcast %70 : vector<1x1xf16> to vector<1x1x8xf16>
%72 = vector.transfer_read %extracted_slice_14[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%73 = vector.broadcast %72 : vector<1x1xf16> to vector<1x1x8xf16>
%74 = arith.extui %69 : vector<1x1x8xi4> to vector<1x1x8xi32>
%75 = arith.uitofp %74 : vector<1x1x8xi32> to vector<1x1x8xf16>
%76 = arith.subf %75, %73 : vector<1x1x8xf16>
%77 = arith.mulf %76, %71 : vector<1x1x8xf16>
%78 = vector.transfer_read %extracted_slice_15[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%79 = vector.broadcast %78 : vector<1x8xf16> to vector<1x1x8xf16>
%80 = arith.mulf %79, %77 : vector<1x1x8xf16>
%81 = arith.addf %80, %arg5 : vector<1x1x8xf16>
scf.yield %81 : vector<1x1x8xf16>
}
scf.yield %68 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %extracted_slice[%c0], %cst_2 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_3 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_5 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.broadcast %65 : f16 to vector<1xf16>
%67 = vector.transfer_write %66, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %67 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After CSE (cse) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_1, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst_0) -> (vector<1x1x8xf16>) {
%extracted_slice_13 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_14 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%68 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (vector<1x1x8xf16>) {
%extracted_slice_15 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_16 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%69 = vector.transfer_read %extracted_slice_16[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%70 = vector.transfer_read %extracted_slice_13[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%71 = vector.broadcast %70 : vector<1x1xf16> to vector<1x1x8xf16>
%72 = vector.transfer_read %extracted_slice_14[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%73 = vector.broadcast %72 : vector<1x1xf16> to vector<1x1x8xf16>
%74 = arith.extui %69 : vector<1x1x8xi4> to vector<1x1x8xi32>
%75 = arith.uitofp %74 : vector<1x1x8xi32> to vector<1x1x8xf16>
%76 = arith.subf %75, %73 : vector<1x1x8xf16>
%77 = arith.mulf %76, %71 : vector<1x1x8xf16>
%78 = vector.transfer_read %extracted_slice_15[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%79 = vector.broadcast %78 : vector<1x8xf16> to vector<1x1x8xf16>
%80 = arith.mulf %79, %77 : vector<1x1x8xf16>
%81 = arith.addf %80, %arg5 : vector<1x1x8xf16>
scf.yield %81 : vector<1x1x8xf16>
}
scf.yield %68 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %extracted_slice[%c0], %cst_2 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_3 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_5 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.broadcast %65 : f16 to vector<1xf16>
%67 = vector.transfer_write %66, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %67 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
// -----// IR Dump After EliminateEmptyTensors (iree-eliminate-empty-tensors) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_1, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst_0) -> (vector<1x1x8xf16>) {
%extracted_slice_13 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_14 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%68 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (vector<1x1x8xf16>) {
%extracted_slice_15 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_16 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%69 = vector.transfer_read %extracted_slice_16[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%70 = vector.transfer_read %extracted_slice_13[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%71 = vector.broadcast %70 : vector<1x1xf16> to vector<1x1x8xf16>
%72 = vector.transfer_read %extracted_slice_14[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%73 = vector.broadcast %72 : vector<1x1xf16> to vector<1x1x8xf16>
%74 = arith.extui %69 : vector<1x1x8xi4> to vector<1x1x8xi32>
%75 = arith.uitofp %74 : vector<1x1x8xi32> to vector<1x1x8xf16>
%76 = arith.subf %75, %73 : vector<1x1x8xf16>
%77 = arith.mulf %76, %71 : vector<1x1x8xf16>
%78 = vector.transfer_read %extracted_slice_15[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%79 = vector.broadcast %78 : vector<1x8xf16> to vector<1x1x8xf16>
%80 = arith.mulf %79, %77 : vector<1x1x8xf16>
%81 = arith.addf %80, %arg5 : vector<1x1x8xf16>
scf.yield %81 : vector<1x1x8xf16>
}
scf.yield %68 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %extracted_slice[%c0], %cst_2 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_3 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_5 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.broadcast %65 : f16 to vector<1xf16>
%67 = vector.transfer_write %66, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %67 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
// -----// IR Dump After EmptyTensorToAllocTensor (empty-tensor-to-alloc-tensor) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xf16>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%16 = flow.dispatch.tensor.load %14, offsets = [%15], sizes = [4], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<4096xf16>> -> tensor<4xf16>
%17 = flow.dispatch.tensor.load %13, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xf16>> -> tensor<32x128xf16>
%18 = flow.dispatch.tensor.load %10, offsets = [%15, 0, 0], sizes = [4, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32x128xi4>> -> tensor<4x32x128xi4>
%19 = flow.dispatch.tensor.load %11, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%20 = flow.dispatch.tensor.load %12, offsets = [%15, 0], sizes = [4, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x32xf16>> -> tensor<4x32xf16>
%21 = vector.transfer_write %cst_1, %16[%c0] {in_bounds = [true]} : vector<4xf16>, tensor<4xf16>
%22 = bufferization.alloc_tensor() copy(%17) {bufferization.escape = [false]} : tensor<32x128xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %21) -> (tensor<4xf16>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<4xf16> to tensor<1xf16>
%24 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst_0) -> (vector<1x1x8xf16>) {
%extracted_slice_13 = tensor.extract_slice %19[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%extracted_slice_14 = tensor.extract_slice %20[%arg0, %arg2] [1, 1] [1, 1] : tensor<4x32xf16> to tensor<1x1xf16>
%68 = scf.for %arg4 = %c0 to %c128 step %c8 iter_args(%arg5 = %arg3) -> (vector<1x1x8xf16>) {
%extracted_slice_15 = tensor.extract_slice %22[%arg2, %arg4] [1, 8] [1, 1] : tensor<32x128xf16> to tensor<1x8xf16>
%extracted_slice_16 = tensor.extract_slice %18[%arg0, %arg2, %arg4] [1, 1, 8] [1, 1, 1] : tensor<4x32x128xi4> to tensor<1x1x8xi4>
%69 = vector.transfer_read %extracted_slice_16[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : tensor<1x1x8xi4>, vector<1x1x8xi4>
%70 = vector.transfer_read %extracted_slice_13[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%71 = vector.broadcast %70 : vector<1x1xf16> to vector<1x1x8xf16>
%72 = vector.transfer_read %extracted_slice_14[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x1xf16>, vector<1x1xf16>
%73 = vector.broadcast %72 : vector<1x1xf16> to vector<1x1x8xf16>
%74 = arith.extui %69 : vector<1x1x8xi4> to vector<1x1x8xi32>
%75 = arith.uitofp %74 : vector<1x1x8xi32> to vector<1x1x8xf16>
%76 = arith.subf %75, %73 : vector<1x1x8xf16>
%77 = arith.mulf %76, %71 : vector<1x1x8xf16>
%78 = vector.transfer_read %extracted_slice_15[%c0, %c0], %cst_2 {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
%79 = vector.broadcast %78 : vector<1x8xf16> to vector<1x1x8xf16>
%80 = arith.mulf %79, %77 : vector<1x1x8xf16>
%81 = arith.addf %80, %arg5 : vector<1x1x8xf16>
scf.yield %81 : vector<1x1x8xf16>
}
scf.yield %68 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %extracted_slice[%c0], %cst_2 {in_bounds = [true]} : tensor<1xf16>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_3 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_5 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.broadcast %65 : f16 to vector<1xf16>
%67 = vector.transfer_write %66, %extracted_slice[%c0] {in_bounds = [true]} : vector<1xf16>, tensor<1xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %67 into %arg1[%arg0] [1] [1] : tensor<1xf16> into tensor<4xf16>
}
} {mapping = [#gpu.thread<y>]}
flow.dispatch.tensor.store %23, %14, offsets = [%15], sizes = [4], strides = [1] : tensor<4xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096xf16>>
return
}
}
// -----// IR Dump After IREEComprehensiveBufferize (iree-codegen-iree-comprehensive-bufferize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.forall (%arg0) in (4) {
%subview_7 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %subview_4[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_20 = memref.subview %subview_5[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%59 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %arg2) -> (vector<1x1x8xf16>) {
%subview_21 = memref.subview %alloc[%arg1, %arg3] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_22 = memref.subview %subview_3[%arg0, %arg1, %arg3] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%60 = vector.transfer_read %subview_22[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%61 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%62 = vector.broadcast %61 : vector<1x1xf16> to vector<1x1x8xf16>
%63 = vector.transfer_read %subview_20[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%64 = vector.broadcast %63 : vector<1x1xf16> to vector<1x1x8xf16>
%65 = arith.extui %60 : vector<1x1x8xi4> to vector<1x1x8xi32>
%66 = arith.uitofp %65 : vector<1x1x8xi32> to vector<1x1x8xf16>
%67 = arith.subf %66, %64 : vector<1x1x8xf16>
%68 = arith.mulf %67, %62 : vector<1x1x8xf16>
%69 = vector.transfer_read %subview_21[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%70 = vector.broadcast %69 : vector<1x8xf16> to vector<1x1x8xf16>
%71 = arith.mulf %70, %68 : vector<1x1x8xf16>
%72 = arith.addf %71, %arg4 : vector<1x1x8xf16>
scf.yield %72 : vector<1x1x8xf16>
}
scf.yield %59 : vector<1x1x8xf16>
}
%17 = vector.transfer_read %subview_7[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%18 = vector.extract %16[0, 0] : vector<1x1x8xf16>
%19 = vector.extract %17[0] : vector<1xf16>
%20 = vector.extract_strided_slice %18 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%21 = vector.reduction <add>, %20 : vector<4xf16> into f16
%22 = vector.insert %21, %cst [0] : f16 into vector<2xf16>
%23 = vector.extract_strided_slice %18 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %22 [1] : f16 into vector<2xf16>
%26 = vector.bitcast %25 : vector<2xf16> to vector<1xi32>
%27 = vector.extract %26[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %27, %c1_i32, %c64_i32 : i32
%28 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%29 = vector.bitcast %28 : vector<1xi32> to vector<2xf16>
%30 = arith.addf %25, %29 : vector<2xf16>
%31 = vector.bitcast %30 : vector<2xf16> to vector<1xi32>
%32 = vector.extract %31[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %32, %c2_i32, %c64_i32 : i32
%33 = vector.broadcast %shuffleResult_8 : i32 to vector<1xi32>
%34 = vector.bitcast %33 : vector<1xi32> to vector<2xf16>
%35 = arith.addf %30, %34 : vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %37, %c4_i32, %c64_i32 : i32
%38 = vector.broadcast %shuffleResult_10 : i32 to vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %42, %c8_i32, %c64_i32 : i32
%43 = vector.broadcast %shuffleResult_12 : i32 to vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_14, %valid_15 = gpu.shuffle xor %47, %c16_i32, %c64_i32 : i32
%48 = vector.broadcast %shuffleResult_14 : i32 to vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_16, %valid_17 = gpu.shuffle xor %52, %c32_i32, %c64_i32 : i32
%53 = vector.broadcast %shuffleResult_16 : i32 to vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.reduction <add>, %55 : vector<2xf16> into f16
%57 = arith.addf %56, %19 : f16
%58 = vector.broadcast %57 : f16 to vector<1xf16>
vector.transfer_write %58, %subview_7[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview_7, %subview_18 : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
} {mapping = [#gpu.thread<y>]}
%subview_6 = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview, %subview_6 : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}
}
// -----// IR Dump After ResolveShapedTypeResultDims (resolve-shaped-type-result-dims) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.forall (%arg0) in (4) {
%subview_7 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %subview_4[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_20 = memref.subview %subview_5[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%59 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %arg2) -> (vector<1x1x8xf16>) {
%subview_21 = memref.subview %alloc[%arg1, %arg3] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_22 = memref.subview %subview_3[%arg0, %arg1, %arg3] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%60 = vector.transfer_read %subview_22[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%61 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%62 = vector.broadcast %61 : vector<1x1xf16> to vector<1x1x8xf16>
%63 = vector.transfer_read %subview_20[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%64 = vector.broadcast %63 : vector<1x1xf16> to vector<1x1x8xf16>
%65 = arith.extui %60 : vector<1x1x8xi4> to vector<1x1x8xi32>
%66 = arith.uitofp %65 : vector<1x1x8xi32> to vector<1x1x8xf16>
%67 = arith.subf %66, %64 : vector<1x1x8xf16>
%68 = arith.mulf %67, %62 : vector<1x1x8xf16>
%69 = vector.transfer_read %subview_21[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%70 = vector.broadcast %69 : vector<1x8xf16> to vector<1x1x8xf16>
%71 = arith.mulf %70, %68 : vector<1x1x8xf16>
%72 = arith.addf %71, %arg4 : vector<1x1x8xf16>
scf.yield %72 : vector<1x1x8xf16>
}
scf.yield %59 : vector<1x1x8xf16>
}
%17 = vector.transfer_read %subview_7[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%18 = vector.extract %16[0, 0] : vector<1x1x8xf16>
%19 = vector.extract %17[0] : vector<1xf16>
%20 = vector.extract_strided_slice %18 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%21 = vector.reduction <add>, %20 : vector<4xf16> into f16
%22 = vector.insert %21, %cst [0] : f16 into vector<2xf16>
%23 = vector.extract_strided_slice %18 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %22 [1] : f16 into vector<2xf16>
%26 = vector.bitcast %25 : vector<2xf16> to vector<1xi32>
%27 = vector.extract %26[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %27, %c1_i32, %c64_i32 : i32
%28 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%29 = vector.bitcast %28 : vector<1xi32> to vector<2xf16>
%30 = arith.addf %25, %29 : vector<2xf16>
%31 = vector.bitcast %30 : vector<2xf16> to vector<1xi32>
%32 = vector.extract %31[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %32, %c2_i32, %c64_i32 : i32
%33 = vector.broadcast %shuffleResult_8 : i32 to vector<1xi32>
%34 = vector.bitcast %33 : vector<1xi32> to vector<2xf16>
%35 = arith.addf %30, %34 : vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %37, %c4_i32, %c64_i32 : i32
%38 = vector.broadcast %shuffleResult_10 : i32 to vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %42, %c8_i32, %c64_i32 : i32
%43 = vector.broadcast %shuffleResult_12 : i32 to vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_14, %valid_15 = gpu.shuffle xor %47, %c16_i32, %c64_i32 : i32
%48 = vector.broadcast %shuffleResult_14 : i32 to vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_16, %valid_17 = gpu.shuffle xor %52, %c32_i32, %c64_i32 : i32
%53 = vector.broadcast %shuffleResult_16 : i32 to vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.reduction <add>, %55 : vector<2xf16> into f16
%57 = arith.addf %56, %19 : f16
%58 = vector.broadcast %57 : f16 to vector<1xf16>
vector.transfer_write %58, %subview_7[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview_7, %subview_18 : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
} {mapping = [#gpu.thread<y>]}
%subview_6 = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview, %subview_6 : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.forall (%arg0) in (4) {
%subview_7 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %subview_4[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_20 = memref.subview %subview_5[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%59 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %arg2) -> (vector<1x1x8xf16>) {
%subview_21 = memref.subview %alloc[%arg1, %arg3] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_22 = memref.subview %subview_3[%arg0, %arg1, %arg3] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%60 = vector.transfer_read %subview_22[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%61 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%62 = vector.broadcast %61 : vector<1x1xf16> to vector<1x1x8xf16>
%63 = vector.transfer_read %subview_20[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%64 = vector.broadcast %63 : vector<1x1xf16> to vector<1x1x8xf16>
%65 = arith.extui %60 : vector<1x1x8xi4> to vector<1x1x8xi32>
%66 = arith.uitofp %65 : vector<1x1x8xi32> to vector<1x1x8xf16>
%67 = arith.subf %66, %64 : vector<1x1x8xf16>
%68 = arith.mulf %67, %62 : vector<1x1x8xf16>
%69 = vector.transfer_read %subview_21[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%70 = vector.broadcast %69 : vector<1x8xf16> to vector<1x1x8xf16>
%71 = arith.mulf %70, %68 : vector<1x1x8xf16>
%72 = arith.addf %71, %arg4 : vector<1x1x8xf16>
scf.yield %72 : vector<1x1x8xf16>
}
scf.yield %59 : vector<1x1x8xf16>
}
%17 = vector.transfer_read %subview_7[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%18 = vector.extract %16[0, 0] : vector<1x1x8xf16>
%19 = vector.extract %17[0] : vector<1xf16>
%20 = vector.extract_strided_slice %18 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%21 = vector.reduction <add>, %20 : vector<4xf16> into f16
%22 = vector.insert %21, %cst [0] : f16 into vector<2xf16>
%23 = vector.extract_strided_slice %18 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %22 [1] : f16 into vector<2xf16>
%26 = vector.bitcast %25 : vector<2xf16> to vector<1xi32>
%27 = vector.extract %26[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %27, %c1_i32, %c64_i32 : i32
%28 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%29 = vector.bitcast %28 : vector<1xi32> to vector<2xf16>
%30 = arith.addf %25, %29 : vector<2xf16>
%31 = vector.bitcast %30 : vector<2xf16> to vector<1xi32>
%32 = vector.extract %31[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %32, %c2_i32, %c64_i32 : i32
%33 = vector.broadcast %shuffleResult_8 : i32 to vector<1xi32>
%34 = vector.bitcast %33 : vector<1xi32> to vector<2xf16>
%35 = arith.addf %30, %34 : vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %37, %c4_i32, %c64_i32 : i32
%38 = vector.broadcast %shuffleResult_10 : i32 to vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %42, %c8_i32, %c64_i32 : i32
%43 = vector.broadcast %shuffleResult_12 : i32 to vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_14, %valid_15 = gpu.shuffle xor %47, %c16_i32, %c64_i32 : i32
%48 = vector.broadcast %shuffleResult_14 : i32 to vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_16, %valid_17 = gpu.shuffle xor %52, %c32_i32, %c64_i32 : i32
%53 = vector.broadcast %shuffleResult_16 : i32 to vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.reduction <add>, %55 : vector<2xf16> into f16
%57 = arith.addf %56, %19 : f16
%58 = vector.broadcast %57 : f16 to vector<1xf16>
vector.transfer_write %58, %subview_7[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview_7, %subview_18 : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
} {mapping = [#gpu.thread<y>]}
%subview_6 = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview, %subview_6 : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}
// -----// IR Dump After CSE (cse) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.forall (%arg0) in (4) {
%subview_6 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_17 = memref.subview %subview_4[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview_5[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%59 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %arg2) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %alloc[%arg1, %arg3] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_20 = memref.subview %subview_3[%arg0, %arg1, %arg3] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%60 = vector.transfer_read %subview_20[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%61 = vector.transfer_read %subview_17[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%62 = vector.broadcast %61 : vector<1x1xf16> to vector<1x1x8xf16>
%63 = vector.transfer_read %subview_18[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%64 = vector.broadcast %63 : vector<1x1xf16> to vector<1x1x8xf16>
%65 = arith.extui %60 : vector<1x1x8xi4> to vector<1x1x8xi32>
%66 = arith.uitofp %65 : vector<1x1x8xi32> to vector<1x1x8xf16>
%67 = arith.subf %66, %64 : vector<1x1x8xf16>
%68 = arith.mulf %67, %62 : vector<1x1x8xf16>
%69 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%70 = vector.broadcast %69 : vector<1x8xf16> to vector<1x1x8xf16>
%71 = arith.mulf %70, %68 : vector<1x1x8xf16>
%72 = arith.addf %71, %arg4 : vector<1x1x8xf16>
scf.yield %72 : vector<1x1x8xf16>
}
scf.yield %59 : vector<1x1x8xf16>
}
%17 = vector.transfer_read %subview_6[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%18 = vector.extract %16[0, 0] : vector<1x1x8xf16>
%19 = vector.extract %17[0] : vector<1xf16>
%20 = vector.extract_strided_slice %18 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%21 = vector.reduction <add>, %20 : vector<4xf16> into f16
%22 = vector.insert %21, %cst [0] : f16 into vector<2xf16>
%23 = vector.extract_strided_slice %18 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %22 [1] : f16 into vector<2xf16>
%26 = vector.bitcast %25 : vector<2xf16> to vector<1xi32>
%27 = vector.extract %26[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %27, %c1_i32, %c64_i32 : i32
%28 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%29 = vector.bitcast %28 : vector<1xi32> to vector<2xf16>
%30 = arith.addf %25, %29 : vector<2xf16>
%31 = vector.bitcast %30 : vector<2xf16> to vector<1xi32>
%32 = vector.extract %31[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %32, %c2_i32, %c64_i32 : i32
%33 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%34 = vector.bitcast %33 : vector<1xi32> to vector<2xf16>
%35 = arith.addf %30, %34 : vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %37, %c4_i32, %c64_i32 : i32
%38 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %42, %c8_i32, %c64_i32 : i32
%43 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_13, %valid_14 = gpu.shuffle xor %47, %c16_i32, %c64_i32 : i32
%48 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %52, %c32_i32, %c64_i32 : i32
%53 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.reduction <add>, %55 : vector<2xf16> into f16
%57 = arith.addf %56, %19 : f16
%58 = vector.broadcast %57 : f16 to vector<1xf16>
vector.transfer_write %58, %subview_6[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview_6, %subview_6 : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
} {mapping = [#gpu.thread<y>]}
memref.copy %subview, %subview : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.forall (%arg0) in (4) {
%subview_6 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_17 = memref.subview %subview_4[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview_5[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%59 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %arg2) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %alloc[%arg1, %arg3] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_20 = memref.subview %subview_3[%arg0, %arg1, %arg3] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%60 = vector.transfer_read %subview_20[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%61 = vector.transfer_read %subview_17[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%62 = vector.broadcast %61 : vector<1x1xf16> to vector<1x1x8xf16>
%63 = vector.transfer_read %subview_18[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%64 = vector.broadcast %63 : vector<1x1xf16> to vector<1x1x8xf16>
%65 = arith.extui %60 : vector<1x1x8xi4> to vector<1x1x8xi32>
%66 = arith.uitofp %65 : vector<1x1x8xi32> to vector<1x1x8xf16>
%67 = arith.subf %66, %64 : vector<1x1x8xf16>
%68 = arith.mulf %67, %62 : vector<1x1x8xf16>
%69 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%70 = vector.broadcast %69 : vector<1x8xf16> to vector<1x1x8xf16>
%71 = arith.mulf %70, %68 : vector<1x1x8xf16>
%72 = arith.addf %71, %arg4 : vector<1x1x8xf16>
scf.yield %72 : vector<1x1x8xf16>
}
scf.yield %59 : vector<1x1x8xf16>
}
%17 = vector.transfer_read %subview_6[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%18 = vector.extract %16[0, 0] : vector<1x1x8xf16>
%19 = vector.extract %17[0] : vector<1xf16>
%20 = vector.extract_strided_slice %18 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%21 = vector.reduction <add>, %20 : vector<4xf16> into f16
%22 = vector.insert %21, %cst [0] : f16 into vector<2xf16>
%23 = vector.extract_strided_slice %18 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %22 [1] : f16 into vector<2xf16>
%26 = vector.bitcast %25 : vector<2xf16> to vector<1xi32>
%27 = vector.extract %26[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %27, %c1_i32, %c64_i32 : i32
%28 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%29 = vector.bitcast %28 : vector<1xi32> to vector<2xf16>
%30 = arith.addf %25, %29 : vector<2xf16>
%31 = vector.bitcast %30 : vector<2xf16> to vector<1xi32>
%32 = vector.extract %31[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %32, %c2_i32, %c64_i32 : i32
%33 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%34 = vector.bitcast %33 : vector<1xi32> to vector<2xf16>
%35 = arith.addf %30, %34 : vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %37, %c4_i32, %c64_i32 : i32
%38 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %42, %c8_i32, %c64_i32 : i32
%43 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_13, %valid_14 = gpu.shuffle xor %47, %c16_i32, %c64_i32 : i32
%48 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %52, %c32_i32, %c64_i32 : i32
%53 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.reduction <add>, %55 : vector<2xf16> into f16
%57 = arith.addf %56, %19 : f16
%58 = vector.broadcast %57 : f16 to vector<1xf16>
vector.transfer_write %58, %subview_6[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
} {mapping = [#gpu.thread<y>]}
return
}
// -----// IR Dump After CleanupBufferAllocView (iree-codegen-cleanup-buffer-alloc-view) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.forall (%arg0) in (4) {
%subview_6 = memref.subview %subview[%arg0] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_17 = memref.subview %subview_4[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview_5[%arg0, %arg1] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%59 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %arg2) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %alloc[%arg1, %arg3] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_20 = memref.subview %subview_3[%arg0, %arg1, %arg3] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%60 = vector.transfer_read %subview_20[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%61 = vector.transfer_read %subview_17[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%62 = vector.broadcast %61 : vector<1x1xf16> to vector<1x1x8xf16>
%63 = vector.transfer_read %subview_18[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%64 = vector.broadcast %63 : vector<1x1xf16> to vector<1x1x8xf16>
%65 = arith.extui %60 : vector<1x1x8xi4> to vector<1x1x8xi32>
%66 = arith.uitofp %65 : vector<1x1x8xi32> to vector<1x1x8xf16>
%67 = arith.subf %66, %64 : vector<1x1x8xf16>
%68 = arith.mulf %67, %62 : vector<1x1x8xf16>
%69 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%70 = vector.broadcast %69 : vector<1x8xf16> to vector<1x1x8xf16>
%71 = arith.mulf %70, %68 : vector<1x1x8xf16>
%72 = arith.addf %71, %arg4 : vector<1x1x8xf16>
scf.yield %72 : vector<1x1x8xf16>
}
scf.yield %59 : vector<1x1x8xf16>
}
%17 = vector.transfer_read %subview_6[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%18 = vector.extract %16[0, 0] : vector<1x1x8xf16>
%19 = vector.extract %17[0] : vector<1xf16>
%20 = vector.extract_strided_slice %18 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%21 = vector.reduction <add>, %20 : vector<4xf16> into f16
%22 = vector.insert %21, %cst [0] : f16 into vector<2xf16>
%23 = vector.extract_strided_slice %18 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %22 [1] : f16 into vector<2xf16>
%26 = vector.bitcast %25 : vector<2xf16> to vector<1xi32>
%27 = vector.extract %26[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %27, %c1_i32, %c64_i32 : i32
%28 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%29 = vector.bitcast %28 : vector<1xi32> to vector<2xf16>
%30 = arith.addf %25, %29 : vector<2xf16>
%31 = vector.bitcast %30 : vector<2xf16> to vector<1xi32>
%32 = vector.extract %31[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %32, %c2_i32, %c64_i32 : i32
%33 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%34 = vector.bitcast %33 : vector<1xi32> to vector<2xf16>
%35 = arith.addf %30, %34 : vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %37, %c4_i32, %c64_i32 : i32
%38 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %42, %c8_i32, %c64_i32 : i32
%43 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_13, %valid_14 = gpu.shuffle xor %47, %c16_i32, %c64_i32 : i32
%48 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %52, %c32_i32, %c64_i32 : i32
%53 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.reduction <add>, %55 : vector<2xf16> into f16
%57 = arith.addf %56, %19 : f16
%58 = vector.broadcast %57 : f16 to vector<1xf16>
vector.transfer_write %58, %subview_6[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
} {mapping = [#gpu.thread<y>]}
return
}
// -----// IR Dump After GPUDistribute (iree-codegen-gpu-distribute) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0 = arith.constant 0 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0_2 = arith.constant 0 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0_2] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %13, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
%16 = gpu.thread_id x
%17 = gpu.thread_id y
%18 = gpu.thread_id z
%c1_7 = arith.constant 1 : index
%19 = arith.cmpi ult, %16, %c1_7 : index
scf.if %19 {
%subview_8 = memref.subview %subview[%17] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = scf.for %arg0 = %c0_2 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %subview_5[%17, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_20 = memref.subview %subview_6[%17, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%63 = scf.for %arg2 = %c0_2 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_21 = memref.subview %alloc[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_22 = memref.subview %subview_4[%17, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%64 = vector.transfer_read %subview_22[%c0_2, %c0_2, %c0_2], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%65 = vector.transfer_read %subview_19[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%66 = vector.broadcast %65 : vector<1x1xf16> to vector<1x1x8xf16>
%67 = vector.transfer_read %subview_20[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%68 = vector.broadcast %67 : vector<1x1xf16> to vector<1x1x8xf16>
%69 = arith.extui %64 : vector<1x1x8xi4> to vector<1x1x8xi32>
%70 = arith.uitofp %69 : vector<1x1x8xi32> to vector<1x1x8xf16>
%71 = arith.subf %70, %68 : vector<1x1x8xf16>
%72 = arith.mulf %71, %66 : vector<1x1x8xf16>
%73 = vector.transfer_read %subview_21[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%74 = vector.broadcast %73 : vector<1x8xf16> to vector<1x1x8xf16>
%75 = arith.mulf %74, %72 : vector<1x1x8xf16>
%76 = arith.addf %75, %arg3 : vector<1x1x8xf16>
scf.yield %76 : vector<1x1x8xf16>
}
scf.yield %63 : vector<1x1x8xf16>
}
%21 = vector.transfer_read %subview_8[%c0_2], %cst_3 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%22 = vector.extract %20[0, 0] : vector<1x1x8xf16>
%23 = vector.extract %21[0] : vector<1xf16>
%24 = vector.extract_strided_slice %22 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%25 = vector.reduction <add>, %24 : vector<4xf16> into f16
%26 = vector.insert %25, %cst [0] : f16 into vector<2xf16>
%27 = vector.extract_strided_slice %22 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%28 = vector.reduction <add>, %27 : vector<4xf16> into f16
%29 = vector.insert %28, %26 [1] : f16 into vector<2xf16>
%30 = vector.bitcast %29 : vector<2xf16> to vector<1xi32>
%31 = vector.extract %30[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %31, %c1_i32, %c64_i32 : i32
%32 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%33 = vector.bitcast %32 : vector<1xi32> to vector<2xf16>
%34 = arith.addf %29, %33 : vector<2xf16>
%35 = vector.bitcast %34 : vector<2xf16> to vector<1xi32>
%36 = vector.extract %35[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %36, %c2_i32, %c64_i32 : i32
%37 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%38 = vector.bitcast %37 : vector<1xi32> to vector<2xf16>
%39 = arith.addf %34, %38 : vector<2xf16>
%40 = vector.bitcast %39 : vector<2xf16> to vector<1xi32>
%41 = vector.extract %40[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %41, %c4_i32, %c64_i32 : i32
%42 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%43 = vector.bitcast %42 : vector<1xi32> to vector<2xf16>
%44 = arith.addf %39, %43 : vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult_13, %valid_14 = gpu.shuffle xor %46, %c8_i32, %c64_i32 : i32
%47 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %51, %c16_i32, %c64_i32 : i32
%52 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %56, %c32_i32, %c64_i32 : i32
%57 = vector.broadcast %shuffleResult_17 : i32 to vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.reduction <add>, %59 : vector<2xf16> into f16
%61 = arith.addf %60, %23 : f16
%62 = vector.broadcast %61 : f16 to vector<1xf16>
vector.transfer_write %62, %subview_8[%c0_2] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After MemrefCopyToLinalgPass (iree-codegen-memrefcopy-to-linalg) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0 = arith.constant 0 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 : i32 to index
%7 = arith.index_castui %2 : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %10, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %11, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%7) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %12, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%9) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %14[%15] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %10[%15, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %11[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %12[%15, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
gpu.barrier
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloc : memref<32x128xf16, #gpu.address_space<workgroup>>) attrs = {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
gpu.barrier
%16 = gpu.thread_id x
%17 = gpu.thread_id y
%18 = arith.cmpi ult, %16, %c1 : index
scf.if %18 {
%subview_6 = memref.subview %subview[%17] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%19 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_17 = memref.subview %subview_4[%17, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_18 = memref.subview %subview_5[%17, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%62 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_19 = memref.subview %alloc[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_20 = memref.subview %subview_3[%17, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%63 = vector.transfer_read %subview_20[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%64 = vector.transfer_read %subview_17[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%65 = vector.broadcast %64 : vector<1x1xf16> to vector<1x1x8xf16>
%66 = vector.transfer_read %subview_18[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%67 = vector.broadcast %66 : vector<1x1xf16> to vector<1x1x8xf16>
%68 = arith.extui %63 : vector<1x1x8xi4> to vector<1x1x8xi32>
%69 = arith.uitofp %68 : vector<1x1x8xi32> to vector<1x1x8xf16>
%70 = arith.subf %69, %67 : vector<1x1x8xf16>
%71 = arith.mulf %70, %65 : vector<1x1x8xf16>
%72 = vector.transfer_read %subview_19[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%73 = vector.broadcast %72 : vector<1x8xf16> to vector<1x1x8xf16>
%74 = arith.mulf %73, %71 : vector<1x1x8xf16>
%75 = arith.addf %74, %arg3 : vector<1x1x8xf16>
scf.yield %75 : vector<1x1x8xf16>
}
scf.yield %62 : vector<1x1x8xf16>
}
%20 = vector.transfer_read %subview_6[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%21 = vector.extract %19[0, 0] : vector<1x1x8xf16>
%22 = vector.extract %20[0] : vector<1xf16>
%23 = vector.extract_strided_slice %21 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%24 = vector.reduction <add>, %23 : vector<4xf16> into f16
%25 = vector.insert %24, %cst [0] : f16 into vector<2xf16>
%26 = vector.extract_strided_slice %21 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%27 = vector.reduction <add>, %26 : vector<4xf16> into f16
%28 = vector.insert %27, %25 [1] : f16 into vector<2xf16>
%29 = vector.bitcast %28 : vector<2xf16> to vector<1xi32>
%30 = vector.extract %29[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %30, %c1_i32, %c64_i32 : i32
%31 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%32 = vector.bitcast %31 : vector<1xi32> to vector<2xf16>
%33 = arith.addf %28, %32 : vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %35, %c2_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult_7 : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %40, %c4_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_9 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %45, %c8_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_11 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_13, %valid_14 = gpu.shuffle xor %50, %c16_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %55, %c32_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.reduction <add>, %58 : vector<2xf16> into f16
%60 = arith.addf %59, %22 : f16
%61 = vector.broadcast %60 : f16 to vector<1xf16>
vector.transfer_write %61, %subview_6[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After GPUDistributeSharedMemoryCopy (iree-codegen-gpu-distribute-shared-memory-copy) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%c32_6 = arith.constant 32 : index
%subview_7 = memref.subview %16[%c0, %c0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_8 = memref.subview %alloc[%c0, %c0] [16, 128] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%22 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_9 = memref.subview %subview_7[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_10 = memref.subview %subview_8[%21, %22] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%23 = vector.transfer_read %subview_9[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %23, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%c1_11 = arith.constant 1 : index
%24 = arith.muli %c16, %c1_11 : index
%25 = arith.addi %c0, %24 : index
scf.for %arg0 = %c0 to %c128 step %c128 {
%subview_12 = memref.subview %16[%25, %arg0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %alloc[%25, %arg0] [16, 128] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%29 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%30 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%31 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%32 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_14 = memref.subview %subview_12[%29, %30] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_15 = memref.subview %subview_13[%31, %32] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%33 = vector.transfer_read %subview_14[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %33, %subview_15[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
}
gpu.barrier
%26 = gpu.thread_id x
%27 = gpu.thread_id y
%28 = arith.cmpi ult, %26, %c1 : index
scf.if %28 {
%subview_12 = memref.subview %subview[%27] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%29 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_23 = memref.subview %subview_4[%27, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_24 = memref.subview %subview_5[%27, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%72 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_25 = memref.subview %alloc[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_26 = memref.subview %subview_3[%27, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%73 = vector.transfer_read %subview_26[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%74 = vector.transfer_read %subview_23[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%75 = vector.broadcast %74 : vector<1x1xf16> to vector<1x1x8xf16>
%76 = vector.transfer_read %subview_24[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%77 = vector.broadcast %76 : vector<1x1xf16> to vector<1x1x8xf16>
%78 = arith.extui %73 : vector<1x1x8xi4> to vector<1x1x8xi32>
%79 = arith.uitofp %78 : vector<1x1x8xi32> to vector<1x1x8xf16>
%80 = arith.subf %79, %77 : vector<1x1x8xf16>
%81 = arith.mulf %80, %75 : vector<1x1x8xf16>
%82 = vector.transfer_read %subview_25[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%83 = vector.broadcast %82 : vector<1x8xf16> to vector<1x1x8xf16>
%84 = arith.mulf %83, %81 : vector<1x1x8xf16>
%85 = arith.addf %84, %arg3 : vector<1x1x8xf16>
scf.yield %85 : vector<1x1x8xf16>
}
scf.yield %72 : vector<1x1x8xf16>
}
%30 = vector.transfer_read %subview_12[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%31 = vector.extract %29[0, 0] : vector<1x1x8xf16>
%32 = vector.extract %30[0] : vector<1xf16>
%33 = vector.extract_strided_slice %31 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %cst [0] : f16 into vector<2xf16>
%36 = vector.extract_strided_slice %31 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%37 = vector.reduction <add>, %36 : vector<4xf16> into f16
%38 = vector.insert %37, %35 [1] : f16 into vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %40, %c1_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_13, %valid_14 = gpu.shuffle xor %45, %c2_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_13 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %50, %c4_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %55, %c8_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_17 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %60, %c16_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_19 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %65, %c32_i32, %c64_i32 : i32
%66 = vector.broadcast %shuffleResult_21 : i32 to vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.reduction <add>, %68 : vector<2xf16> into f16
%70 = arith.addf %69, %32 : f16
%71 = vector.broadcast %70 : f16 to vector<1xf16>
vector.transfer_write %71, %subview_12[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_6 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %alloc[0, 0] [16, 128] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[128, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%22 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_8 = memref.subview %subview_6[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview_7[%21, %22] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%23 = vector.transfer_read %subview_8[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %23, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_10 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %alloc[16, 0] [16, 128] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[128, 1], offset: 2048>, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%27 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_12 = memref.subview %subview_10[%24, %25] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview_11[%26, %27] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: 2048>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%28 = vector.transfer_read %subview_12[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %28, %subview_13[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%29 = gpu.thread_id x
%30 = gpu.thread_id y
%31 = arith.cmpi ult, %29, %c1 : index
scf.if %31 {
%subview_14 = memref.subview %subview[%30] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%32 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_25 = memref.subview %subview_4[%30, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_26 = memref.subview %subview_5[%30, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %alloc[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_28 = memref.subview %subview_3[%30, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%76 = vector.transfer_read %subview_28[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%77 = vector.transfer_read %subview_25[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%78 = vector.broadcast %77 : vector<1x1xf16> to vector<1x1x8xf16>
%79 = vector.transfer_read %subview_26[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%80 = vector.broadcast %79 : vector<1x1xf16> to vector<1x1x8xf16>
%81 = arith.extui %76 : vector<1x1x8xi4> to vector<1x1x8xi32>
%82 = arith.uitofp %81 : vector<1x1x8xi32> to vector<1x1x8xf16>
%83 = arith.subf %82, %80 : vector<1x1x8xf16>
%84 = arith.mulf %83, %78 : vector<1x1x8xf16>
%85 = vector.transfer_read %subview_27[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%86 = vector.broadcast %85 : vector<1x8xf16> to vector<1x1x8xf16>
%87 = arith.mulf %86, %84 : vector<1x1x8xf16>
%88 = arith.addf %87, %arg3 : vector<1x1x8xf16>
scf.yield %88 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%33 = vector.transfer_read %subview_14[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%34 = vector.extract %32[0, 0] : vector<1x1x8xf16>
%35 = vector.extract %33[0] : vector<1xf16>
%36 = vector.extract_strided_slice %34 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%37 = vector.reduction <add>, %36 : vector<4xf16> into f16
%38 = vector.insert %37, %cst [0] : f16 into vector<2xf16>
%39 = vector.extract_strided_slice %34 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%40 = vector.reduction <add>, %39 : vector<4xf16> into f16
%41 = vector.insert %40, %38 [1] : f16 into vector<2xf16>
%42 = vector.bitcast %41 : vector<2xf16> to vector<1xi32>
%43 = vector.extract %42[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %43, %c1_i32, %c64_i32 : i32
%44 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%45 = vector.bitcast %44 : vector<1xi32> to vector<2xf16>
%46 = arith.addf %41, %45 : vector<2xf16>
%47 = vector.bitcast %46 : vector<2xf16> to vector<1xi32>
%48 = vector.extract %47[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %48, %c2_i32, %c64_i32 : i32
%49 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%50 = vector.bitcast %49 : vector<1xi32> to vector<2xf16>
%51 = arith.addf %46, %50 : vector<2xf16>
%52 = vector.bitcast %51 : vector<2xf16> to vector<1xi32>
%53 = vector.extract %52[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %53, %c4_i32, %c64_i32 : i32
%54 = vector.broadcast %shuffleResult_17 : i32 to vector<1xi32>
%55 = vector.bitcast %54 : vector<1xi32> to vector<2xf16>
%56 = arith.addf %51, %55 : vector<2xf16>
%57 = vector.bitcast %56 : vector<2xf16> to vector<1xi32>
%58 = vector.extract %57[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %58, %c8_i32, %c64_i32 : i32
%59 = vector.broadcast %shuffleResult_19 : i32 to vector<1xi32>
%60 = vector.bitcast %59 : vector<1xi32> to vector<2xf16>
%61 = arith.addf %56, %60 : vector<2xf16>
%62 = vector.bitcast %61 : vector<2xf16> to vector<1xi32>
%63 = vector.extract %62[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %63, %c16_i32, %c64_i32 : i32
%64 = vector.broadcast %shuffleResult_21 : i32 to vector<1xi32>
%65 = vector.bitcast %64 : vector<1xi32> to vector<2xf16>
%66 = arith.addf %61, %65 : vector<2xf16>
%67 = vector.bitcast %66 : vector<2xf16> to vector<1xi32>
%68 = vector.extract %67[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %68, %c32_i32, %c64_i32 : i32
%69 = vector.broadcast %shuffleResult_23 : i32 to vector<1xi32>
%70 = vector.bitcast %69 : vector<1xi32> to vector<2xf16>
%71 = arith.addf %66, %70 : vector<2xf16>
%72 = vector.reduction <add>, %71 : vector<2xf16> into f16
%73 = arith.addf %72, %35 : f16
%74 = vector.broadcast %73 : f16 to vector<1xf16>
vector.transfer_write %74, %subview_14[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x128xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_6 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %alloc[0, 0] [16, 128] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[128, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_8 = memref.subview %subview_6[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview_7[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_8[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %21, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_10 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %alloc[16, 0] [16, 128] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[128, 1], offset: 2048>, #gpu.address_space<workgroup>>
%subview_12 = memref.subview %subview_10[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview_11[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: 2048>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_12[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %22, %subview_13[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_14 = memref.subview %subview[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_25 = memref.subview %subview_4[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_26 = memref.subview %subview_5[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %alloc[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_28 = memref.subview %subview_3[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%68 = vector.transfer_read %subview_28[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%69 = vector.transfer_read %subview_25[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%70 = vector.broadcast %69 : vector<1x1xf16> to vector<1x1x8xf16>
%71 = vector.transfer_read %subview_26[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%72 = vector.broadcast %71 : vector<1x1xf16> to vector<1x1x8xf16>
%73 = arith.extui %68 : vector<1x1x8xi4> to vector<1x1x8xi32>
%74 = arith.uitofp %73 : vector<1x1x8xi32> to vector<1x1x8xf16>
%75 = arith.subf %74, %72 : vector<1x1x8xf16>
%76 = arith.mulf %75, %70 : vector<1x1x8xf16>
%77 = vector.transfer_read %subview_27[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%78 = vector.broadcast %77 : vector<1x8xf16> to vector<1x1x8xf16>
%79 = arith.mulf %78, %76 : vector<1x1x8xf16>
%80 = arith.addf %79, %arg3 : vector<1x1x8xf16>
scf.yield %80 : vector<1x1x8xf16>
}
scf.yield %67 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_14[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_15, %valid_16 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_15 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_17 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_19 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_21 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_23 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.broadcast %65 : f16 to vector<1xf16>
vector.transfer_write %66, %subview_14[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After GPUReduceBankConflicts (iree-codegen-gpu-reduce-bank-conflicts) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_2 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%subview = memref.subview %alloc[0, 0] [32, 128] [1, 1] : memref<32x136xf16, #gpu.address_space<workgroup>> to memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview_3 = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_1, %subview_3[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_7 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_8 = memref.subview %subview[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_9 = memref.subview %subview_7[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_10 = memref.subview %subview_8[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_9[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %21, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_11 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_12 = memref.subview %subview[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>>
%subview_13 = memref.subview %subview_11[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_14 = memref.subview %subview_12[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_13[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x8xf16>
vector.transfer_write %22, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_15 = memref.subview %subview_3[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<1x1x8xf16>) {
%subview_26 = memref.subview %subview_5[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_27 = memref.subview %subview_6[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_28 = memref.subview %subview[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_29 = memref.subview %subview_4[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%68 = vector.transfer_read %subview_29[%c0, %c0, %c0], %c0_i4 {in_bounds = [true, true, true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1x8xi4>
%69 = vector.transfer_read %subview_26[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%70 = vector.broadcast %69 : vector<1x1xf16> to vector<1x1x8xf16>
%71 = vector.transfer_read %subview_27[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1x1xf16>
%72 = vector.broadcast %71 : vector<1x1xf16> to vector<1x1x8xf16>
%73 = arith.extui %68 : vector<1x1x8xi4> to vector<1x1x8xi32>
%74 = arith.uitofp %73 : vector<1x1x8xi32> to vector<1x1x8xf16>
%75 = arith.subf %74, %72 : vector<1x1x8xf16>
%76 = arith.mulf %75, %70 : vector<1x1x8xf16>
%77 = vector.transfer_read %subview_28[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<1x8xf16>
%78 = vector.broadcast %77 : vector<1x8xf16> to vector<1x1x8xf16>
%79 = arith.mulf %78, %76 : vector<1x1x8xf16>
%80 = arith.addf %79, %arg3 : vector<1x1x8xf16>
scf.yield %80 : vector<1x1x8xf16>
}
scf.yield %67 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_15[%c0], %cst_2 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.broadcast %shuffleResult : i32 to vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_16, %valid_17 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.broadcast %shuffleResult_16 : i32 to vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_18, %valid_19 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.broadcast %shuffleResult_18 : i32 to vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_20, %valid_21 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.broadcast %shuffleResult_20 : i32 to vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_22, %valid_23 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.broadcast %shuffleResult_22 : i32 to vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_24, %valid_25 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.broadcast %shuffleResult_24 : i32 to vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.broadcast %65 : f16 to vector<1xf16>
vector.transfer_write %66, %subview_15[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After SPIRVVectorize (iree-spirv-vectorize) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%subview = memref.subview %alloc[0, 0] [32, 128] [1, 1] : memref<32x136xf16, #gpu.address_space<workgroup>> to memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview_4 = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_2, %subview_4[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_8 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_10 = memref.subview %subview_8[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %subview_9[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_10[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %subview_11[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_12 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>>
%subview_14 = memref.subview %subview_12[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_15 = memref.subview %subview_13[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_14[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %22, %subview_15[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_16 = memref.subview %subview_4[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %subview_6[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_28 = memref.subview %subview_7[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = vector.transfer_read %subview_27[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%68 = vector.transfer_read %subview_28[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%69 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_29 = memref.subview %subview[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_30 = memref.subview %subview_5[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%70 = vector.transfer_read %subview_30[%c0, %c0, %c0], %c0_i4 {in_bounds = [true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%71 = vector.extract_strided_slice %70 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%72 = arith.extui %71 : vector<4xi4> to vector<4xi32>
%73 = vector.extract_strided_slice %70 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%74 = arith.extui %73 : vector<4xi4> to vector<4xi32>
%75 = arith.uitofp %72 : vector<4xi32> to vector<4xf16>
%76 = arith.uitofp %74 : vector<4xi32> to vector<4xf16>
%77 = vector.extract %68[0] : vector<1xf16>
%78 = vector.splat %77 : vector<4xf16>
%79 = arith.subf %75, %78 : vector<4xf16>
%80 = vector.extract %68[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = arith.subf %76, %81 : vector<4xf16>
%83 = vector.extract %67[0] : vector<1xf16>
%84 = vector.splat %83 : vector<4xf16>
%85 = arith.mulf %79, %84 : vector<4xf16>
%86 = vector.extract %67[0] : vector<1xf16>
%87 = vector.splat %86 : vector<4xf16>
%88 = arith.mulf %82, %87 : vector<4xf16>
%89 = vector.transfer_read %subview_29[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<8xf16>
%90 = vector.extract_strided_slice %89 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %85 : vector<4xf16>
%92 = vector.extract_strided_slice %89 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%93 = arith.mulf %92, %88 : vector<4xf16>
%94 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%95 = vector.extract_strided_slice %94 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%96 = arith.addf %91, %95 : vector<4xf16>
%97 = vector.insert_strided_slice %96, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%98 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%99 = vector.extract_strided_slice %98 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.addf %93, %99 : vector<4xf16>
%101 = vector.insert_strided_slice %100, %97 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%102 = vector.broadcast %101 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %102 : vector<1x1x8xf16>
}
scf.yield %69 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_16[%c0], %cst_3 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst_0 [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.splat %shuffleResult : vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult_17 : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_19 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_21 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_23 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_25, %valid_26 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_25 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.splat %65 : vector<1xf16>
vector.transfer_write %66, %subview_16[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After ForOpCanonicalization (iree-codegen-canonicalize-scf-for) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%subview = memref.subview %alloc[0, 0] [32, 128] [1, 1] : memref<32x136xf16, #gpu.address_space<workgroup>> to memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview_4 = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_2, %subview_4[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_8 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_10 = memref.subview %subview_8[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %subview_9[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_10[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %subview_11[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_12 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>>
%subview_14 = memref.subview %subview_12[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_15 = memref.subview %subview_13[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_14[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %22, %subview_15[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_16 = memref.subview %subview_4[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %subview_6[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_28 = memref.subview %subview_7[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = vector.transfer_read %subview_27[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%68 = vector.transfer_read %subview_28[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%69 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_29 = memref.subview %subview[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_30 = memref.subview %subview_5[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%70 = vector.transfer_read %subview_30[%c0, %c0, %c0], %c0_i4 {in_bounds = [true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%71 = vector.extract_strided_slice %70 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%72 = arith.extui %71 : vector<4xi4> to vector<4xi32>
%73 = vector.extract_strided_slice %70 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%74 = arith.extui %73 : vector<4xi4> to vector<4xi32>
%75 = arith.uitofp %72 : vector<4xi32> to vector<4xf16>
%76 = arith.uitofp %74 : vector<4xi32> to vector<4xf16>
%77 = vector.extract %68[0] : vector<1xf16>
%78 = vector.splat %77 : vector<4xf16>
%79 = arith.subf %75, %78 : vector<4xf16>
%80 = vector.extract %68[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = arith.subf %76, %81 : vector<4xf16>
%83 = vector.extract %67[0] : vector<1xf16>
%84 = vector.splat %83 : vector<4xf16>
%85 = arith.mulf %79, %84 : vector<4xf16>
%86 = vector.extract %67[0] : vector<1xf16>
%87 = vector.splat %86 : vector<4xf16>
%88 = arith.mulf %82, %87 : vector<4xf16>
%89 = vector.transfer_read %subview_29[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<8xf16>
%90 = vector.extract_strided_slice %89 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %85 : vector<4xf16>
%92 = vector.extract_strided_slice %89 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%93 = arith.mulf %92, %88 : vector<4xf16>
%94 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%95 = vector.extract_strided_slice %94 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%96 = arith.addf %91, %95 : vector<4xf16>
%97 = vector.insert_strided_slice %96, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%98 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%99 = vector.extract_strided_slice %98 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.addf %93, %99 : vector<4xf16>
%101 = vector.insert_strided_slice %100, %97 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%102 = vector.broadcast %101 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %102 : vector<1x1x8xf16>
}
scf.yield %69 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_16[%c0], %cst_3 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst_0 [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.splat %shuffleResult : vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult_17 : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_19 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_21 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_23 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_25, %valid_26 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_25 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.splat %65 : vector<1xf16>
vector.transfer_write %66, %subview_16[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%subview = memref.subview %alloc[0, 0] [32, 128] [1, 1] : memref<32x136xf16, #gpu.address_space<workgroup>> to memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview_4 = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_2, %subview_4[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_8 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_10 = memref.subview %subview_8[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %subview_9[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_10[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %subview_11[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_12 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>>
%subview_14 = memref.subview %subview_12[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_15 = memref.subview %subview_13[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_14[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %22, %subview_15[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_16 = memref.subview %subview_4[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %subview_6[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_28 = memref.subview %subview_7[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = vector.transfer_read %subview_27[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%68 = vector.transfer_read %subview_28[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%69 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_29 = memref.subview %subview[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_30 = memref.subview %subview_5[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%70 = vector.transfer_read %subview_30[%c0, %c0, %c0], %c0_i4 {in_bounds = [true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%71 = vector.extract_strided_slice %70 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%72 = arith.extui %71 : vector<4xi4> to vector<4xi32>
%73 = vector.extract_strided_slice %70 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%74 = arith.extui %73 : vector<4xi4> to vector<4xi32>
%75 = arith.uitofp %72 : vector<4xi32> to vector<4xf16>
%76 = arith.uitofp %74 : vector<4xi32> to vector<4xf16>
%77 = vector.extract %68[0] : vector<1xf16>
%78 = vector.splat %77 : vector<4xf16>
%79 = arith.subf %75, %78 : vector<4xf16>
%80 = vector.extract %68[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = arith.subf %76, %81 : vector<4xf16>
%83 = vector.extract %67[0] : vector<1xf16>
%84 = vector.splat %83 : vector<4xf16>
%85 = arith.mulf %79, %84 : vector<4xf16>
%86 = vector.extract %67[0] : vector<1xf16>
%87 = vector.splat %86 : vector<4xf16>
%88 = arith.mulf %82, %87 : vector<4xf16>
%89 = vector.transfer_read %subview_29[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<8xf16>
%90 = vector.extract_strided_slice %89 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %85 : vector<4xf16>
%92 = vector.extract_strided_slice %89 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%93 = arith.mulf %92, %88 : vector<4xf16>
%94 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%95 = vector.extract_strided_slice %94 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%96 = arith.addf %91, %95 : vector<4xf16>
%97 = vector.insert_strided_slice %96, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%98 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%99 = vector.extract_strided_slice %98 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.addf %93, %99 : vector<4xf16>
%101 = vector.insert_strided_slice %100, %97 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%102 = vector.broadcast %101 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %102 : vector<1x1x8xf16>
}
scf.yield %69 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_16[%c0], %cst_3 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst_0 [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.splat %shuffleResult : vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult_17 : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_19 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_21 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_23 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_25, %valid_26 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_25 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.splat %65 : vector<1xf16>
vector.transfer_write %66, %subview_16[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%subview = memref.subview %alloc[0, 0] [32, 128] [1, 1] : memref<32x136xf16, #gpu.address_space<workgroup>> to memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview_4 = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_2, %subview_4[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_8 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_10 = memref.subview %subview_8[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %subview_9[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_10[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %subview_11[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_12 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>>
%subview_14 = memref.subview %subview_12[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_15 = memref.subview %subview_13[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_14[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %22, %subview_15[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_16 = memref.subview %subview_4[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %subview_6[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_28 = memref.subview %subview_7[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = vector.transfer_read %subview_27[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%68 = vector.transfer_read %subview_28[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%69 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_29 = memref.subview %subview[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_30 = memref.subview %subview_5[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%70 = vector.transfer_read %subview_30[%c0, %c0, %c0], %c0_i4 {in_bounds = [true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%71 = vector.extract_strided_slice %70 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%72 = arith.extui %71 : vector<4xi4> to vector<4xi32>
%73 = vector.extract_strided_slice %70 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%74 = arith.extui %73 : vector<4xi4> to vector<4xi32>
%75 = arith.uitofp %72 : vector<4xi32> to vector<4xf16>
%76 = arith.uitofp %74 : vector<4xi32> to vector<4xf16>
%77 = vector.extract %68[0] : vector<1xf16>
%78 = vector.splat %77 : vector<4xf16>
%79 = arith.subf %75, %78 : vector<4xf16>
%80 = arith.subf %76, %78 : vector<4xf16>
%81 = vector.extract %67[0] : vector<1xf16>
%82 = vector.splat %81 : vector<4xf16>
%83 = arith.mulf %79, %82 : vector<4xf16>
%84 = arith.mulf %80, %82 : vector<4xf16>
%85 = vector.transfer_read %subview_29[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<8xf16>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%87 = arith.mulf %86, %83 : vector<4xf16>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %84 : vector<4xf16>
%90 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%91 = vector.extract_strided_slice %90 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%92 = arith.addf %87, %91 : vector<4xf16>
%93 = vector.insert_strided_slice %92, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%94 = vector.extract_strided_slice %90 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%95 = arith.addf %89, %94 : vector<4xf16>
%96 = vector.insert_strided_slice %95, %93 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%97 = vector.broadcast %96 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %97 : vector<1x1x8xf16>
}
scf.yield %69 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_16[%c0], %cst_3 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst_0 [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.splat %shuffleResult : vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult_17 : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_19 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_21 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_23 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_25, %valid_26 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_25 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.splat %65 : vector<1xf16>
vector.transfer_write %66, %subview_16[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After OptimizeVectorTransfer (iree-codegen-optimize-vector-transfer) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%subview = memref.subview %alloc[0, 0] [32, 128] [1, 1] : memref<32x136xf16, #gpu.address_space<workgroup>> to memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%subview_4 = memref.subview %17[%18] [4] [1] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %13[%18, 0, 0] [4, 32, 128] [1, 1, 1] : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_6 = memref.subview %14[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %15[%18, 0] [4, 32] [1, 1] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
vector.transfer_write %cst_2, %subview_4[%c0] {in_bounds = [true]} : vector<4xf16>, memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%subview_8 = memref.subview %16[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %subview[0, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>>
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%subview_10 = memref.subview %subview_8[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_11 = memref.subview %subview_9[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%21 = vector.transfer_read %subview_10[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %subview_11[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_12 = memref.subview %16[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_13 = memref.subview %subview[16, 0] [16, 128] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>>
%subview_14 = memref.subview %subview_12[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_15 = memref.subview %subview_13[%19, %20] [1, 8] [1, 1] : memref<16x128xf16, strided<[136, 1], offset: 2176>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%22 = vector.transfer_read %subview_14[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %22, %subview_15[%c0, %c0] {in_bounds = [true]} : vector<8xf16>, memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
gpu.barrier
%23 = arith.cmpi ult, %0, %c1 : index
scf.if %23 {
%subview_16 = memref.subview %subview_4[%1] [1] [1] : memref<4xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%24 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%subview_27 = memref.subview %subview_6[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_28 = memref.subview %subview_7[%1, %arg0] [1, 1] [1, 1] : memref<4x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%67 = vector.transfer_read %subview_27[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%68 = vector.transfer_read %subview_28[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x1xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%69 = vector.extract %68[0] : vector<1xf16>
%70 = vector.splat %69 : vector<4xf16>
%71 = vector.extract %67[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%subview_29 = memref.subview %subview[%arg0, %arg2] [1, 8] [1, 1] : memref<32x128xf16, strided<[136, 1]>, #gpu.address_space<workgroup>> to memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_30 = memref.subview %subview_5[%1, %arg0, %arg2] [1, 1, 8] [1, 1, 1] : memref<4x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%74 = vector.transfer_read %subview_30[%c0, %c0, %c0], %c0_i4 {in_bounds = [true]} : memref<1x1x8xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%75 = vector.extract_strided_slice %74 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%76 = arith.extui %75 : vector<4xi4> to vector<4xi32>
%77 = vector.extract_strided_slice %74 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = arith.uitofp %76 : vector<4xi32> to vector<4xf16>
%80 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%81 = arith.subf %79, %70 : vector<4xf16>
%82 = arith.subf %80, %70 : vector<4xf16>
%83 = arith.mulf %81, %72 : vector<4xf16>
%84 = arith.mulf %82, %72 : vector<4xf16>
%85 = vector.transfer_read %subview_29[%c0, %c0], %cst_3 {in_bounds = [true]} : memref<1x8xf16, strided<[136, 1], offset: ?>, #gpu.address_space<workgroup>>, vector<8xf16>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%87 = arith.mulf %86, %83 : vector<4xf16>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %84 : vector<4xf16>
%90 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%91 = vector.extract_strided_slice %90 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%92 = arith.addf %87, %91 : vector<4xf16>
%93 = vector.insert_strided_slice %92, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%94 = vector.extract_strided_slice %90 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%95 = arith.addf %89, %94 : vector<4xf16>
%96 = vector.insert_strided_slice %95, %93 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%97 = vector.broadcast %96 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %97 : vector<1x1x8xf16>
}
scf.yield %73 : vector<1x1x8xf16>
}
%25 = vector.transfer_read %subview_16[%c0], %cst_3 {in_bounds = [true]} : memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%26 = vector.extract %24[0, 0] : vector<1x1x8xf16>
%27 = vector.extract %25[0] : vector<1xf16>
%28 = vector.extract_strided_slice %26 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%29 = vector.reduction <add>, %28 : vector<4xf16> into f16
%30 = vector.insert %29, %cst_0 [0] : f16 into vector<2xf16>
%31 = vector.extract_strided_slice %26 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%32 = vector.reduction <add>, %31 : vector<4xf16> into f16
%33 = vector.insert %32, %30 [1] : f16 into vector<2xf16>
%34 = vector.bitcast %33 : vector<2xf16> to vector<1xi32>
%35 = vector.extract %34[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %35, %c1_i32, %c64_i32 : i32
%36 = vector.splat %shuffleResult : vector<1xi32>
%37 = vector.bitcast %36 : vector<1xi32> to vector<2xf16>
%38 = arith.addf %33, %37 : vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult_17, %valid_18 = gpu.shuffle xor %40, %c2_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult_17 : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_19, %valid_20 = gpu.shuffle xor %45, %c4_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_19 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_21, %valid_22 = gpu.shuffle xor %50, %c8_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_21 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_23, %valid_24 = gpu.shuffle xor %55, %c16_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_23 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_25, %valid_26 = gpu.shuffle xor %60, %c32_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_25 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.reduction <add>, %63 : vector<2xf16> into f16
%65 = arith.addf %64, %27 : f16
%66 = vector.splat %65 : vector<1xf16>
vector.transfer_write %66, %subview_16[%c0] {in_bounds = [true]} : vector<1xf16>, memref<1xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After FoldMemRefAliasOps (fold-memref-alias-ops) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%75 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%76 = vector.transfer_read %14[%75, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%78 = vector.transfer_read %15[%77, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.extract %78[0] : vector<1xf16>
%80 = vector.splat %79 : vector<4xf16>
%81 = vector.extract %76[0] : vector<1xf16>
%82 = vector.splat %81 : vector<4xf16>
%83 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%84 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%85 = vector.transfer_read %13[%84, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %80 : vector<4xf16>
%93 = arith.subf %91, %80 : vector<4xf16>
%94 = arith.mulf %92, %82 : vector<4xf16>
%95 = arith.mulf %93, %82 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %83 : vector<1x1x8xf16>
}
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = vector.transfer_read %17[%31], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%33 = vector.extract %30[0, 0] : vector<1x1x8xf16>
%34 = vector.extract %32[0] : vector<1xf16>
%35 = vector.extract_strided_slice %33 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%36 = vector.reduction <add>, %35 : vector<4xf16> into f16
%37 = vector.insert %36, %cst_0 [0] : f16 into vector<2xf16>
%38 = vector.extract_strided_slice %33 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %37 [1] : f16 into vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %42, %c1_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %47, %c2_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_4 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %52, %c4_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_6 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %57, %c8_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_8 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %62, %c16_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_10 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.bitcast %65 : vector<2xf16> to vector<1xi32>
%67 = vector.extract %66[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %67, %c32_i32, %c64_i32 : i32
%68 = vector.splat %shuffleResult_12 : vector<1xi32>
%69 = vector.bitcast %68 : vector<1xi32> to vector<2xf16>
%70 = arith.addf %65, %69 : vector<2xf16>
%71 = vector.reduction <add>, %70 : vector<2xf16> into f16
%72 = arith.addf %71, %34 : f16
%73 = vector.splat %72 : vector<1xf16>
%74 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %73, %17[%74] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After OptimizeVectorTransfer (iree-codegen-optimize-vector-transfer) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After LoopInvariantCodeMotion (loop-invariant-code-motion) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After LinalgExtToLoops (iree-linalg-ext-to-loops) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After MemrefCopyToLinalgPass (iree-codegen-memrefcopy-to-linalg) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After LinalgLowerToLoops (convert-linalg-to-loops) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After RemoveSingleIterationLoop (iree-codegen-remove-single-iteration-loop) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %21, %alloc[%22, %23] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%24 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%25 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%26 = vector.transfer_read %16[%24, %25], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
vector.transfer_write %26, %alloc[%27, %28] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%29 = arith.cmpi ult, %0, %c1 : index
scf.if %29 {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%33 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%78 = vector.transfer_read %14[%30, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%79 = vector.transfer_read %15[%31, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%80 = vector.extract %79[0] : vector<1xf16>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.extract %78[0] : vector<1xf16>
%83 = vector.splat %82 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%85 = vector.transfer_read %13[%32, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%86 = vector.extract_strided_slice %85 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%87 = arith.extui %86 : vector<4xi4> to vector<4xi32>
%88 = vector.extract_strided_slice %85 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = arith.uitofp %87 : vector<4xi32> to vector<4xf16>
%91 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%92 = arith.subf %90, %81 : vector<4xf16>
%93 = arith.subf %91, %81 : vector<4xf16>
%94 = arith.mulf %92, %83 : vector<4xf16>
%95 = arith.mulf %93, %83 : vector<4xf16>
%96 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%97 = vector.extract_strided_slice %96 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%98 = arith.mulf %97, %94 : vector<4xf16>
%99 = vector.extract_strided_slice %96 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%100 = arith.mulf %99, %95 : vector<4xf16>
%101 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%102 = vector.extract_strided_slice %101 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%103 = arith.addf %98, %102 : vector<4xf16>
%104 = vector.insert_strided_slice %103, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%105 = vector.extract_strided_slice %101 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%106 = arith.addf %100, %105 : vector<4xf16>
%107 = vector.insert_strided_slice %106, %104 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%108 = vector.broadcast %107 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %108 : vector<1x1x8xf16>
}
scf.yield %84 : vector<1x1x8xf16>
}
%34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%35 = vector.transfer_read %17[%34], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%36 = vector.extract %33[0, 0] : vector<1x1x8xf16>
%37 = vector.extract %35[0] : vector<1xf16>
%38 = vector.extract_strided_slice %36 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_0 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %36 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%42 = vector.reduction <add>, %41 : vector<4xf16> into f16
%43 = vector.insert %42, %40 [1] : f16 into vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %45, %c1_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %50, %c2_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_4 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %55, %c4_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_6 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %60, %c8_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_8 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %65, %c16_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_10 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.bitcast %68 : vector<2xf16> to vector<1xi32>
%70 = vector.extract %69[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %70, %c32_i32, %c64_i32 : i32
%71 = vector.splat %shuffleResult_12 : vector<1xi32>
%72 = vector.bitcast %71 : vector<1xi32> to vector<2xf16>
%73 = arith.addf %68, %72 : vector<2xf16>
%74 = vector.reduction <add>, %73 : vector<2xf16> into f16
%75 = arith.addf %74, %37 : f16
%76 = vector.splat %75 : vector<1xf16>
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
vector.transfer_write %76, %17[%77] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After SPIRVLowerExecutableTarget (iree-spirv-lower-executable-target-pass) //----- //
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>}> {
hal.executable.export public @forward_dispatch_3_generic_4096x32x128_f16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 5, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<SPIRVMatvecPromoteSubgroupReduce>, workgroup_size = [64 : index, 4 : index, 1 : index]} {
^bb0(%arg0: !hal.device):
%c1024 = arith.constant 1024 : index
%c1 = arith.constant 1 : index
hal.return %c1024, %c1, %c1 : index, index, index
}
builtin.module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After ConvertComplexToStandard (convert-complex-to-standard) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After PolynomialApproximationPass (iree-codegen-polynomial-approximation) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After PadDynamicAlloc (iree-codegen-pad-dynamic-alloc) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After GPUCheckResourceUsage (iree-codegen-gpu-check-resource-usage) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After FoldMemRefAliasOps (fold-memref-alias-ops) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After ExpandOps (memref-expand) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c0_i4 = arith.constant 0 : i4
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst_3 = arith.constant 0.000000e+00 : f16
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x136xf16, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
vector.transfer_write %cst_2, %17[%18] {in_bounds = [true]} : vector<4xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%19 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%20 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%21 = vector.transfer_read %16[%19, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %21, %alloc[%19, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%23 = vector.transfer_read %16[%22, %20], %cst_3 {in_bounds = [true]} : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
vector.transfer_write %23, %alloc[%22, %20] {in_bounds = [true]} : vector<8xf16>, memref<32x136xf16, #gpu.address_space<workgroup>>
gpu.barrier
%24 = arith.cmpi ult, %0, %c1 : index
scf.if %24 {
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%26 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%69 = vector.transfer_read %14[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%70 = vector.transfer_read %15[%25, %arg0], %cst_3 {in_bounds = [true]} : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%71 = vector.extract %70[0] : vector<1xf16>
%72 = vector.splat %71 : vector<4xf16>
%73 = vector.extract %69[0] : vector<1xf16>
%74 = vector.splat %73 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%76 = vector.transfer_read %13[%25, %arg0, %arg2], %c0_i4 {in_bounds = [true]} : memref<4096x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8xi4>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%78 = arith.extui %77 : vector<4xi4> to vector<4xi32>
%79 = vector.extract_strided_slice %76 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%80 = arith.extui %79 : vector<4xi4> to vector<4xi32>
%81 = arith.uitofp %78 : vector<4xi32> to vector<4xf16>
%82 = arith.uitofp %80 : vector<4xi32> to vector<4xf16>
%83 = arith.subf %81, %72 : vector<4xf16>
%84 = arith.subf %82, %72 : vector<4xf16>
%85 = arith.mulf %83, %74 : vector<4xf16>
%86 = arith.mulf %84, %74 : vector<4xf16>
%87 = vector.transfer_read %alloc[%arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x136xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%89 = arith.mulf %88, %85 : vector<4xf16>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%91 = arith.mulf %90, %86 : vector<4xf16>
%92 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%93 = vector.extract_strided_slice %92 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%94 = arith.addf %89, %93 : vector<4xf16>
%95 = vector.insert_strided_slice %94, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%96 = vector.extract_strided_slice %92 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%97 = arith.addf %91, %96 : vector<4xf16>
%98 = vector.insert_strided_slice %97, %95 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%99 = vector.broadcast %98 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %99 : vector<1x1x8xf16>
}
scf.yield %75 : vector<1x1x8xf16>
}
%27 = vector.transfer_read %17[%25], %cst_3 {in_bounds = [true]} : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%28 = vector.extract %26[0, 0] : vector<1x1x8xf16>
%29 = vector.extract %27[0] : vector<1xf16>
%30 = vector.extract_strided_slice %28 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%31 = vector.reduction <add>, %30 : vector<4xf16> into f16
%32 = vector.insert %31, %cst_0 [0] : f16 into vector<2xf16>
%33 = vector.extract_strided_slice %28 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%34 = vector.reduction <add>, %33 : vector<4xf16> into f16
%35 = vector.insert %34, %32 [1] : f16 into vector<2xf16>
%36 = vector.bitcast %35 : vector<2xf16> to vector<1xi32>
%37 = vector.extract %36[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %37, %c1_i32, %c64_i32 : i32
%38 = vector.splat %shuffleResult : vector<1xi32>
%39 = vector.bitcast %38 : vector<1xi32> to vector<2xf16>
%40 = arith.addf %35, %39 : vector<2xf16>
%41 = vector.bitcast %40 : vector<2xf16> to vector<1xi32>
%42 = vector.extract %41[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %42, %c2_i32, %c64_i32 : i32
%43 = vector.splat %shuffleResult_4 : vector<1xi32>
%44 = vector.bitcast %43 : vector<1xi32> to vector<2xf16>
%45 = arith.addf %40, %44 : vector<2xf16>
%46 = vector.bitcast %45 : vector<2xf16> to vector<1xi32>
%47 = vector.extract %46[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %47, %c4_i32, %c64_i32 : i32
%48 = vector.splat %shuffleResult_6 : vector<1xi32>
%49 = vector.bitcast %48 : vector<1xi32> to vector<2xf16>
%50 = arith.addf %45, %49 : vector<2xf16>
%51 = vector.bitcast %50 : vector<2xf16> to vector<1xi32>
%52 = vector.extract %51[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %52, %c8_i32, %c64_i32 : i32
%53 = vector.splat %shuffleResult_8 : vector<1xi32>
%54 = vector.bitcast %53 : vector<1xi32> to vector<2xf16>
%55 = arith.addf %50, %54 : vector<2xf16>
%56 = vector.bitcast %55 : vector<2xf16> to vector<1xi32>
%57 = vector.extract %56[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %57, %c16_i32, %c64_i32 : i32
%58 = vector.splat %shuffleResult_10 : vector<1xi32>
%59 = vector.bitcast %58 : vector<1xi32> to vector<2xf16>
%60 = arith.addf %55, %59 : vector<2xf16>
%61 = vector.bitcast %60 : vector<2xf16> to vector<1xi32>
%62 = vector.extract %61[0] : vector<1xi32>
%shuffleResult_12, %valid_13 = gpu.shuffle xor %62, %c32_i32, %c64_i32 : i32
%63 = vector.splat %shuffleResult_12 : vector<1xi32>
%64 = vector.bitcast %63 : vector<1xi32> to vector<2xf16>
%65 = arith.addf %60, %64 : vector<2xf16>
%66 = vector.reduction <add>, %65 : vector<2xf16> into f16
%67 = arith.addf %66, %29 : f16
%68 = vector.splat %67 : vector<1xf16>
vector.transfer_write %68, %17[%25] {in_bounds = [true]} : vector<1xf16>, memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After SPIRVVectorizeLoadStore (iree-spirv-vectorize-load-store) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c0]
%20 = vector.extract %cst_2[0] : vector<4xf16>
memref.store %20, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c1]
%22 = vector.extract %cst_2[1] : vector<4xf16>
memref.store %22, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%23 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c2]
%24 = vector.extract %cst_2[2] : vector<4xf16>
memref.store %24, %17[%23] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c3]
%26 = vector.extract %cst_2[3] : vector<4xf16>
memref.store %26, %17[%25] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%29 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
%30 = memref.load %16[%27, %29] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
memref.store %30, %alloc[%27, %31] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%32 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%33 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
%34 = memref.load %16[%32, %33] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%35 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
memref.store %34, %alloc[%32, %35] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%36 = arith.cmpi ult, %0, %c1 : index
scf.if %36 {
%37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%38 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x1x8xf16>) {
%81 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%82 = memref.load %14[%37, %81] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%83 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%84 = memref.load %15[%37, %83] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%85 = vector.splat %84 : vector<4xf16>
%86 = vector.splat %82 : vector<4xf16>
%87 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<1x1x8xf16>) {
%88 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%89 = memref.load %13[%37, %arg0, %88] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%90 = vector.bitcast %89 : vector<1xi32> to vector<8xi4>
%91 = vector.extract_strided_slice %90 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%92 = arith.extui %91 : vector<4xi4> to vector<4xi32>
%93 = vector.extract_strided_slice %90 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%94 = arith.extui %93 : vector<4xi4> to vector<4xi32>
%95 = arith.uitofp %92 : vector<4xi32> to vector<4xf16>
%96 = arith.uitofp %94 : vector<4xi32> to vector<4xf16>
%97 = arith.subf %95, %85 : vector<4xf16>
%98 = arith.subf %96, %85 : vector<4xf16>
%99 = arith.mulf %97, %86 : vector<4xf16>
%100 = arith.mulf %98, %86 : vector<4xf16>
%101 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%102 = memref.load %alloc[%arg0, %101] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%103 = vector.bitcast %102 : vector<4xf32> to vector<8xf16>
%104 = vector.extract_strided_slice %103 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%105 = arith.mulf %104, %99 : vector<4xf16>
%106 = vector.extract_strided_slice %103 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%107 = arith.mulf %106, %100 : vector<4xf16>
%108 = vector.extract %arg3[0, 0] : vector<1x1x8xf16>
%109 = vector.extract_strided_slice %108 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%110 = arith.addf %105, %109 : vector<4xf16>
%111 = vector.insert_strided_slice %110, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%112 = vector.extract_strided_slice %108 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%113 = arith.addf %107, %112 : vector<4xf16>
%114 = vector.insert_strided_slice %113, %111 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%115 = vector.broadcast %114 : vector<8xf16> to vector<1x1x8xf16>
scf.yield %115 : vector<1x1x8xf16>
}
scf.yield %87 : vector<1x1x8xf16>
}
%39 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%37, %c0]
%40 = memref.load %17[%39] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%41 = vector.extract %38[0, 0] : vector<1x1x8xf16>
%42 = vector.extract_strided_slice %41 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %cst_0 [0] : f16 into vector<2xf16>
%45 = vector.extract_strided_slice %41 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%46 = vector.reduction <add>, %45 : vector<4xf16> into f16
%47 = vector.insert %46, %44 [1] : f16 into vector<2xf16>
%48 = vector.bitcast %47 : vector<2xf16> to vector<1xi32>
%49 = vector.extract %48[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %49, %c1_i32, %c64_i32 : i32
%50 = vector.splat %shuffleResult : vector<1xi32>
%51 = vector.bitcast %50 : vector<1xi32> to vector<2xf16>
%52 = arith.addf %47, %51 : vector<2xf16>
%53 = vector.bitcast %52 : vector<2xf16> to vector<1xi32>
%54 = vector.extract %53[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %54, %c2_i32, %c64_i32 : i32
%55 = vector.splat %shuffleResult_3 : vector<1xi32>
%56 = vector.bitcast %55 : vector<1xi32> to vector<2xf16>
%57 = arith.addf %52, %56 : vector<2xf16>
%58 = vector.bitcast %57 : vector<2xf16> to vector<1xi32>
%59 = vector.extract %58[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %59, %c4_i32, %c64_i32 : i32
%60 = vector.splat %shuffleResult_5 : vector<1xi32>
%61 = vector.bitcast %60 : vector<1xi32> to vector<2xf16>
%62 = arith.addf %57, %61 : vector<2xf16>
%63 = vector.bitcast %62 : vector<2xf16> to vector<1xi32>
%64 = vector.extract %63[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %64, %c8_i32, %c64_i32 : i32
%65 = vector.splat %shuffleResult_7 : vector<1xi32>
%66 = vector.bitcast %65 : vector<1xi32> to vector<2xf16>
%67 = arith.addf %62, %66 : vector<2xf16>
%68 = vector.bitcast %67 : vector<2xf16> to vector<1xi32>
%69 = vector.extract %68[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %69, %c16_i32, %c64_i32 : i32
%70 = vector.splat %shuffleResult_9 : vector<1xi32>
%71 = vector.bitcast %70 : vector<1xi32> to vector<2xf16>
%72 = arith.addf %67, %71 : vector<2xf16>
%73 = vector.bitcast %72 : vector<2xf16> to vector<1xi32>
%74 = vector.extract %73[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %74, %c32_i32, %c64_i32 : i32
%75 = vector.splat %shuffleResult_11 : vector<1xi32>
%76 = vector.bitcast %75 : vector<1xi32> to vector<2xf16>
%77 = arith.addf %72, %76 : vector<2xf16>
%78 = vector.reduction <add>, %77 : vector<2xf16> into f16
%79 = arith.addf %78, %40 : f16
%80 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%37, %c0]
memref.store %79, %17[%80] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After ForOpCanonicalization (iree-codegen-canonicalize-scf-for) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%cst = arith.constant dense<0.000000e+00> : vector<8xf16>
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<0.000000e+00> : vector<2xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<1x1x8xf16>
%cst_2 = arith.constant dense<0.000000e+00> : vector<4xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c0]
%20 = vector.extract %cst_2[0] : vector<4xf16>
memref.store %20, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c1]
%22 = vector.extract %cst_2[1] : vector<4xf16>
memref.store %22, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%23 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c2]
%24 = vector.extract %cst_2[2] : vector<4xf16>
memref.store %24, %17[%23] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c3]
%26 = vector.extract %cst_2[3] : vector<4xf16>
memref.store %26, %17[%25] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%27 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%28 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%29 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
%30 = memref.load %16[%27, %29] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
memref.store %30, %alloc[%27, %31] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%32 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%33 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
%34 = memref.load %16[%32, %33] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%35 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%28, %c8]
memref.store %34, %alloc[%32, %35] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%36 = arith.cmpi ult, %0, %c1 : index
scf.if %36 {
%37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%38 = vector.extract %cst_1[0, 0] : vector<1x1x8xf16>
%39 = vector.bitcast %38 : vector<8xf16> to vector<4xf32>
%40 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %39) -> (vector<4xf32>) {
%83 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%84 = memref.load %14[%37, %83] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%85 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%86 = memref.load %15[%37, %85] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%87 = vector.splat %86 : vector<4xf16>
%88 = vector.splat %84 : vector<4xf16>
%89 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%90 = vector.bitcast %arg3 : vector<4xf32> to vector<8xf16>
%91 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%92 = memref.load %13[%37, %arg0, %91] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%93 = vector.bitcast %92 : vector<1xi32> to vector<8xi4>
%94 = vector.extract_strided_slice %93 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%95 = arith.extui %94 : vector<4xi4> to vector<4xi32>
%96 = vector.extract_strided_slice %93 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%97 = arith.extui %96 : vector<4xi4> to vector<4xi32>
%98 = arith.uitofp %95 : vector<4xi32> to vector<4xf16>
%99 = arith.uitofp %97 : vector<4xi32> to vector<4xf16>
%100 = arith.subf %98, %87 : vector<4xf16>
%101 = arith.subf %99, %87 : vector<4xf16>
%102 = arith.mulf %100, %88 : vector<4xf16>
%103 = arith.mulf %101, %88 : vector<4xf16>
%104 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%105 = memref.load %alloc[%arg0, %104] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%106 = vector.bitcast %105 : vector<4xf32> to vector<8xf16>
%107 = vector.extract_strided_slice %106 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%108 = arith.mulf %107, %102 : vector<4xf16>
%109 = vector.extract_strided_slice %106 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%110 = arith.mulf %109, %103 : vector<4xf16>
%111 = vector.extract_strided_slice %90 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%112 = arith.addf %108, %111 : vector<4xf16>
%113 = vector.insert_strided_slice %112, %cst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
%114 = vector.extract_strided_slice %90 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%115 = arith.addf %110, %114 : vector<4xf16>
%116 = vector.insert_strided_slice %115, %113 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
%117 = vector.bitcast %116 : vector<8xf16> to vector<4xf32>
scf.yield %117 : vector<4xf32>
}
scf.yield %89 : vector<4xf32>
}
%41 = vector.bitcast %40 : vector<4xf32> to vector<8xf16>
%42 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%37, %c0]
%43 = memref.load %17[%42] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%44 = vector.extract_strided_slice %41 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%45 = vector.reduction <add>, %44 : vector<4xf16> into f16
%46 = vector.insert %45, %cst_0 [0] : f16 into vector<2xf16>
%47 = vector.extract_strided_slice %41 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
%48 = vector.reduction <add>, %47 : vector<4xf16> into f16
%49 = vector.insert %48, %46 [1] : f16 into vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %51, %c1_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %56, %c2_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_3 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %61, %c4_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_5 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %66, %c8_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_7 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %71, %c16_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_9 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.bitcast %74 : vector<2xf16> to vector<1xi32>
%76 = vector.extract %75[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %76, %c32_i32, %c64_i32 : i32
%77 = vector.splat %shuffleResult_11 : vector<1xi32>
%78 = vector.bitcast %77 : vector<1xi32> to vector<2xf16>
%79 = arith.addf %74, %78 : vector<2xf16>
%80 = vector.reduction <add>, %79 : vector<2xf16> into f16
%81 = arith.addf %80, %43 : f16
%82 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%37, %c0]
memref.store %81, %17[%82] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After OptimizeVectorTransfer (iree-codegen-optimize-vector-transfer) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_1 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c0]
memref.store %cst, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c1]
memref.store %cst, %17[%20] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c2]
memref.store %cst, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%22 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c3]
memref.store %cst, %17[%22] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%23 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%24 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%25 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
%26 = memref.load %16[%23, %25] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%27 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
memref.store %26, %alloc[%23, %27] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%28 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%29 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
%30 = memref.load %16[%28, %29] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
memref.store %30, %alloc[%28, %31] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%32 = arith.cmpi ult, %0, %c1 : index
scf.if %32 {
%33 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_0) -> (vector<4xf32>) {
%78 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%79 = memref.load %14[%33, %78] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%80 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%81 = memref.load %15[%33, %80] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%82 = vector.splat %81 : vector<4xf16>
%83 = vector.splat %79 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%85 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%86 = memref.load %13[%33, %arg0, %85] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%87 = vector.bitcast %86 : vector<1xi32> to vector<8xi4>
%88 = vector.extract_strided_slice %87 {offsets = [0], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%89 = arith.extui %88 : vector<4xi4> to vector<4xi32>
%90 = vector.extract_strided_slice %87 {offsets = [4], sizes = [4], strides = [1]} : vector<8xi4> to vector<4xi4>
%91 = arith.extui %90 : vector<4xi4> to vector<4xi32>
%92 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%93 = arith.uitofp %91 : vector<4xi32> to vector<4xf16>
%94 = arith.subf %92, %82 : vector<4xf16>
%95 = arith.subf %93, %82 : vector<4xf16>
%96 = arith.mulf %94, %83 : vector<4xf16>
%97 = arith.mulf %95, %83 : vector<4xf16>
%98 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%99 = memref.load %alloc[%arg0, %98] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%100 = vector.extract_strided_slice %99 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%101 = vector.bitcast %100 : vector<2xf32> to vector<4xf16>
%102 = arith.mulf %101, %96 : vector<4xf16>
%103 = vector.extract_strided_slice %99 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%104 = vector.bitcast %103 : vector<2xf32> to vector<4xf16>
%105 = arith.mulf %104, %97 : vector<4xf16>
%106 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%107 = vector.bitcast %106 : vector<2xf32> to vector<4xf16>
%108 = arith.addf %102, %107 : vector<4xf16>
%109 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%110 = vector.bitcast %109 : vector<2xf32> to vector<4xf16>
%111 = arith.addf %105, %110 : vector<4xf16>
%112 = vector.bitcast %111 : vector<4xf16> to vector<2xf32>
%113 = vector.bitcast %108 : vector<4xf16> to vector<2xf32>
%114 = vector.insert_strided_slice %113, %cst_0 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%115 = vector.insert_strided_slice %112, %114 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %115 : vector<4xf32>
}
scf.yield %84 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %c0]
%36 = memref.load %17[%35] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_1 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_2, %valid_3 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_2 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_4, %valid_5 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_4 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_6, %valid_7 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_6 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_8, %valid_9 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_8 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_10, %valid_11 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_10 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %c0]
memref.store %76, %17[%77] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After SPIRVBreakDownLargeVector (iree-spirv-breakdown-large-vector) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c0]
memref.store %cst_0, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c1]
memref.store %cst_0, %17[%20] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c2]
memref.store %cst_0, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%22 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c3]
memref.store %cst_0, %17[%22] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%23 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%24 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%25 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
%26 = memref.load %16[%23, %25] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%27 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
memref.store %26, %alloc[%23, %27] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%28 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%29 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
%30 = memref.load %16[%28, %29] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
memref.store %30, %alloc[%28, %31] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%32 = arith.cmpi ult, %0, %c1 : index
scf.if %32 {
%33 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%78 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%79 = memref.load %14[%33, %78] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%80 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%81 = memref.load %15[%33, %80] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%82 = vector.splat %81 : vector<4xf16>
%83 = vector.splat %79 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%85 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%86 = memref.load %13[%33, %arg0, %85] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%87 = vector.extract %86[0] : vector<1xi32>
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %cst [0] : i32 into vector<4xi32>
%90 = arith.shrui %87, %c4_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %89 [1] : i32 into vector<4xi32>
%93 = arith.shrui %87, %c8_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [2] : i32 into vector<4xi32>
%96 = arith.shrui %87, %c12_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [3] : i32 into vector<4xi32>
%99 = vector.extract %86[0] : vector<1xi32>
%100 = arith.shrui %99, %c16_i32 : i32
%101 = arith.andi %100, %c15_i32 : i32
%102 = vector.insert %101, %cst [0] : i32 into vector<4xi32>
%103 = arith.shrui %99, %c20_i32 : i32
%104 = arith.andi %103, %c15_i32 : i32
%105 = vector.insert %104, %102 [1] : i32 into vector<4xi32>
%106 = arith.shrui %99, %c24_i32 : i32
%107 = arith.andi %106, %c15_i32 : i32
%108 = vector.insert %107, %105 [2] : i32 into vector<4xi32>
%109 = arith.shrui %99, %c28_i32 : i32
%110 = arith.andi %109, %c15_i32 : i32
%111 = vector.insert %110, %108 [3] : i32 into vector<4xi32>
%112 = arith.uitofp %98 : vector<4xi32> to vector<4xf16>
%113 = arith.uitofp %111 : vector<4xi32> to vector<4xf16>
%114 = arith.subf %112, %82 : vector<4xf16>
%115 = arith.subf %113, %82 : vector<4xf16>
%116 = arith.mulf %114, %83 : vector<4xf16>
%117 = arith.mulf %115, %83 : vector<4xf16>
%118 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%119 = memref.load %alloc[%arg0, %118] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%120 = vector.extract_strided_slice %119 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%121 = vector.bitcast %120 : vector<2xf32> to vector<4xf16>
%122 = arith.mulf %121, %116 : vector<4xf16>
%123 = vector.extract_strided_slice %119 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%124 = vector.bitcast %123 : vector<2xf32> to vector<4xf16>
%125 = arith.mulf %124, %117 : vector<4xf16>
%126 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%127 = vector.bitcast %126 : vector<2xf32> to vector<4xf16>
%128 = arith.addf %122, %127 : vector<4xf16>
%129 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%130 = vector.bitcast %129 : vector<2xf32> to vector<4xf16>
%131 = arith.addf %125, %130 : vector<4xf16>
%132 = vector.bitcast %131 : vector<4xf16> to vector<2xf32>
%133 = vector.bitcast %128 : vector<4xf16> to vector<2xf32>
%134 = vector.insert_strided_slice %133, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%135 = vector.insert_strided_slice %132, %134 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %135 : vector<4xf32>
}
scf.yield %84 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %c0]
%36 = memref.load %17[%35] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_2 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_3 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_5 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_7 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_9 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_11 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %c0]
memref.store %76, %17[%77] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After ForOpCanonicalization (iree-codegen-canonicalize-scf-for) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c0]
memref.store %cst_0, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c1]
memref.store %cst_0, %17[%20] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c2]
memref.store %cst_0, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%22 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %c3]
memref.store %cst_0, %17[%22] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%23 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%24 = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 16) * 128)>()[%0]
%25 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
%26 = memref.load %16[%23, %25] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%27 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
memref.store %26, %alloc[%23, %27] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%28 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%29 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
%30 = memref.load %16[%28, %29] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%24, %c8]
memref.store %30, %alloc[%28, %31] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%32 = arith.cmpi ult, %0, %c1 : index
scf.if %32 {
%33 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%78 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%79 = memref.load %14[%33, %78] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%80 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %c0]
%81 = memref.load %15[%33, %80] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%82 = vector.splat %81 : vector<4xf16>
%83 = vector.splat %79 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%85 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%86 = memref.load %13[%33, %arg0, %85] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%87 = vector.extract %86[0] : vector<1xi32>
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %cst [0] : i32 into vector<4xi32>
%90 = arith.shrui %87, %c4_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %89 [1] : i32 into vector<4xi32>
%93 = arith.shrui %87, %c8_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [2] : i32 into vector<4xi32>
%96 = arith.shrui %87, %c12_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [3] : i32 into vector<4xi32>
%99 = vector.extract %86[0] : vector<1xi32>
%100 = arith.shrui %99, %c16_i32 : i32
%101 = arith.andi %100, %c15_i32 : i32
%102 = vector.insert %101, %cst [0] : i32 into vector<4xi32>
%103 = arith.shrui %99, %c20_i32 : i32
%104 = arith.andi %103, %c15_i32 : i32
%105 = vector.insert %104, %102 [1] : i32 into vector<4xi32>
%106 = arith.shrui %99, %c24_i32 : i32
%107 = arith.andi %106, %c15_i32 : i32
%108 = vector.insert %107, %105 [2] : i32 into vector<4xi32>
%109 = arith.shrui %99, %c28_i32 : i32
%110 = arith.andi %109, %c15_i32 : i32
%111 = vector.insert %110, %108 [3] : i32 into vector<4xi32>
%112 = arith.uitofp %98 : vector<4xi32> to vector<4xf16>
%113 = arith.uitofp %111 : vector<4xi32> to vector<4xf16>
%114 = arith.subf %112, %82 : vector<4xf16>
%115 = arith.subf %113, %82 : vector<4xf16>
%116 = arith.mulf %114, %83 : vector<4xf16>
%117 = arith.mulf %115, %83 : vector<4xf16>
%118 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv s1)>()[%arg2, %c8]
%119 = memref.load %alloc[%arg0, %118] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%120 = vector.extract_strided_slice %119 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%121 = vector.bitcast %120 : vector<2xf32> to vector<4xf16>
%122 = arith.mulf %121, %116 : vector<4xf16>
%123 = vector.extract_strided_slice %119 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%124 = vector.bitcast %123 : vector<2xf32> to vector<4xf16>
%125 = arith.mulf %124, %117 : vector<4xf16>
%126 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%127 = vector.bitcast %126 : vector<2xf32> to vector<4xf16>
%128 = arith.addf %122, %127 : vector<4xf16>
%129 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%130 = vector.bitcast %129 : vector<2xf32> to vector<4xf16>
%131 = arith.addf %125, %130 : vector<4xf16>
%132 = vector.bitcast %131 : vector<4xf16> to vector<2xf32>
%133 = vector.bitcast %128 : vector<4xf16> to vector<2xf32>
%134 = vector.insert_strided_slice %133, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%135 = vector.insert_strided_slice %132, %134 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %135 : vector<4xf32>
}
scf.yield %84 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %c0]
%36 = memref.load %17[%35] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_2 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_3 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_5 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_7 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_9 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_11 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
%77 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %c0]
memref.store %76, %17[%77] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
memref.store %cst_0, %17[%18] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%19 = affine.apply affine_map<()[s0] -> (s0 * 4 + 1)>()[%workgroup_id_x]
memref.store %cst_0, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = affine.apply affine_map<()[s0] -> (s0 * 4 + 2)>()[%workgroup_id_x]
memref.store %cst_0, %17[%20] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0] -> (s0 * 4 + 3)>()[%workgroup_id_x]
memref.store %cst_0, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%0]
%24 = memref.load %16[%22, %23] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%0]
memref.store %24, %alloc[%22, %25] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%26 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%27 = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%0]
%28 = memref.load %16[%26, %27] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%29 = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%0]
memref.store %28, %alloc[%26, %29] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%30 = arith.cmpi ult, %0, %c1 : index
scf.if %30 {
%31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%32 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%76 = memref.load %14[%31, %arg0] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%77 = memref.load %15[%31, %arg0] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%78 = vector.splat %77 : vector<4xf16>
%79 = vector.splat %76 : vector<4xf16>
%80 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%81 = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%arg2]
%82 = memref.load %13[%31, %arg0, %81] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%83 = vector.extract %82[0] : vector<1xi32>
%84 = arith.andi %83, %c15_i32 : i32
%85 = vector.insert %84, %cst [0] : i32 into vector<4xi32>
%86 = arith.shrui %83, %c4_i32 : i32
%87 = arith.andi %86, %c15_i32 : i32
%88 = vector.insert %87, %85 [1] : i32 into vector<4xi32>
%89 = arith.shrui %83, %c8_i32 : i32
%90 = arith.andi %89, %c15_i32 : i32
%91 = vector.insert %90, %88 [2] : i32 into vector<4xi32>
%92 = arith.shrui %83, %c12_i32 : i32
%93 = arith.andi %92, %c15_i32 : i32
%94 = vector.insert %93, %91 [3] : i32 into vector<4xi32>
%95 = vector.extract %82[0] : vector<1xi32>
%96 = arith.shrui %95, %c16_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %cst [0] : i32 into vector<4xi32>
%99 = arith.shrui %95, %c20_i32 : i32
%100 = arith.andi %99, %c15_i32 : i32
%101 = vector.insert %100, %98 [1] : i32 into vector<4xi32>
%102 = arith.shrui %95, %c24_i32 : i32
%103 = arith.andi %102, %c15_i32 : i32
%104 = vector.insert %103, %101 [2] : i32 into vector<4xi32>
%105 = arith.shrui %95, %c28_i32 : i32
%106 = arith.andi %105, %c15_i32 : i32
%107 = vector.insert %106, %104 [3] : i32 into vector<4xi32>
%108 = arith.uitofp %94 : vector<4xi32> to vector<4xf16>
%109 = arith.uitofp %107 : vector<4xi32> to vector<4xf16>
%110 = arith.subf %108, %78 : vector<4xf16>
%111 = arith.subf %109, %78 : vector<4xf16>
%112 = arith.mulf %110, %79 : vector<4xf16>
%113 = arith.mulf %111, %79 : vector<4xf16>
%114 = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%arg2]
%115 = memref.load %alloc[%arg0, %114] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%116 = vector.extract_strided_slice %115 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%117 = vector.bitcast %116 : vector<2xf32> to vector<4xf16>
%118 = arith.mulf %117, %112 : vector<4xf16>
%119 = vector.extract_strided_slice %115 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%120 = vector.bitcast %119 : vector<2xf32> to vector<4xf16>
%121 = arith.mulf %120, %113 : vector<4xf16>
%122 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%123 = vector.bitcast %122 : vector<2xf32> to vector<4xf16>
%124 = arith.addf %118, %123 : vector<4xf16>
%125 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%126 = vector.bitcast %125 : vector<2xf32> to vector<4xf16>
%127 = arith.addf %121, %126 : vector<4xf16>
%128 = vector.bitcast %127 : vector<4xf16> to vector<2xf32>
%129 = vector.bitcast %124 : vector<4xf16> to vector<2xf32>
%130 = vector.insert_strided_slice %129, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%131 = vector.insert_strided_slice %128, %130 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %131 : vector<4xf32>
}
scf.yield %80 : vector<4xf32>
}
%33 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%34 = memref.load %17[%33] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%35 = vector.extract_strided_slice %32 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%36 = vector.bitcast %35 : vector<2xf32> to vector<4xf16>
%37 = vector.reduction <add>, %36 : vector<4xf16> into f16
%38 = vector.insert %37, %cst_2 [0] : f16 into vector<2xf16>
%39 = vector.extract_strided_slice %32 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%40 = vector.bitcast %39 : vector<2xf32> to vector<4xf16>
%41 = vector.reduction <add>, %40 : vector<4xf16> into f16
%42 = vector.insert %41, %38 [1] : f16 into vector<2xf16>
%43 = vector.bitcast %42 : vector<2xf16> to vector<1xi32>
%44 = vector.extract %43[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %44, %c1_i32, %c64_i32 : i32
%45 = vector.splat %shuffleResult : vector<1xi32>
%46 = vector.bitcast %45 : vector<1xi32> to vector<2xf16>
%47 = arith.addf %42, %46 : vector<2xf16>
%48 = vector.bitcast %47 : vector<2xf16> to vector<1xi32>
%49 = vector.extract %48[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %49, %c2_i32, %c64_i32 : i32
%50 = vector.splat %shuffleResult_3 : vector<1xi32>
%51 = vector.bitcast %50 : vector<1xi32> to vector<2xf16>
%52 = arith.addf %47, %51 : vector<2xf16>
%53 = vector.bitcast %52 : vector<2xf16> to vector<1xi32>
%54 = vector.extract %53[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %54, %c4_i32, %c64_i32 : i32
%55 = vector.splat %shuffleResult_5 : vector<1xi32>
%56 = vector.bitcast %55 : vector<1xi32> to vector<2xf16>
%57 = arith.addf %52, %56 : vector<2xf16>
%58 = vector.bitcast %57 : vector<2xf16> to vector<1xi32>
%59 = vector.extract %58[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %59, %c8_i32, %c64_i32 : i32
%60 = vector.splat %shuffleResult_7 : vector<1xi32>
%61 = vector.bitcast %60 : vector<1xi32> to vector<2xf16>
%62 = arith.addf %57, %61 : vector<2xf16>
%63 = vector.bitcast %62 : vector<2xf16> to vector<1xi32>
%64 = vector.extract %63[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %64, %c16_i32, %c64_i32 : i32
%65 = vector.splat %shuffleResult_9 : vector<1xi32>
%66 = vector.bitcast %65 : vector<1xi32> to vector<2xf16>
%67 = arith.addf %62, %66 : vector<2xf16>
%68 = vector.bitcast %67 : vector<2xf16> to vector<1xi32>
%69 = vector.extract %68[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %69, %c32_i32, %c64_i32 : i32
%70 = vector.splat %shuffleResult_11 : vector<1xi32>
%71 = vector.bitcast %70 : vector<1xi32> to vector<2xf16>
%72 = arith.addf %67, %71 : vector<2xf16>
%73 = vector.reduction <add>, %72 : vector<2xf16> into f16
%74 = arith.addf %73, %34 : f16
%75 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
memref.store %74, %17[%75] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
memref.store %cst_0, %17[%18] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%19 = affine.apply affine_map<()[s0] -> (s0 * 4 + 1)>()[%workgroup_id_x]
memref.store %cst_0, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = affine.apply affine_map<()[s0] -> (s0 * 4 + 2)>()[%workgroup_id_x]
memref.store %cst_0, %17[%20] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0] -> (s0 * 4 + 3)>()[%workgroup_id_x]
memref.store %cst_0, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%0]
%24 = memref.load %16[%22, %23] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.store %24, %alloc[%22, %23] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%25 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%26 = memref.load %16[%25, %23] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.store %26, %alloc[%25, %23] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%27 = arith.cmpi ult, %0, %c1 : index
scf.if %27 {
%28 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%29 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%71 = memref.load %14[%28, %arg0] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%72 = memref.load %15[%28, %arg0] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%73 = vector.splat %72 : vector<4xf16>
%74 = vector.splat %71 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%76 = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%arg2]
%77 = memref.load %13[%28, %arg0, %76] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%78 = vector.extract %77[0] : vector<1xi32>
%79 = arith.andi %78, %c15_i32 : i32
%80 = vector.insert %79, %cst [0] : i32 into vector<4xi32>
%81 = arith.shrui %78, %c4_i32 : i32
%82 = arith.andi %81, %c15_i32 : i32
%83 = vector.insert %82, %80 [1] : i32 into vector<4xi32>
%84 = arith.shrui %78, %c8_i32 : i32
%85 = arith.andi %84, %c15_i32 : i32
%86 = vector.insert %85, %83 [2] : i32 into vector<4xi32>
%87 = arith.shrui %78, %c12_i32 : i32
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %86 [3] : i32 into vector<4xi32>
%90 = arith.shrui %78, %c16_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %cst [0] : i32 into vector<4xi32>
%93 = arith.shrui %78, %c20_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [1] : i32 into vector<4xi32>
%96 = arith.shrui %78, %c24_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [2] : i32 into vector<4xi32>
%99 = arith.shrui %78, %c28_i32 : i32
%100 = arith.andi %99, %c15_i32 : i32
%101 = vector.insert %100, %98 [3] : i32 into vector<4xi32>
%102 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%103 = arith.uitofp %101 : vector<4xi32> to vector<4xf16>
%104 = arith.subf %102, %73 : vector<4xf16>
%105 = arith.subf %103, %73 : vector<4xf16>
%106 = arith.mulf %104, %74 : vector<4xf16>
%107 = arith.mulf %105, %74 : vector<4xf16>
%108 = memref.load %alloc[%arg0, %76] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%109 = vector.extract_strided_slice %108 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%110 = vector.bitcast %109 : vector<2xf32> to vector<4xf16>
%111 = arith.mulf %110, %106 : vector<4xf16>
%112 = vector.extract_strided_slice %108 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%113 = vector.bitcast %112 : vector<2xf32> to vector<4xf16>
%114 = arith.mulf %113, %107 : vector<4xf16>
%115 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%116 = vector.bitcast %115 : vector<2xf32> to vector<4xf16>
%117 = arith.addf %111, %116 : vector<4xf16>
%118 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%119 = vector.bitcast %118 : vector<2xf32> to vector<4xf16>
%120 = arith.addf %114, %119 : vector<4xf16>
%121 = vector.bitcast %120 : vector<4xf16> to vector<2xf32>
%122 = vector.bitcast %117 : vector<4xf16> to vector<2xf32>
%123 = vector.insert_strided_slice %122, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%124 = vector.insert_strided_slice %121, %123 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %124 : vector<4xf32>
}
scf.yield %75 : vector<4xf32>
}
%30 = memref.load %17[%28] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = vector.extract_strided_slice %29 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%32 = vector.bitcast %31 : vector<2xf32> to vector<4xf16>
%33 = vector.reduction <add>, %32 : vector<4xf16> into f16
%34 = vector.insert %33, %cst_2 [0] : f16 into vector<2xf16>
%35 = vector.extract_strided_slice %29 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%36 = vector.bitcast %35 : vector<2xf32> to vector<4xf16>
%37 = vector.reduction <add>, %36 : vector<4xf16> into f16
%38 = vector.insert %37, %34 [1] : f16 into vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %40, %c1_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %45, %c2_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_3 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %50, %c4_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_5 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %55, %c8_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_7 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %60, %c16_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_9 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %65, %c32_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_11 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.reduction <add>, %68 : vector<2xf16> into f16
%70 = arith.addf %69, %30 : f16
memref.store %70, %17[%28] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After OptimizeVectorTransfer (iree-codegen-optimize-vector-transfer) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%8) flags(ReadOnly) : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %13, 1 : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %14, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %15, 1 : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %16, 1 : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %17, 1 : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%18 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
memref.store %cst_0, %17[%18] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%19 = affine.apply affine_map<()[s0] -> (s0 * 4 + 1)>()[%workgroup_id_x]
memref.store %cst_0, %17[%19] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%20 = affine.apply affine_map<()[s0] -> (s0 * 4 + 2)>()[%workgroup_id_x]
memref.store %cst_0, %17[%20] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%21 = affine.apply affine_map<()[s0] -> (s0 * 4 + 3)>()[%workgroup_id_x]
memref.store %cst_0, %17[%21] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%22 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16)>()[%0, %1, %2]
%23 = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%0]
%24 = memref.load %16[%22, %23] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.store %24, %alloc[%22, %23] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%25 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 4 + s2 * 16 + s0 floordiv 16 + 16)>()[%0, %1, %2]
%26 = memref.load %16[%25, %23] : memref<32x16xvector<4xf32>, strided<[16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.store %26, %alloc[%25, %23] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%27 = arith.cmpi ult, %0, %c1 : index
scf.if %27 {
%28 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%1, %workgroup_id_x]
%29 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%71 = memref.load %14[%28, %arg0] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%72 = memref.load %15[%28, %arg0] : memref<4096x32xf16, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%73 = vector.splat %72 : vector<4xf16>
%74 = vector.splat %71 : vector<4xf16>
%75 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%76 = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%arg2]
%77 = memref.load %13[%28, %arg0, %76] : memref<4096x32x16xvector<1xi32>, strided<[512, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%78 = vector.extract %77[0] : vector<1xi32>
%79 = arith.andi %78, %c15_i32 : i32
%80 = vector.insert %79, %cst [0] : i32 into vector<4xi32>
%81 = arith.shrui %78, %c4_i32 : i32
%82 = arith.andi %81, %c15_i32 : i32
%83 = vector.insert %82, %80 [1] : i32 into vector<4xi32>
%84 = arith.shrui %78, %c8_i32 : i32
%85 = arith.andi %84, %c15_i32 : i32
%86 = vector.insert %85, %83 [2] : i32 into vector<4xi32>
%87 = arith.shrui %78, %c12_i32 : i32
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %86 [3] : i32 into vector<4xi32>
%90 = arith.shrui %78, %c16_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %cst [0] : i32 into vector<4xi32>
%93 = arith.shrui %78, %c20_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [1] : i32 into vector<4xi32>
%96 = arith.shrui %78, %c24_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [2] : i32 into vector<4xi32>
%99 = arith.shrui %78, %c28_i32 : i32
%100 = arith.andi %99, %c15_i32 : i32
%101 = vector.insert %100, %98 [3] : i32 into vector<4xi32>
%102 = arith.uitofp %89 : vector<4xi32> to vector<4xf16>
%103 = arith.uitofp %101 : vector<4xi32> to vector<4xf16>
%104 = arith.subf %102, %73 : vector<4xf16>
%105 = arith.subf %103, %73 : vector<4xf16>
%106 = arith.mulf %104, %74 : vector<4xf16>
%107 = arith.mulf %105, %74 : vector<4xf16>
%108 = memref.load %alloc[%arg0, %76] : memref<32x17xvector<4xf32>, #gpu.address_space<workgroup>>
%109 = vector.extract_strided_slice %108 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%110 = vector.bitcast %109 : vector<2xf32> to vector<4xf16>
%111 = arith.mulf %110, %106 : vector<4xf16>
%112 = vector.extract_strided_slice %108 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%113 = vector.bitcast %112 : vector<2xf32> to vector<4xf16>
%114 = arith.mulf %113, %107 : vector<4xf16>
%115 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%116 = vector.bitcast %115 : vector<2xf32> to vector<4xf16>
%117 = arith.addf %111, %116 : vector<4xf16>
%118 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%119 = vector.bitcast %118 : vector<2xf32> to vector<4xf16>
%120 = arith.addf %114, %119 : vector<4xf16>
%121 = vector.bitcast %120 : vector<4xf16> to vector<2xf32>
%122 = vector.bitcast %117 : vector<4xf16> to vector<2xf32>
%123 = vector.insert_strided_slice %122, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%124 = vector.insert_strided_slice %121, %123 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %124 : vector<4xf32>
}
scf.yield %75 : vector<4xf32>
}
%30 = memref.load %17[%28] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%31 = vector.extract_strided_slice %29 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%32 = vector.bitcast %31 : vector<2xf32> to vector<4xf16>
%33 = vector.reduction <add>, %32 : vector<4xf16> into f16
%34 = vector.insert %33, %cst_2 [0] : f16 into vector<2xf16>
%35 = vector.extract_strided_slice %29 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%36 = vector.bitcast %35 : vector<2xf32> to vector<4xf16>
%37 = vector.reduction <add>, %36 : vector<4xf16> into f16
%38 = vector.insert %37, %34 [1] : f16 into vector<2xf16>
%39 = vector.bitcast %38 : vector<2xf16> to vector<1xi32>
%40 = vector.extract %39[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %40, %c1_i32, %c64_i32 : i32
%41 = vector.splat %shuffleResult : vector<1xi32>
%42 = vector.bitcast %41 : vector<1xi32> to vector<2xf16>
%43 = arith.addf %38, %42 : vector<2xf16>
%44 = vector.bitcast %43 : vector<2xf16> to vector<1xi32>
%45 = vector.extract %44[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %45, %c2_i32, %c64_i32 : i32
%46 = vector.splat %shuffleResult_3 : vector<1xi32>
%47 = vector.bitcast %46 : vector<1xi32> to vector<2xf16>
%48 = arith.addf %43, %47 : vector<2xf16>
%49 = vector.bitcast %48 : vector<2xf16> to vector<1xi32>
%50 = vector.extract %49[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %50, %c4_i32, %c64_i32 : i32
%51 = vector.splat %shuffleResult_5 : vector<1xi32>
%52 = vector.bitcast %51 : vector<1xi32> to vector<2xf16>
%53 = arith.addf %48, %52 : vector<2xf16>
%54 = vector.bitcast %53 : vector<2xf16> to vector<1xi32>
%55 = vector.extract %54[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %55, %c8_i32, %c64_i32 : i32
%56 = vector.splat %shuffleResult_7 : vector<1xi32>
%57 = vector.bitcast %56 : vector<1xi32> to vector<2xf16>
%58 = arith.addf %53, %57 : vector<2xf16>
%59 = vector.bitcast %58 : vector<2xf16> to vector<1xi32>
%60 = vector.extract %59[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %60, %c16_i32, %c64_i32 : i32
%61 = vector.splat %shuffleResult_9 : vector<1xi32>
%62 = vector.bitcast %61 : vector<1xi32> to vector<2xf16>
%63 = arith.addf %58, %62 : vector<2xf16>
%64 = vector.bitcast %63 : vector<2xf16> to vector<1xi32>
%65 = vector.extract %64[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %65, %c32_i32, %c64_i32 : i32
%66 = vector.splat %shuffleResult_11 : vector<1xi32>
%67 = vector.bitcast %66 : vector<1xi32> to vector<2xf16>
%68 = arith.addf %63, %67 : vector<2xf16>
%69 = vector.reduction <add>, %68 : vector<2xf16> into f16
%70 = arith.addf %69, %30 : f16
memref.store %70, %17[%28] : memref<4096xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After FlattenMemRefSubspan (iree-codegen-flatten-memref-subspan) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = affine.apply affine_map<()[s0] -> (s0 floordiv 4 + 2097152)>()[%8]
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%13}
%15 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%9]
%16 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%15}
%17 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%10]
%18 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%17}
%19 = affine.apply affine_map<()[s0] -> (s0 floordiv 16 + 512)>()[%11]
%20 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%19}
%21 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 4096)>()[%12]
%22 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%21}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%23 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%23] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%24 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 1)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%24] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%25] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%26 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 3)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%26] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%27 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16)>()[%11, %0, %1, %2]
%28 = memref.load %20[%27] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%29 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16)>()[%0, %1, %2]
memref.store %28, %alloc[%29] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%30 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16 + 256)>()[%11, %0, %1, %2]
%31 = memref.load %20[%30] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%32 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16 + 272)>()[%0, %1, %2]
memref.store %31, %alloc[%32] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%33 = arith.cmpi ult, %0, %c1 : index
scf.if %33 {
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%78 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%9, %1, %workgroup_id_x]
%79 = memref.load %16[%78] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%80 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%10, %1, %workgroup_id_x]
%81 = memref.load %18[%80] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%82 = vector.splat %81 : vector<4xf16>
%83 = vector.splat %79 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%85 = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 * 16 + s1 * 512 + s2 * 2048 + s3 floordiv 8 + s0 floordiv 4)>(%arg0)[%8, %1, %workgroup_id_x, %arg2]
%86 = memref.load %14[%85] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%87 = vector.extract %86[0] : vector<1xi32>
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %cst [0] : i32 into vector<4xi32>
%90 = arith.shrui %87, %c4_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %89 [1] : i32 into vector<4xi32>
%93 = arith.shrui %87, %c8_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [2] : i32 into vector<4xi32>
%96 = arith.shrui %87, %c12_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [3] : i32 into vector<4xi32>
%99 = arith.shrui %87, %c16_i32 : i32
%100 = arith.andi %99, %c15_i32 : i32
%101 = vector.insert %100, %cst [0] : i32 into vector<4xi32>
%102 = arith.shrui %87, %c20_i32 : i32
%103 = arith.andi %102, %c15_i32 : i32
%104 = vector.insert %103, %101 [1] : i32 into vector<4xi32>
%105 = arith.shrui %87, %c24_i32 : i32
%106 = arith.andi %105, %c15_i32 : i32
%107 = vector.insert %106, %104 [2] : i32 into vector<4xi32>
%108 = arith.shrui %87, %c28_i32 : i32
%109 = arith.andi %108, %c15_i32 : i32
%110 = vector.insert %109, %107 [3] : i32 into vector<4xi32>
%111 = arith.uitofp %98 : vector<4xi32> to vector<4xf16>
%112 = arith.uitofp %110 : vector<4xi32> to vector<4xf16>
%113 = arith.subf %111, %82 : vector<4xf16>
%114 = arith.subf %112, %82 : vector<4xf16>
%115 = arith.mulf %113, %83 : vector<4xf16>
%116 = arith.mulf %114, %83 : vector<4xf16>
%117 = affine.apply affine_map<(d0)[s0] -> (d0 * 17 + s0 floordiv 8)>(%arg0)[%arg2]
%118 = memref.load %alloc[%117] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%119 = vector.extract_strided_slice %118 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%120 = vector.bitcast %119 : vector<2xf32> to vector<4xf16>
%121 = arith.mulf %120, %115 : vector<4xf16>
%122 = vector.extract_strided_slice %118 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%123 = vector.bitcast %122 : vector<2xf32> to vector<4xf16>
%124 = arith.mulf %123, %116 : vector<4xf16>
%125 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%126 = vector.bitcast %125 : vector<2xf32> to vector<4xf16>
%127 = arith.addf %121, %126 : vector<4xf16>
%128 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%129 = vector.bitcast %128 : vector<2xf32> to vector<4xf16>
%130 = arith.addf %124, %129 : vector<4xf16>
%131 = vector.bitcast %130 : vector<4xf16> to vector<2xf32>
%132 = vector.bitcast %127 : vector<4xf16> to vector<2xf32>
%133 = vector.insert_strided_slice %132, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%134 = vector.insert_strided_slice %131, %133 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %134 : vector<4xf32>
}
scf.yield %84 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
%36 = memref.load %22[%35] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_2 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_3 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_5 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_7 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_9 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_11 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
%77 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
memref.store %76, %22[%77] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After SPIRVEraseStorageBufferStaticShape (iree-spirv-erase-storage-buffer-static-shape) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = affine.apply affine_map<()[s0] -> (s0 floordiv 4 + 2097152)>()[%8]
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%13}
%15 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%9]
%16 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%15}
%17 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%10]
%18 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%17}
%19 = affine.apply affine_map<()[s0] -> (s0 floordiv 16 + 512)>()[%11]
%20 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%19}
%21 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 4096)>()[%12]
%22 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%21}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%23 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%23] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%24 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 1)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%24] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%25] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%26 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 3)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%26] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%27 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16)>()[%11, %0, %1, %2]
%28 = memref.load %20[%27] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%29 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16)>()[%0, %1, %2]
memref.store %28, %alloc[%29] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%30 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16 + 256)>()[%11, %0, %1, %2]
%31 = memref.load %20[%30] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%32 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16 + 272)>()[%0, %1, %2]
memref.store %31, %alloc[%32] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%33 = arith.cmpi ult, %0, %c1 : index
scf.if %33 {
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%78 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%9, %1, %workgroup_id_x]
%79 = memref.load %16[%78] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%80 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%10, %1, %workgroup_id_x]
%81 = memref.load %18[%80] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%82 = vector.splat %81 : vector<4xf16>
%83 = vector.splat %79 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%85 = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 * 16 + s1 * 512 + s2 * 2048 + s3 floordiv 8 + s0 floordiv 4)>(%arg0)[%8, %1, %workgroup_id_x, %arg2]
%86 = memref.load %14[%85] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%87 = vector.extract %86[0] : vector<1xi32>
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %cst [0] : i32 into vector<4xi32>
%90 = arith.shrui %87, %c4_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %89 [1] : i32 into vector<4xi32>
%93 = arith.shrui %87, %c8_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [2] : i32 into vector<4xi32>
%96 = arith.shrui %87, %c12_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [3] : i32 into vector<4xi32>
%99 = arith.shrui %87, %c16_i32 : i32
%100 = arith.andi %99, %c15_i32 : i32
%101 = vector.insert %100, %cst [0] : i32 into vector<4xi32>
%102 = arith.shrui %87, %c20_i32 : i32
%103 = arith.andi %102, %c15_i32 : i32
%104 = vector.insert %103, %101 [1] : i32 into vector<4xi32>
%105 = arith.shrui %87, %c24_i32 : i32
%106 = arith.andi %105, %c15_i32 : i32
%107 = vector.insert %106, %104 [2] : i32 into vector<4xi32>
%108 = arith.shrui %87, %c28_i32 : i32
%109 = arith.andi %108, %c15_i32 : i32
%110 = vector.insert %109, %107 [3] : i32 into vector<4xi32>
%111 = arith.uitofp %98 : vector<4xi32> to vector<4xf16>
%112 = arith.uitofp %110 : vector<4xi32> to vector<4xf16>
%113 = arith.subf %111, %82 : vector<4xf16>
%114 = arith.subf %112, %82 : vector<4xf16>
%115 = arith.mulf %113, %83 : vector<4xf16>
%116 = arith.mulf %114, %83 : vector<4xf16>
%117 = affine.apply affine_map<(d0)[s0] -> (d0 * 17 + s0 floordiv 8)>(%arg0)[%arg2]
%118 = memref.load %alloc[%117] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%119 = vector.extract_strided_slice %118 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%120 = vector.bitcast %119 : vector<2xf32> to vector<4xf16>
%121 = arith.mulf %120, %115 : vector<4xf16>
%122 = vector.extract_strided_slice %118 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%123 = vector.bitcast %122 : vector<2xf32> to vector<4xf16>
%124 = arith.mulf %123, %116 : vector<4xf16>
%125 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%126 = vector.bitcast %125 : vector<2xf32> to vector<4xf16>
%127 = arith.addf %121, %126 : vector<4xf16>
%128 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%129 = vector.bitcast %128 : vector<2xf32> to vector<4xf16>
%130 = arith.addf %124, %129 : vector<4xf16>
%131 = vector.bitcast %130 : vector<4xf16> to vector<2xf32>
%132 = vector.bitcast %127 : vector<4xf16> to vector<2xf32>
%133 = vector.insert_strided_slice %132, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%134 = vector.insert_strided_slice %131, %133 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %134 : vector<4xf32>
}
scf.yield %84 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
%36 = memref.load %22[%35] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_2 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_3 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_5 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_7 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_9 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_11 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
%77 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
memref.store %76, %22[%77] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = affine.apply affine_map<()[s0] -> (s0 floordiv 4 + 2097152)>()[%8]
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%13}
%15 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%9]
%16 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%15}
%17 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%10]
%18 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%17}
%19 = affine.apply affine_map<()[s0] -> (s0 floordiv 16 + 512)>()[%11]
%20 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%19}
%21 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 4096)>()[%12]
%22 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%21}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%23 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%23] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%24 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 1)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%24] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%25] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%26 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 3)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%26] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%27 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16)>()[%11, %0, %1, %2]
%28 = memref.load %20[%27] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%29 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16)>()[%0, %1, %2]
memref.store %28, %alloc[%29] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%30 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16 + 256)>()[%11, %0, %1, %2]
%31 = memref.load %20[%30] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%32 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16 + 272)>()[%0, %1, %2]
memref.store %31, %alloc[%32] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%33 = arith.cmpi ult, %0, %c1 : index
scf.if %33 {
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%78 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%9, %1, %workgroup_id_x]
%79 = memref.load %16[%78] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%80 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%10, %1, %workgroup_id_x]
%81 = memref.load %18[%80] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%82 = vector.splat %81 : vector<4xf16>
%83 = vector.splat %79 : vector<4xf16>
%84 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%85 = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 * 16 + s1 * 512 + s2 * 2048 + s3 floordiv 8 + s0 floordiv 4)>(%arg0)[%8, %1, %workgroup_id_x, %arg2]
%86 = memref.load %14[%85] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%87 = vector.extract %86[0] : vector<1xi32>
%88 = arith.andi %87, %c15_i32 : i32
%89 = vector.insert %88, %cst [0] : i32 into vector<4xi32>
%90 = arith.shrui %87, %c4_i32 : i32
%91 = arith.andi %90, %c15_i32 : i32
%92 = vector.insert %91, %89 [1] : i32 into vector<4xi32>
%93 = arith.shrui %87, %c8_i32 : i32
%94 = arith.andi %93, %c15_i32 : i32
%95 = vector.insert %94, %92 [2] : i32 into vector<4xi32>
%96 = arith.shrui %87, %c12_i32 : i32
%97 = arith.andi %96, %c15_i32 : i32
%98 = vector.insert %97, %95 [3] : i32 into vector<4xi32>
%99 = arith.shrui %87, %c16_i32 : i32
%100 = arith.andi %99, %c15_i32 : i32
%101 = vector.insert %100, %cst [0] : i32 into vector<4xi32>
%102 = arith.shrui %87, %c20_i32 : i32
%103 = arith.andi %102, %c15_i32 : i32
%104 = vector.insert %103, %101 [1] : i32 into vector<4xi32>
%105 = arith.shrui %87, %c24_i32 : i32
%106 = arith.andi %105, %c15_i32 : i32
%107 = vector.insert %106, %104 [2] : i32 into vector<4xi32>
%108 = arith.shrui %87, %c28_i32 : i32
%109 = arith.andi %108, %c15_i32 : i32
%110 = vector.insert %109, %107 [3] : i32 into vector<4xi32>
%111 = arith.uitofp %98 : vector<4xi32> to vector<4xf16>
%112 = arith.uitofp %110 : vector<4xi32> to vector<4xf16>
%113 = arith.subf %111, %82 : vector<4xf16>
%114 = arith.subf %112, %82 : vector<4xf16>
%115 = arith.mulf %113, %83 : vector<4xf16>
%116 = arith.mulf %114, %83 : vector<4xf16>
%117 = affine.apply affine_map<(d0)[s0] -> (d0 * 17 + s0 floordiv 8)>(%arg0)[%arg2]
%118 = memref.load %alloc[%117] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%119 = vector.extract_strided_slice %118 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%120 = vector.bitcast %119 : vector<2xf32> to vector<4xf16>
%121 = arith.mulf %120, %115 : vector<4xf16>
%122 = vector.extract_strided_slice %118 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%123 = vector.bitcast %122 : vector<2xf32> to vector<4xf16>
%124 = arith.mulf %123, %116 : vector<4xf16>
%125 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%126 = vector.bitcast %125 : vector<2xf32> to vector<4xf16>
%127 = arith.addf %121, %126 : vector<4xf16>
%128 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%129 = vector.bitcast %128 : vector<2xf32> to vector<4xf16>
%130 = arith.addf %124, %129 : vector<4xf16>
%131 = vector.bitcast %130 : vector<4xf16> to vector<2xf32>
%132 = vector.bitcast %127 : vector<4xf16> to vector<2xf32>
%133 = vector.insert_strided_slice %132, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%134 = vector.insert_strided_slice %131, %133 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %134 : vector<4xf32>
}
scf.yield %84 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
%36 = memref.load %22[%35] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_2 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_3 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_5 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_7 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_9 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_11 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
%77 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
memref.store %76, %22[%77] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = affine.apply affine_map<()[s0] -> (s0 floordiv 4 + 2097152)>()[%8]
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%13}
%15 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%9]
%16 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%15}
%17 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 131072)>()[%10]
%18 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%17}
%19 = affine.apply affine_map<()[s0] -> (s0 floordiv 16 + 512)>()[%11]
%20 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%19}
%21 = affine.apply affine_map<()[s0] -> (s0 floordiv 2 + 4096)>()[%12]
%22 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%21}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%23 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%23] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%24 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 1)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%24] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%25 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 2)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%25] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%26 = affine.apply affine_map<()[s0, s1] -> (s1 * 4 + s0 floordiv 2 + 3)>()[%12, %workgroup_id_x]
memref.store %cst_0, %22[%26] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%27 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16)>()[%11, %0, %1, %2]
%28 = memref.load %20[%27] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%29 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16)>()[%0, %1, %2]
memref.store %28, %alloc[%29] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%30 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 + s2 * 64 + s3 * 256 + s0 floordiv 16 + 256)>()[%11, %0, %1, %2]
%31 = memref.load %20[%30] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%32 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 68 + s2 * 272 + s0 floordiv 16 + 272)>()[%0, %1, %2]
memref.store %31, %alloc[%32] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%33 = arith.cmpi ult, %0, %c1 : index
scf.if %33 {
%34 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%77 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%9, %1, %workgroup_id_x]
%78 = memref.load %16[%77] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%79 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 32 + s2 * 128 + s0 floordiv 2)>(%arg0)[%10, %1, %workgroup_id_x]
%80 = memref.load %18[%79] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%81 = vector.splat %80 : vector<4xf16>
%82 = vector.splat %78 : vector<4xf16>
%83 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%84 = affine.apply affine_map<(d0)[s0, s1, s2, s3] -> (d0 * 16 + s1 * 512 + s2 * 2048 + s3 floordiv 8 + s0 floordiv 4)>(%arg0)[%8, %1, %workgroup_id_x, %arg2]
%85 = memref.load %14[%84] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%86 = vector.extract %85[0] : vector<1xi32>
%87 = arith.andi %86, %c15_i32 : i32
%88 = vector.insert %87, %cst [0] : i32 into vector<4xi32>
%89 = arith.shrui %86, %c4_i32 : i32
%90 = arith.andi %89, %c15_i32 : i32
%91 = vector.insert %90, %88 [1] : i32 into vector<4xi32>
%92 = arith.shrui %86, %c8_i32 : i32
%93 = arith.andi %92, %c15_i32 : i32
%94 = vector.insert %93, %91 [2] : i32 into vector<4xi32>
%95 = arith.shrui %86, %c12_i32 : i32
%96 = arith.andi %95, %c15_i32 : i32
%97 = vector.insert %96, %94 [3] : i32 into vector<4xi32>
%98 = arith.shrui %86, %c16_i32 : i32
%99 = arith.andi %98, %c15_i32 : i32
%100 = vector.insert %99, %cst [0] : i32 into vector<4xi32>
%101 = arith.shrui %86, %c20_i32 : i32
%102 = arith.andi %101, %c15_i32 : i32
%103 = vector.insert %102, %100 [1] : i32 into vector<4xi32>
%104 = arith.shrui %86, %c24_i32 : i32
%105 = arith.andi %104, %c15_i32 : i32
%106 = vector.insert %105, %103 [2] : i32 into vector<4xi32>
%107 = arith.shrui %86, %c28_i32 : i32
%108 = arith.andi %107, %c15_i32 : i32
%109 = vector.insert %108, %106 [3] : i32 into vector<4xi32>
%110 = arith.uitofp %97 : vector<4xi32> to vector<4xf16>
%111 = arith.uitofp %109 : vector<4xi32> to vector<4xf16>
%112 = arith.subf %110, %81 : vector<4xf16>
%113 = arith.subf %111, %81 : vector<4xf16>
%114 = arith.mulf %112, %82 : vector<4xf16>
%115 = arith.mulf %113, %82 : vector<4xf16>
%116 = affine.apply affine_map<(d0)[s0] -> (d0 * 17 + s0 floordiv 8)>(%arg0)[%arg2]
%117 = memref.load %alloc[%116] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%118 = vector.extract_strided_slice %117 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%119 = vector.bitcast %118 : vector<2xf32> to vector<4xf16>
%120 = arith.mulf %119, %114 : vector<4xf16>
%121 = vector.extract_strided_slice %117 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%122 = vector.bitcast %121 : vector<2xf32> to vector<4xf16>
%123 = arith.mulf %122, %115 : vector<4xf16>
%124 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%125 = vector.bitcast %124 : vector<2xf32> to vector<4xf16>
%126 = arith.addf %120, %125 : vector<4xf16>
%127 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%128 = vector.bitcast %127 : vector<2xf32> to vector<4xf16>
%129 = arith.addf %123, %128 : vector<4xf16>
%130 = vector.bitcast %129 : vector<4xf16> to vector<2xf32>
%131 = vector.bitcast %126 : vector<4xf16> to vector<2xf32>
%132 = vector.insert_strided_slice %131, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%133 = vector.insert_strided_slice %130, %132 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %133 : vector<4xf32>
}
scf.yield %83 : vector<4xf32>
}
%35 = affine.apply affine_map<()[s0, s1, s2] -> (s1 + s2 * 4 + s0 floordiv 2)>()[%12, %1, %workgroup_id_x]
%36 = memref.load %22[%35] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%37 = vector.extract_strided_slice %34 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%38 = vector.bitcast %37 : vector<2xf32> to vector<4xf16>
%39 = vector.reduction <add>, %38 : vector<4xf16> into f16
%40 = vector.insert %39, %cst_2 [0] : f16 into vector<2xf16>
%41 = vector.extract_strided_slice %34 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%42 = vector.bitcast %41 : vector<2xf32> to vector<4xf16>
%43 = vector.reduction <add>, %42 : vector<4xf16> into f16
%44 = vector.insert %43, %40 [1] : f16 into vector<2xf16>
%45 = vector.bitcast %44 : vector<2xf16> to vector<1xi32>
%46 = vector.extract %45[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %46, %c1_i32, %c64_i32 : i32
%47 = vector.splat %shuffleResult : vector<1xi32>
%48 = vector.bitcast %47 : vector<1xi32> to vector<2xf16>
%49 = arith.addf %44, %48 : vector<2xf16>
%50 = vector.bitcast %49 : vector<2xf16> to vector<1xi32>
%51 = vector.extract %50[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %51, %c2_i32, %c64_i32 : i32
%52 = vector.splat %shuffleResult_3 : vector<1xi32>
%53 = vector.bitcast %52 : vector<1xi32> to vector<2xf16>
%54 = arith.addf %49, %53 : vector<2xf16>
%55 = vector.bitcast %54 : vector<2xf16> to vector<1xi32>
%56 = vector.extract %55[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %56, %c4_i32, %c64_i32 : i32
%57 = vector.splat %shuffleResult_5 : vector<1xi32>
%58 = vector.bitcast %57 : vector<1xi32> to vector<2xf16>
%59 = arith.addf %54, %58 : vector<2xf16>
%60 = vector.bitcast %59 : vector<2xf16> to vector<1xi32>
%61 = vector.extract %60[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %61, %c8_i32, %c64_i32 : i32
%62 = vector.splat %shuffleResult_7 : vector<1xi32>
%63 = vector.bitcast %62 : vector<1xi32> to vector<2xf16>
%64 = arith.addf %59, %63 : vector<2xf16>
%65 = vector.bitcast %64 : vector<2xf16> to vector<1xi32>
%66 = vector.extract %65[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %66, %c16_i32, %c64_i32 : i32
%67 = vector.splat %shuffleResult_9 : vector<1xi32>
%68 = vector.bitcast %67 : vector<1xi32> to vector<2xf16>
%69 = arith.addf %64, %68 : vector<2xf16>
%70 = vector.bitcast %69 : vector<2xf16> to vector<1xi32>
%71 = vector.extract %70[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %71, %c32_i32, %c64_i32 : i32
%72 = vector.splat %shuffleResult_11 : vector<1xi32>
%73 = vector.bitcast %72 : vector<1xi32> to vector<2xf16>
%74 = arith.addf %69, %73 : vector<2xf16>
%75 = vector.reduction <add>, %74 : vector<2xf16> into f16
%76 = arith.addf %75, %36 : f16
memref.store %76, %22[%35] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After ConvertAffineToStandard (lower-affine) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%c4 = arith.constant 4 : index
%c0_3 = arith.constant 0 : index
%c-1 = arith.constant -1 : index
%13 = arith.cmpi slt, %8, %c0_3 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%c2097152 = arith.constant 2097152 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%19}
%c2 = arith.constant 2 : index
%c0_4 = arith.constant 0 : index
%c-1_5 = arith.constant -1 : index
%21 = arith.cmpi slt, %9, %c0_4 : index
%22 = arith.subi %c-1_5, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1_5, %24 : index
%26 = arith.select %21, %25, %24 : index
%c131072 = arith.constant 131072 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%27}
%c2_6 = arith.constant 2 : index
%c0_7 = arith.constant 0 : index
%c-1_8 = arith.constant -1 : index
%29 = arith.cmpi slt, %10, %c0_7 : index
%30 = arith.subi %c-1_8, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2_6 : index
%33 = arith.subi %c-1_8, %32 : index
%34 = arith.select %29, %33, %32 : index
%c131072_9 = arith.constant 131072 : index
%35 = arith.addi %34, %c131072_9 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%35}
%c16 = arith.constant 16 : index
%c0_10 = arith.constant 0 : index
%c-1_11 = arith.constant -1 : index
%37 = arith.cmpi slt, %11, %c0_10 : index
%38 = arith.subi %c-1_11, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1_11, %40 : index
%42 = arith.select %37, %41, %40 : index
%c512 = arith.constant 512 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%43}
%c2_12 = arith.constant 2 : index
%c0_13 = arith.constant 0 : index
%c-1_14 = arith.constant -1 : index
%45 = arith.cmpi slt, %12, %c0_13 : index
%46 = arith.subi %c-1_14, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2_12 : index
%49 = arith.subi %c-1_14, %48 : index
%50 = arith.select %45, %49, %48 : index
%c4096 = arith.constant 4096 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%c4_15 = arith.constant 4 : index
%53 = arith.muli %workgroup_id_x, %c4_15 : index
%c2_16 = arith.constant 2 : index
%c0_17 = arith.constant 0 : index
%c-1_18 = arith.constant -1 : index
%54 = arith.cmpi slt, %12, %c0_17 : index
%55 = arith.subi %c-1_18, %12 : index
%56 = arith.select %54, %55, %12 : index
%57 = arith.divsi %56, %c2_16 : index
%58 = arith.subi %c-1_18, %57 : index
%59 = arith.select %54, %58, %57 : index
%60 = arith.addi %53, %59 : index
memref.store %cst_0, %52[%60] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c4_19 = arith.constant 4 : index
%61 = arith.muli %workgroup_id_x, %c4_19 : index
%c2_20 = arith.constant 2 : index
%c0_21 = arith.constant 0 : index
%c-1_22 = arith.constant -1 : index
%62 = arith.cmpi slt, %12, %c0_21 : index
%63 = arith.subi %c-1_22, %12 : index
%64 = arith.select %62, %63, %12 : index
%65 = arith.divsi %64, %c2_20 : index
%66 = arith.subi %c-1_22, %65 : index
%67 = arith.select %62, %66, %65 : index
%68 = arith.addi %61, %67 : index
%c1_23 = arith.constant 1 : index
%69 = arith.addi %68, %c1_23 : index
memref.store %cst_0, %52[%69] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c4_24 = arith.constant 4 : index
%70 = arith.muli %workgroup_id_x, %c4_24 : index
%c2_25 = arith.constant 2 : index
%c0_26 = arith.constant 0 : index
%c-1_27 = arith.constant -1 : index
%71 = arith.cmpi slt, %12, %c0_26 : index
%72 = arith.subi %c-1_27, %12 : index
%73 = arith.select %71, %72, %12 : index
%74 = arith.divsi %73, %c2_25 : index
%75 = arith.subi %c-1_27, %74 : index
%76 = arith.select %71, %75, %74 : index
%77 = arith.addi %70, %76 : index
%c2_28 = arith.constant 2 : index
%78 = arith.addi %77, %c2_28 : index
memref.store %cst_0, %52[%78] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c4_29 = arith.constant 4 : index
%79 = arith.muli %workgroup_id_x, %c4_29 : index
%c2_30 = arith.constant 2 : index
%c0_31 = arith.constant 0 : index
%c-1_32 = arith.constant -1 : index
%80 = arith.cmpi slt, %12, %c0_31 : index
%81 = arith.subi %c-1_32, %12 : index
%82 = arith.select %80, %81, %12 : index
%83 = arith.divsi %82, %c2_30 : index
%84 = arith.subi %c-1_32, %83 : index
%85 = arith.select %80, %84, %83 : index
%86 = arith.addi %79, %85 : index
%c3 = arith.constant 3 : index
%87 = arith.addi %86, %c3 : index
memref.store %cst_0, %52[%87] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%c64 = arith.constant 64 : index
%88 = arith.muli %1, %c64 : index
%89 = arith.addi %0, %88 : index
%c256 = arith.constant 256 : index
%90 = arith.muli %2, %c256 : index
%91 = arith.addi %89, %90 : index
%c16_33 = arith.constant 16 : index
%c0_34 = arith.constant 0 : index
%c-1_35 = arith.constant -1 : index
%92 = arith.cmpi slt, %11, %c0_34 : index
%93 = arith.subi %c-1_35, %11 : index
%94 = arith.select %92, %93, %11 : index
%95 = arith.divsi %94, %c16_33 : index
%96 = arith.subi %c-1_35, %95 : index
%97 = arith.select %92, %96, %95 : index
%98 = arith.addi %91, %97 : index
%99 = memref.load %44[%98] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%c68 = arith.constant 68 : index
%100 = arith.muli %1, %c68 : index
%101 = arith.addi %0, %100 : index
%c272 = arith.constant 272 : index
%102 = arith.muli %2, %c272 : index
%103 = arith.addi %101, %102 : index
%c16_36 = arith.constant 16 : index
%c0_37 = arith.constant 0 : index
%c-1_38 = arith.constant -1 : index
%104 = arith.cmpi slt, %0, %c0_37 : index
%105 = arith.subi %c-1_38, %0 : index
%106 = arith.select %104, %105, %0 : index
%107 = arith.divsi %106, %c16_36 : index
%108 = arith.subi %c-1_38, %107 : index
%109 = arith.select %104, %108, %107 : index
%110 = arith.addi %103, %109 : index
memref.store %99, %alloc[%110] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%c64_39 = arith.constant 64 : index
%111 = arith.muli %1, %c64_39 : index
%112 = arith.addi %0, %111 : index
%c256_40 = arith.constant 256 : index
%113 = arith.muli %2, %c256_40 : index
%114 = arith.addi %112, %113 : index
%c16_41 = arith.constant 16 : index
%c0_42 = arith.constant 0 : index
%c-1_43 = arith.constant -1 : index
%115 = arith.cmpi slt, %11, %c0_42 : index
%116 = arith.subi %c-1_43, %11 : index
%117 = arith.select %115, %116, %11 : index
%118 = arith.divsi %117, %c16_41 : index
%119 = arith.subi %c-1_43, %118 : index
%120 = arith.select %115, %119, %118 : index
%121 = arith.addi %114, %120 : index
%c256_44 = arith.constant 256 : index
%122 = arith.addi %121, %c256_44 : index
%123 = memref.load %44[%122] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%c68_45 = arith.constant 68 : index
%124 = arith.muli %1, %c68_45 : index
%125 = arith.addi %0, %124 : index
%c272_46 = arith.constant 272 : index
%126 = arith.muli %2, %c272_46 : index
%127 = arith.addi %125, %126 : index
%c16_47 = arith.constant 16 : index
%c0_48 = arith.constant 0 : index
%c-1_49 = arith.constant -1 : index
%128 = arith.cmpi slt, %0, %c0_48 : index
%129 = arith.subi %c-1_49, %0 : index
%130 = arith.select %128, %129, %0 : index
%131 = arith.divsi %130, %c16_47 : index
%132 = arith.subi %c-1_49, %131 : index
%133 = arith.select %128, %132, %131 : index
%134 = arith.addi %127, %133 : index
%c272_50 = arith.constant 272 : index
%135 = arith.addi %134, %c272_50 : index
memref.store %123, %alloc[%135] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%136 = arith.cmpi ult, %0, %c1 : index
scf.if %136 {
%137 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%c32_65 = arith.constant 32 : index
%188 = arith.muli %1, %c32_65 : index
%189 = arith.addi %arg0, %188 : index
%c128_66 = arith.constant 128 : index
%190 = arith.muli %workgroup_id_x, %c128_66 : index
%191 = arith.addi %189, %190 : index
%c2_67 = arith.constant 2 : index
%c0_68 = arith.constant 0 : index
%c-1_69 = arith.constant -1 : index
%192 = arith.cmpi slt, %9, %c0_68 : index
%193 = arith.subi %c-1_69, %9 : index
%194 = arith.select %192, %193, %9 : index
%195 = arith.divsi %194, %c2_67 : index
%196 = arith.subi %c-1_69, %195 : index
%197 = arith.select %192, %196, %195 : index
%198 = arith.addi %191, %197 : index
%199 = memref.load %28[%198] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c32_70 = arith.constant 32 : index
%200 = arith.muli %1, %c32_70 : index
%201 = arith.addi %arg0, %200 : index
%c128_71 = arith.constant 128 : index
%202 = arith.muli %workgroup_id_x, %c128_71 : index
%203 = arith.addi %201, %202 : index
%c2_72 = arith.constant 2 : index
%c0_73 = arith.constant 0 : index
%c-1_74 = arith.constant -1 : index
%204 = arith.cmpi slt, %10, %c0_73 : index
%205 = arith.subi %c-1_74, %10 : index
%206 = arith.select %204, %205, %10 : index
%207 = arith.divsi %206, %c2_72 : index
%208 = arith.subi %c-1_74, %207 : index
%209 = arith.select %204, %208, %207 : index
%210 = arith.addi %203, %209 : index
%211 = memref.load %36[%210] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%212 = vector.splat %211 : vector<4xf16>
%213 = vector.splat %199 : vector<4xf16>
%214 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%c16_75 = arith.constant 16 : index
%215 = arith.muli %arg0, %c16_75 : index
%c512_76 = arith.constant 512 : index
%216 = arith.muli %1, %c512_76 : index
%217 = arith.addi %215, %216 : index
%c2048 = arith.constant 2048 : index
%218 = arith.muli %workgroup_id_x, %c2048 : index
%219 = arith.addi %217, %218 : index
%c8_77 = arith.constant 8 : index
%c0_78 = arith.constant 0 : index
%c-1_79 = arith.constant -1 : index
%220 = arith.cmpi slt, %arg2, %c0_78 : index
%221 = arith.subi %c-1_79, %arg2 : index
%222 = arith.select %220, %221, %arg2 : index
%223 = arith.divsi %222, %c8_77 : index
%224 = arith.subi %c-1_79, %223 : index
%225 = arith.select %220, %224, %223 : index
%226 = arith.addi %219, %225 : index
%c4_80 = arith.constant 4 : index
%c0_81 = arith.constant 0 : index
%c-1_82 = arith.constant -1 : index
%227 = arith.cmpi slt, %8, %c0_81 : index
%228 = arith.subi %c-1_82, %8 : index
%229 = arith.select %227, %228, %8 : index
%230 = arith.divsi %229, %c4_80 : index
%231 = arith.subi %c-1_82, %230 : index
%232 = arith.select %227, %231, %230 : index
%233 = arith.addi %226, %232 : index
%234 = memref.load %20[%233] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%235 = vector.extract %234[0] : vector<1xi32>
%236 = arith.andi %235, %c15_i32 : i32
%237 = vector.insert %236, %cst [0] : i32 into vector<4xi32>
%238 = arith.shrui %235, %c4_i32 : i32
%239 = arith.andi %238, %c15_i32 : i32
%240 = vector.insert %239, %237 [1] : i32 into vector<4xi32>
%241 = arith.shrui %235, %c8_i32 : i32
%242 = arith.andi %241, %c15_i32 : i32
%243 = vector.insert %242, %240 [2] : i32 into vector<4xi32>
%244 = arith.shrui %235, %c12_i32 : i32
%245 = arith.andi %244, %c15_i32 : i32
%246 = vector.insert %245, %243 [3] : i32 into vector<4xi32>
%247 = arith.shrui %235, %c16_i32 : i32
%248 = arith.andi %247, %c15_i32 : i32
%249 = vector.insert %248, %cst [0] : i32 into vector<4xi32>
%250 = arith.shrui %235, %c20_i32 : i32
%251 = arith.andi %250, %c15_i32 : i32
%252 = vector.insert %251, %249 [1] : i32 into vector<4xi32>
%253 = arith.shrui %235, %c24_i32 : i32
%254 = arith.andi %253, %c15_i32 : i32
%255 = vector.insert %254, %252 [2] : i32 into vector<4xi32>
%256 = arith.shrui %235, %c28_i32 : i32
%257 = arith.andi %256, %c15_i32 : i32
%258 = vector.insert %257, %255 [3] : i32 into vector<4xi32>
%259 = arith.uitofp %246 : vector<4xi32> to vector<4xf16>
%260 = arith.uitofp %258 : vector<4xi32> to vector<4xf16>
%261 = arith.subf %259, %212 : vector<4xf16>
%262 = arith.subf %260, %212 : vector<4xf16>
%263 = arith.mulf %261, %213 : vector<4xf16>
%264 = arith.mulf %262, %213 : vector<4xf16>
%c17 = arith.constant 17 : index
%265 = arith.muli %arg0, %c17 : index
%c8_83 = arith.constant 8 : index
%c0_84 = arith.constant 0 : index
%c-1_85 = arith.constant -1 : index
%266 = arith.cmpi slt, %arg2, %c0_84 : index
%267 = arith.subi %c-1_85, %arg2 : index
%268 = arith.select %266, %267, %arg2 : index
%269 = arith.divsi %268, %c8_83 : index
%270 = arith.subi %c-1_85, %269 : index
%271 = arith.select %266, %270, %269 : index
%272 = arith.addi %265, %271 : index
%273 = memref.load %alloc[%272] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%274 = vector.extract_strided_slice %273 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%275 = vector.bitcast %274 : vector<2xf32> to vector<4xf16>
%276 = arith.mulf %275, %263 : vector<4xf16>
%277 = vector.extract_strided_slice %273 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%278 = vector.bitcast %277 : vector<2xf32> to vector<4xf16>
%279 = arith.mulf %278, %264 : vector<4xf16>
%280 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%281 = vector.bitcast %280 : vector<2xf32> to vector<4xf16>
%282 = arith.addf %276, %281 : vector<4xf16>
%283 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%284 = vector.bitcast %283 : vector<2xf32> to vector<4xf16>
%285 = arith.addf %279, %284 : vector<4xf16>
%286 = vector.bitcast %285 : vector<4xf16> to vector<2xf32>
%287 = vector.bitcast %282 : vector<4xf16> to vector<2xf32>
%288 = vector.insert_strided_slice %287, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%289 = vector.insert_strided_slice %286, %288 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %289 : vector<4xf32>
}
scf.yield %214 : vector<4xf32>
}
%c4_51 = arith.constant 4 : index
%138 = arith.muli %workgroup_id_x, %c4_51 : index
%139 = arith.addi %1, %138 : index
%c2_52 = arith.constant 2 : index
%c0_53 = arith.constant 0 : index
%c-1_54 = arith.constant -1 : index
%140 = arith.cmpi slt, %12, %c0_53 : index
%141 = arith.subi %c-1_54, %12 : index
%142 = arith.select %140, %141, %12 : index
%143 = arith.divsi %142, %c2_52 : index
%144 = arith.subi %c-1_54, %143 : index
%145 = arith.select %140, %144, %143 : index
%146 = arith.addi %139, %145 : index
%147 = memref.load %52[%146] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%148 = vector.extract_strided_slice %137 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%149 = vector.bitcast %148 : vector<2xf32> to vector<4xf16>
%150 = vector.reduction <add>, %149 : vector<4xf16> into f16
%151 = vector.insert %150, %cst_2 [0] : f16 into vector<2xf16>
%152 = vector.extract_strided_slice %137 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%153 = vector.bitcast %152 : vector<2xf32> to vector<4xf16>
%154 = vector.reduction <add>, %153 : vector<4xf16> into f16
%155 = vector.insert %154, %151 [1] : f16 into vector<2xf16>
%156 = vector.bitcast %155 : vector<2xf16> to vector<1xi32>
%157 = vector.extract %156[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %157, %c1_i32, %c64_i32 : i32
%158 = vector.splat %shuffleResult : vector<1xi32>
%159 = vector.bitcast %158 : vector<1xi32> to vector<2xf16>
%160 = arith.addf %155, %159 : vector<2xf16>
%161 = vector.bitcast %160 : vector<2xf16> to vector<1xi32>
%162 = vector.extract %161[0] : vector<1xi32>
%shuffleResult_55, %valid_56 = gpu.shuffle xor %162, %c2_i32, %c64_i32 : i32
%163 = vector.splat %shuffleResult_55 : vector<1xi32>
%164 = vector.bitcast %163 : vector<1xi32> to vector<2xf16>
%165 = arith.addf %160, %164 : vector<2xf16>
%166 = vector.bitcast %165 : vector<2xf16> to vector<1xi32>
%167 = vector.extract %166[0] : vector<1xi32>
%shuffleResult_57, %valid_58 = gpu.shuffle xor %167, %c4_i32, %c64_i32 : i32
%168 = vector.splat %shuffleResult_57 : vector<1xi32>
%169 = vector.bitcast %168 : vector<1xi32> to vector<2xf16>
%170 = arith.addf %165, %169 : vector<2xf16>
%171 = vector.bitcast %170 : vector<2xf16> to vector<1xi32>
%172 = vector.extract %171[0] : vector<1xi32>
%shuffleResult_59, %valid_60 = gpu.shuffle xor %172, %c8_i32, %c64_i32 : i32
%173 = vector.splat %shuffleResult_59 : vector<1xi32>
%174 = vector.bitcast %173 : vector<1xi32> to vector<2xf16>
%175 = arith.addf %170, %174 : vector<2xf16>
%176 = vector.bitcast %175 : vector<2xf16> to vector<1xi32>
%177 = vector.extract %176[0] : vector<1xi32>
%shuffleResult_61, %valid_62 = gpu.shuffle xor %177, %c16_i32, %c64_i32 : i32
%178 = vector.splat %shuffleResult_61 : vector<1xi32>
%179 = vector.bitcast %178 : vector<1xi32> to vector<2xf16>
%180 = arith.addf %175, %179 : vector<2xf16>
%181 = vector.bitcast %180 : vector<2xf16> to vector<1xi32>
%182 = vector.extract %181[0] : vector<1xi32>
%shuffleResult_63, %valid_64 = gpu.shuffle xor %182, %c32_i32, %c64_i32 : i32
%183 = vector.splat %shuffleResult_63 : vector<1xi32>
%184 = vector.bitcast %183 : vector<1xi32> to vector<2xf16>
%185 = arith.addf %180, %184 : vector<2xf16>
%186 = vector.reduction <add>, %185 : vector<2xf16> into f16
%187 = arith.addf %186, %147 : f16
memref.store %187, %52[%146] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After TosaToArith (tosa-to-arith) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%c4 = arith.constant 4 : index
%c0_3 = arith.constant 0 : index
%c-1 = arith.constant -1 : index
%13 = arith.cmpi slt, %8, %c0_3 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%c2097152 = arith.constant 2097152 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%19}
%c2 = arith.constant 2 : index
%c0_4 = arith.constant 0 : index
%c-1_5 = arith.constant -1 : index
%21 = arith.cmpi slt, %9, %c0_4 : index
%22 = arith.subi %c-1_5, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1_5, %24 : index
%26 = arith.select %21, %25, %24 : index
%c131072 = arith.constant 131072 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%27}
%c2_6 = arith.constant 2 : index
%c0_7 = arith.constant 0 : index
%c-1_8 = arith.constant -1 : index
%29 = arith.cmpi slt, %10, %c0_7 : index
%30 = arith.subi %c-1_8, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2_6 : index
%33 = arith.subi %c-1_8, %32 : index
%34 = arith.select %29, %33, %32 : index
%c131072_9 = arith.constant 131072 : index
%35 = arith.addi %34, %c131072_9 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%35}
%c16 = arith.constant 16 : index
%c0_10 = arith.constant 0 : index
%c-1_11 = arith.constant -1 : index
%37 = arith.cmpi slt, %11, %c0_10 : index
%38 = arith.subi %c-1_11, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1_11, %40 : index
%42 = arith.select %37, %41, %40 : index
%c512 = arith.constant 512 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%43}
%c2_12 = arith.constant 2 : index
%c0_13 = arith.constant 0 : index
%c-1_14 = arith.constant -1 : index
%45 = arith.cmpi slt, %12, %c0_13 : index
%46 = arith.subi %c-1_14, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2_12 : index
%49 = arith.subi %c-1_14, %48 : index
%50 = arith.select %45, %49, %48 : index
%c4096 = arith.constant 4096 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%c4_15 = arith.constant 4 : index
%53 = arith.muli %workgroup_id_x, %c4_15 : index
%c2_16 = arith.constant 2 : index
%c0_17 = arith.constant 0 : index
%c-1_18 = arith.constant -1 : index
%54 = arith.cmpi slt, %12, %c0_17 : index
%55 = arith.subi %c-1_18, %12 : index
%56 = arith.select %54, %55, %12 : index
%57 = arith.divsi %56, %c2_16 : index
%58 = arith.subi %c-1_18, %57 : index
%59 = arith.select %54, %58, %57 : index
%60 = arith.addi %53, %59 : index
memref.store %cst_0, %52[%60] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c4_19 = arith.constant 4 : index
%61 = arith.muli %workgroup_id_x, %c4_19 : index
%c2_20 = arith.constant 2 : index
%c0_21 = arith.constant 0 : index
%c-1_22 = arith.constant -1 : index
%62 = arith.cmpi slt, %12, %c0_21 : index
%63 = arith.subi %c-1_22, %12 : index
%64 = arith.select %62, %63, %12 : index
%65 = arith.divsi %64, %c2_20 : index
%66 = arith.subi %c-1_22, %65 : index
%67 = arith.select %62, %66, %65 : index
%68 = arith.addi %61, %67 : index
%c1_23 = arith.constant 1 : index
%69 = arith.addi %68, %c1_23 : index
memref.store %cst_0, %52[%69] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c4_24 = arith.constant 4 : index
%70 = arith.muli %workgroup_id_x, %c4_24 : index
%c2_25 = arith.constant 2 : index
%c0_26 = arith.constant 0 : index
%c-1_27 = arith.constant -1 : index
%71 = arith.cmpi slt, %12, %c0_26 : index
%72 = arith.subi %c-1_27, %12 : index
%73 = arith.select %71, %72, %12 : index
%74 = arith.divsi %73, %c2_25 : index
%75 = arith.subi %c-1_27, %74 : index
%76 = arith.select %71, %75, %74 : index
%77 = arith.addi %70, %76 : index
%c2_28 = arith.constant 2 : index
%78 = arith.addi %77, %c2_28 : index
memref.store %cst_0, %52[%78] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c4_29 = arith.constant 4 : index
%79 = arith.muli %workgroup_id_x, %c4_29 : index
%c2_30 = arith.constant 2 : index
%c0_31 = arith.constant 0 : index
%c-1_32 = arith.constant -1 : index
%80 = arith.cmpi slt, %12, %c0_31 : index
%81 = arith.subi %c-1_32, %12 : index
%82 = arith.select %80, %81, %12 : index
%83 = arith.divsi %82, %c2_30 : index
%84 = arith.subi %c-1_32, %83 : index
%85 = arith.select %80, %84, %83 : index
%86 = arith.addi %79, %85 : index
%c3 = arith.constant 3 : index
%87 = arith.addi %86, %c3 : index
memref.store %cst_0, %52[%87] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%c64 = arith.constant 64 : index
%88 = arith.muli %1, %c64 : index
%89 = arith.addi %0, %88 : index
%c256 = arith.constant 256 : index
%90 = arith.muli %2, %c256 : index
%91 = arith.addi %89, %90 : index
%c16_33 = arith.constant 16 : index
%c0_34 = arith.constant 0 : index
%c-1_35 = arith.constant -1 : index
%92 = arith.cmpi slt, %11, %c0_34 : index
%93 = arith.subi %c-1_35, %11 : index
%94 = arith.select %92, %93, %11 : index
%95 = arith.divsi %94, %c16_33 : index
%96 = arith.subi %c-1_35, %95 : index
%97 = arith.select %92, %96, %95 : index
%98 = arith.addi %91, %97 : index
%99 = memref.load %44[%98] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%c68 = arith.constant 68 : index
%100 = arith.muli %1, %c68 : index
%101 = arith.addi %0, %100 : index
%c272 = arith.constant 272 : index
%102 = arith.muli %2, %c272 : index
%103 = arith.addi %101, %102 : index
%c16_36 = arith.constant 16 : index
%c0_37 = arith.constant 0 : index
%c-1_38 = arith.constant -1 : index
%104 = arith.cmpi slt, %0, %c0_37 : index
%105 = arith.subi %c-1_38, %0 : index
%106 = arith.select %104, %105, %0 : index
%107 = arith.divsi %106, %c16_36 : index
%108 = arith.subi %c-1_38, %107 : index
%109 = arith.select %104, %108, %107 : index
%110 = arith.addi %103, %109 : index
memref.store %99, %alloc[%110] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%c64_39 = arith.constant 64 : index
%111 = arith.muli %1, %c64_39 : index
%112 = arith.addi %0, %111 : index
%c256_40 = arith.constant 256 : index
%113 = arith.muli %2, %c256_40 : index
%114 = arith.addi %112, %113 : index
%c16_41 = arith.constant 16 : index
%c0_42 = arith.constant 0 : index
%c-1_43 = arith.constant -1 : index
%115 = arith.cmpi slt, %11, %c0_42 : index
%116 = arith.subi %c-1_43, %11 : index
%117 = arith.select %115, %116, %11 : index
%118 = arith.divsi %117, %c16_41 : index
%119 = arith.subi %c-1_43, %118 : index
%120 = arith.select %115, %119, %118 : index
%121 = arith.addi %114, %120 : index
%c256_44 = arith.constant 256 : index
%122 = arith.addi %121, %c256_44 : index
%123 = memref.load %44[%122] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%c68_45 = arith.constant 68 : index
%124 = arith.muli %1, %c68_45 : index
%125 = arith.addi %0, %124 : index
%c272_46 = arith.constant 272 : index
%126 = arith.muli %2, %c272_46 : index
%127 = arith.addi %125, %126 : index
%c16_47 = arith.constant 16 : index
%c0_48 = arith.constant 0 : index
%c-1_49 = arith.constant -1 : index
%128 = arith.cmpi slt, %0, %c0_48 : index
%129 = arith.subi %c-1_49, %0 : index
%130 = arith.select %128, %129, %0 : index
%131 = arith.divsi %130, %c16_47 : index
%132 = arith.subi %c-1_49, %131 : index
%133 = arith.select %128, %132, %131 : index
%134 = arith.addi %127, %133 : index
%c272_50 = arith.constant 272 : index
%135 = arith.addi %134, %c272_50 : index
memref.store %123, %alloc[%135] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%136 = arith.cmpi ult, %0, %c1 : index
scf.if %136 {
%137 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%c32_65 = arith.constant 32 : index
%188 = arith.muli %1, %c32_65 : index
%189 = arith.addi %arg0, %188 : index
%c128_66 = arith.constant 128 : index
%190 = arith.muli %workgroup_id_x, %c128_66 : index
%191 = arith.addi %189, %190 : index
%c2_67 = arith.constant 2 : index
%c0_68 = arith.constant 0 : index
%c-1_69 = arith.constant -1 : index
%192 = arith.cmpi slt, %9, %c0_68 : index
%193 = arith.subi %c-1_69, %9 : index
%194 = arith.select %192, %193, %9 : index
%195 = arith.divsi %194, %c2_67 : index
%196 = arith.subi %c-1_69, %195 : index
%197 = arith.select %192, %196, %195 : index
%198 = arith.addi %191, %197 : index
%199 = memref.load %28[%198] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%c32_70 = arith.constant 32 : index
%200 = arith.muli %1, %c32_70 : index
%201 = arith.addi %arg0, %200 : index
%c128_71 = arith.constant 128 : index
%202 = arith.muli %workgroup_id_x, %c128_71 : index
%203 = arith.addi %201, %202 : index
%c2_72 = arith.constant 2 : index
%c0_73 = arith.constant 0 : index
%c-1_74 = arith.constant -1 : index
%204 = arith.cmpi slt, %10, %c0_73 : index
%205 = arith.subi %c-1_74, %10 : index
%206 = arith.select %204, %205, %10 : index
%207 = arith.divsi %206, %c2_72 : index
%208 = arith.subi %c-1_74, %207 : index
%209 = arith.select %204, %208, %207 : index
%210 = arith.addi %203, %209 : index
%211 = memref.load %36[%210] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%212 = vector.splat %211 : vector<4xf16>
%213 = vector.splat %199 : vector<4xf16>
%214 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%c16_75 = arith.constant 16 : index
%215 = arith.muli %arg0, %c16_75 : index
%c512_76 = arith.constant 512 : index
%216 = arith.muli %1, %c512_76 : index
%217 = arith.addi %215, %216 : index
%c2048 = arith.constant 2048 : index
%218 = arith.muli %workgroup_id_x, %c2048 : index
%219 = arith.addi %217, %218 : index
%c8_77 = arith.constant 8 : index
%c0_78 = arith.constant 0 : index
%c-1_79 = arith.constant -1 : index
%220 = arith.cmpi slt, %arg2, %c0_78 : index
%221 = arith.subi %c-1_79, %arg2 : index
%222 = arith.select %220, %221, %arg2 : index
%223 = arith.divsi %222, %c8_77 : index
%224 = arith.subi %c-1_79, %223 : index
%225 = arith.select %220, %224, %223 : index
%226 = arith.addi %219, %225 : index
%c4_80 = arith.constant 4 : index
%c0_81 = arith.constant 0 : index
%c-1_82 = arith.constant -1 : index
%227 = arith.cmpi slt, %8, %c0_81 : index
%228 = arith.subi %c-1_82, %8 : index
%229 = arith.select %227, %228, %8 : index
%230 = arith.divsi %229, %c4_80 : index
%231 = arith.subi %c-1_82, %230 : index
%232 = arith.select %227, %231, %230 : index
%233 = arith.addi %226, %232 : index
%234 = memref.load %20[%233] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%235 = vector.extract %234[0] : vector<1xi32>
%236 = arith.andi %235, %c15_i32 : i32
%237 = vector.insert %236, %cst [0] : i32 into vector<4xi32>
%238 = arith.shrui %235, %c4_i32 : i32
%239 = arith.andi %238, %c15_i32 : i32
%240 = vector.insert %239, %237 [1] : i32 into vector<4xi32>
%241 = arith.shrui %235, %c8_i32 : i32
%242 = arith.andi %241, %c15_i32 : i32
%243 = vector.insert %242, %240 [2] : i32 into vector<4xi32>
%244 = arith.shrui %235, %c12_i32 : i32
%245 = arith.andi %244, %c15_i32 : i32
%246 = vector.insert %245, %243 [3] : i32 into vector<4xi32>
%247 = arith.shrui %235, %c16_i32 : i32
%248 = arith.andi %247, %c15_i32 : i32
%249 = vector.insert %248, %cst [0] : i32 into vector<4xi32>
%250 = arith.shrui %235, %c20_i32 : i32
%251 = arith.andi %250, %c15_i32 : i32
%252 = vector.insert %251, %249 [1] : i32 into vector<4xi32>
%253 = arith.shrui %235, %c24_i32 : i32
%254 = arith.andi %253, %c15_i32 : i32
%255 = vector.insert %254, %252 [2] : i32 into vector<4xi32>
%256 = arith.shrui %235, %c28_i32 : i32
%257 = arith.andi %256, %c15_i32 : i32
%258 = vector.insert %257, %255 [3] : i32 into vector<4xi32>
%259 = arith.uitofp %246 : vector<4xi32> to vector<4xf16>
%260 = arith.uitofp %258 : vector<4xi32> to vector<4xf16>
%261 = arith.subf %259, %212 : vector<4xf16>
%262 = arith.subf %260, %212 : vector<4xf16>
%263 = arith.mulf %261, %213 : vector<4xf16>
%264 = arith.mulf %262, %213 : vector<4xf16>
%c17 = arith.constant 17 : index
%265 = arith.muli %arg0, %c17 : index
%c8_83 = arith.constant 8 : index
%c0_84 = arith.constant 0 : index
%c-1_85 = arith.constant -1 : index
%266 = arith.cmpi slt, %arg2, %c0_84 : index
%267 = arith.subi %c-1_85, %arg2 : index
%268 = arith.select %266, %267, %arg2 : index
%269 = arith.divsi %268, %c8_83 : index
%270 = arith.subi %c-1_85, %269 : index
%271 = arith.select %266, %270, %269 : index
%272 = arith.addi %265, %271 : index
%273 = memref.load %alloc[%272] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%274 = vector.extract_strided_slice %273 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%275 = vector.bitcast %274 : vector<2xf32> to vector<4xf16>
%276 = arith.mulf %275, %263 : vector<4xf16>
%277 = vector.extract_strided_slice %273 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%278 = vector.bitcast %277 : vector<2xf32> to vector<4xf16>
%279 = arith.mulf %278, %264 : vector<4xf16>
%280 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%281 = vector.bitcast %280 : vector<2xf32> to vector<4xf16>
%282 = arith.addf %276, %281 : vector<4xf16>
%283 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%284 = vector.bitcast %283 : vector<2xf32> to vector<4xf16>
%285 = arith.addf %279, %284 : vector<4xf16>
%286 = vector.bitcast %285 : vector<4xf16> to vector<2xf32>
%287 = vector.bitcast %282 : vector<4xf16> to vector<2xf32>
%288 = vector.insert_strided_slice %287, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%289 = vector.insert_strided_slice %286, %288 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %289 : vector<4xf32>
}
scf.yield %214 : vector<4xf32>
}
%c4_51 = arith.constant 4 : index
%138 = arith.muli %workgroup_id_x, %c4_51 : index
%139 = arith.addi %1, %138 : index
%c2_52 = arith.constant 2 : index
%c0_53 = arith.constant 0 : index
%c-1_54 = arith.constant -1 : index
%140 = arith.cmpi slt, %12, %c0_53 : index
%141 = arith.subi %c-1_54, %12 : index
%142 = arith.select %140, %141, %12 : index
%143 = arith.divsi %142, %c2_52 : index
%144 = arith.subi %c-1_54, %143 : index
%145 = arith.select %140, %144, %143 : index
%146 = arith.addi %139, %145 : index
%147 = memref.load %52[%146] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%148 = vector.extract_strided_slice %137 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%149 = vector.bitcast %148 : vector<2xf32> to vector<4xf16>
%150 = vector.reduction <add>, %149 : vector<4xf16> into f16
%151 = vector.insert %150, %cst_2 [0] : f16 into vector<2xf16>
%152 = vector.extract_strided_slice %137 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%153 = vector.bitcast %152 : vector<2xf32> to vector<4xf16>
%154 = vector.reduction <add>, %153 : vector<4xf16> into f16
%155 = vector.insert %154, %151 [1] : f16 into vector<2xf16>
%156 = vector.bitcast %155 : vector<2xf16> to vector<1xi32>
%157 = vector.extract %156[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %157, %c1_i32, %c64_i32 : i32
%158 = vector.splat %shuffleResult : vector<1xi32>
%159 = vector.bitcast %158 : vector<1xi32> to vector<2xf16>
%160 = arith.addf %155, %159 : vector<2xf16>
%161 = vector.bitcast %160 : vector<2xf16> to vector<1xi32>
%162 = vector.extract %161[0] : vector<1xi32>
%shuffleResult_55, %valid_56 = gpu.shuffle xor %162, %c2_i32, %c64_i32 : i32
%163 = vector.splat %shuffleResult_55 : vector<1xi32>
%164 = vector.bitcast %163 : vector<1xi32> to vector<2xf16>
%165 = arith.addf %160, %164 : vector<2xf16>
%166 = vector.bitcast %165 : vector<2xf16> to vector<1xi32>
%167 = vector.extract %166[0] : vector<1xi32>
%shuffleResult_57, %valid_58 = gpu.shuffle xor %167, %c4_i32, %c64_i32 : i32
%168 = vector.splat %shuffleResult_57 : vector<1xi32>
%169 = vector.bitcast %168 : vector<1xi32> to vector<2xf16>
%170 = arith.addf %165, %169 : vector<2xf16>
%171 = vector.bitcast %170 : vector<2xf16> to vector<1xi32>
%172 = vector.extract %171[0] : vector<1xi32>
%shuffleResult_59, %valid_60 = gpu.shuffle xor %172, %c8_i32, %c64_i32 : i32
%173 = vector.splat %shuffleResult_59 : vector<1xi32>
%174 = vector.bitcast %173 : vector<1xi32> to vector<2xf16>
%175 = arith.addf %170, %174 : vector<2xf16>
%176 = vector.bitcast %175 : vector<2xf16> to vector<1xi32>
%177 = vector.extract %176[0] : vector<1xi32>
%shuffleResult_61, %valid_62 = gpu.shuffle xor %177, %c16_i32, %c64_i32 : i32
%178 = vector.splat %shuffleResult_61 : vector<1xi32>
%179 = vector.bitcast %178 : vector<1xi32> to vector<2xf16>
%180 = arith.addf %175, %179 : vector<2xf16>
%181 = vector.bitcast %180 : vector<2xf16> to vector<1xi32>
%182 = vector.extract %181[0] : vector<1xi32>
%shuffleResult_63, %valid_64 = gpu.shuffle xor %182, %c32_i32, %c64_i32 : i32
%183 = vector.splat %shuffleResult_63 : vector<1xi32>
%184 = vector.bitcast %183 : vector<1xi32> to vector<2xf16>
%185 = arith.addf %180, %184 : vector<2xf16>
%186 = vector.reduction <add>, %185 : vector<2xf16> into f16
%187 = arith.addf %186, %147 : f16
memref.store %187, %52[%146] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.cmpi slt, %12, %c0 : index
%55 = arith.subi %c-1, %12 : index
%56 = arith.select %54, %55, %12 : index
%57 = arith.divsi %56, %c2 : index
%58 = arith.subi %c-1, %57 : index
%59 = arith.select %54, %58, %57 : index
%60 = arith.addi %53, %59 : index
memref.store %cst_0, %52[%60] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%61 = arith.muli %workgroup_id_x, %c4 : index
%62 = arith.cmpi slt, %12, %c0 : index
%63 = arith.subi %c-1, %12 : index
%64 = arith.select %62, %63, %12 : index
%65 = arith.divsi %64, %c2 : index
%66 = arith.subi %c-1, %65 : index
%67 = arith.select %62, %66, %65 : index
%68 = arith.addi %61, %67 : index
%69 = arith.addi %68, %c1 : index
memref.store %cst_0, %52[%69] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%70 = arith.muli %workgroup_id_x, %c4 : index
%71 = arith.cmpi slt, %12, %c0 : index
%72 = arith.subi %c-1, %12 : index
%73 = arith.select %71, %72, %12 : index
%74 = arith.divsi %73, %c2 : index
%75 = arith.subi %c-1, %74 : index
%76 = arith.select %71, %75, %74 : index
%77 = arith.addi %70, %76 : index
%78 = arith.addi %77, %c2 : index
memref.store %cst_0, %52[%78] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%79 = arith.muli %workgroup_id_x, %c4 : index
%80 = arith.cmpi slt, %12, %c0 : index
%81 = arith.subi %c-1, %12 : index
%82 = arith.select %80, %81, %12 : index
%83 = arith.divsi %82, %c2 : index
%84 = arith.subi %c-1, %83 : index
%85 = arith.select %80, %84, %83 : index
%86 = arith.addi %79, %85 : index
%87 = arith.addi %86, %c3 : index
memref.store %cst_0, %52[%87] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%88 = arith.muli %1, %c64 : index
%89 = arith.addi %0, %88 : index
%90 = arith.muli %2, %c256 : index
%91 = arith.addi %89, %90 : index
%92 = arith.cmpi slt, %11, %c0 : index
%93 = arith.subi %c-1, %11 : index
%94 = arith.select %92, %93, %11 : index
%95 = arith.divsi %94, %c16 : index
%96 = arith.subi %c-1, %95 : index
%97 = arith.select %92, %96, %95 : index
%98 = arith.addi %91, %97 : index
%99 = memref.load %44[%98] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%100 = arith.muli %1, %c68 : index
%101 = arith.addi %0, %100 : index
%102 = arith.muli %2, %c272 : index
%103 = arith.addi %101, %102 : index
%104 = arith.cmpi slt, %0, %c0 : index
%105 = arith.subi %c-1, %0 : index
%106 = arith.select %104, %105, %0 : index
%107 = arith.divsi %106, %c16 : index
%108 = arith.subi %c-1, %107 : index
%109 = arith.select %104, %108, %107 : index
%110 = arith.addi %103, %109 : index
memref.store %99, %alloc[%110] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%111 = arith.muli %1, %c64 : index
%112 = arith.addi %0, %111 : index
%113 = arith.muli %2, %c256 : index
%114 = arith.addi %112, %113 : index
%115 = arith.cmpi slt, %11, %c0 : index
%116 = arith.subi %c-1, %11 : index
%117 = arith.select %115, %116, %11 : index
%118 = arith.divsi %117, %c16 : index
%119 = arith.subi %c-1, %118 : index
%120 = arith.select %115, %119, %118 : index
%121 = arith.addi %114, %120 : index
%122 = arith.addi %121, %c256 : index
%123 = memref.load %44[%122] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%124 = arith.muli %1, %c68 : index
%125 = arith.addi %0, %124 : index
%126 = arith.muli %2, %c272 : index
%127 = arith.addi %125, %126 : index
%128 = arith.cmpi slt, %0, %c0 : index
%129 = arith.subi %c-1, %0 : index
%130 = arith.select %128, %129, %0 : index
%131 = arith.divsi %130, %c16 : index
%132 = arith.subi %c-1, %131 : index
%133 = arith.select %128, %132, %131 : index
%134 = arith.addi %127, %133 : index
%135 = arith.addi %134, %c272 : index
memref.store %123, %alloc[%135] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%136 = arith.cmpi ult, %0, %c1 : index
scf.if %136 {
%137 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%188 = arith.muli %1, %c32 : index
%189 = arith.addi %arg0, %188 : index
%190 = arith.muli %workgroup_id_x, %c128 : index
%191 = arith.addi %189, %190 : index
%192 = arith.cmpi slt, %9, %c0 : index
%193 = arith.subi %c-1, %9 : index
%194 = arith.select %192, %193, %9 : index
%195 = arith.divsi %194, %c2 : index
%196 = arith.subi %c-1, %195 : index
%197 = arith.select %192, %196, %195 : index
%198 = arith.addi %191, %197 : index
%199 = memref.load %28[%198] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%200 = arith.muli %1, %c32 : index
%201 = arith.addi %arg0, %200 : index
%202 = arith.muli %workgroup_id_x, %c128 : index
%203 = arith.addi %201, %202 : index
%204 = arith.cmpi slt, %10, %c0 : index
%205 = arith.subi %c-1, %10 : index
%206 = arith.select %204, %205, %10 : index
%207 = arith.divsi %206, %c2 : index
%208 = arith.subi %c-1, %207 : index
%209 = arith.select %204, %208, %207 : index
%210 = arith.addi %203, %209 : index
%211 = memref.load %36[%210] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%212 = vector.splat %211 : vector<4xf16>
%213 = vector.splat %199 : vector<4xf16>
%214 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%215 = arith.muli %arg0, %c16 : index
%216 = arith.muli %1, %c512 : index
%217 = arith.addi %215, %216 : index
%218 = arith.muli %workgroup_id_x, %c2048 : index
%219 = arith.addi %217, %218 : index
%220 = arith.cmpi slt, %arg2, %c0 : index
%221 = arith.subi %c-1, %arg2 : index
%222 = arith.select %220, %221, %arg2 : index
%223 = arith.divsi %222, %c8 : index
%224 = arith.subi %c-1, %223 : index
%225 = arith.select %220, %224, %223 : index
%226 = arith.addi %219, %225 : index
%227 = arith.cmpi slt, %8, %c0 : index
%228 = arith.subi %c-1, %8 : index
%229 = arith.select %227, %228, %8 : index
%230 = arith.divsi %229, %c4 : index
%231 = arith.subi %c-1, %230 : index
%232 = arith.select %227, %231, %230 : index
%233 = arith.addi %226, %232 : index
%234 = memref.load %20[%233] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%235 = vector.extract %234[0] : vector<1xi32>
%236 = arith.andi %235, %c15_i32 : i32
%237 = vector.insert %236, %cst [0] : i32 into vector<4xi32>
%238 = arith.shrui %235, %c4_i32 : i32
%239 = arith.andi %238, %c15_i32 : i32
%240 = vector.insert %239, %237 [1] : i32 into vector<4xi32>
%241 = arith.shrui %235, %c8_i32 : i32
%242 = arith.andi %241, %c15_i32 : i32
%243 = vector.insert %242, %240 [2] : i32 into vector<4xi32>
%244 = arith.shrui %235, %c12_i32 : i32
%245 = arith.andi %244, %c15_i32 : i32
%246 = vector.insert %245, %243 [3] : i32 into vector<4xi32>
%247 = arith.shrui %235, %c16_i32 : i32
%248 = arith.andi %247, %c15_i32 : i32
%249 = vector.insert %248, %cst [0] : i32 into vector<4xi32>
%250 = arith.shrui %235, %c20_i32 : i32
%251 = arith.andi %250, %c15_i32 : i32
%252 = vector.insert %251, %249 [1] : i32 into vector<4xi32>
%253 = arith.shrui %235, %c24_i32 : i32
%254 = arith.andi %253, %c15_i32 : i32
%255 = vector.insert %254, %252 [2] : i32 into vector<4xi32>
%256 = arith.shrui %235, %c28_i32 : i32
%257 = arith.andi %256, %c15_i32 : i32
%258 = vector.insert %257, %255 [3] : i32 into vector<4xi32>
%259 = arith.uitofp %246 : vector<4xi32> to vector<4xf16>
%260 = arith.uitofp %258 : vector<4xi32> to vector<4xf16>
%261 = arith.subf %259, %212 : vector<4xf16>
%262 = arith.subf %260, %212 : vector<4xf16>
%263 = arith.mulf %261, %213 : vector<4xf16>
%264 = arith.mulf %262, %213 : vector<4xf16>
%265 = arith.muli %arg0, %c17 : index
%266 = arith.cmpi slt, %arg2, %c0 : index
%267 = arith.subi %c-1, %arg2 : index
%268 = arith.select %266, %267, %arg2 : index
%269 = arith.divsi %268, %c8 : index
%270 = arith.subi %c-1, %269 : index
%271 = arith.select %266, %270, %269 : index
%272 = arith.addi %265, %271 : index
%273 = memref.load %alloc[%272] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%274 = vector.extract_strided_slice %273 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%275 = vector.bitcast %274 : vector<2xf32> to vector<4xf16>
%276 = arith.mulf %275, %263 : vector<4xf16>
%277 = vector.extract_strided_slice %273 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%278 = vector.bitcast %277 : vector<2xf32> to vector<4xf16>
%279 = arith.mulf %278, %264 : vector<4xf16>
%280 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%281 = vector.bitcast %280 : vector<2xf32> to vector<4xf16>
%282 = arith.addf %276, %281 : vector<4xf16>
%283 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%284 = vector.bitcast %283 : vector<2xf32> to vector<4xf16>
%285 = arith.addf %279, %284 : vector<4xf16>
%286 = vector.bitcast %285 : vector<4xf16> to vector<2xf32>
%287 = vector.bitcast %282 : vector<4xf16> to vector<2xf32>
%288 = vector.insert_strided_slice %287, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%289 = vector.insert_strided_slice %286, %288 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %289 : vector<4xf32>
}
scf.yield %214 : vector<4xf32>
}
%138 = arith.muli %workgroup_id_x, %c4 : index
%139 = arith.addi %1, %138 : index
%140 = arith.cmpi slt, %12, %c0 : index
%141 = arith.subi %c-1, %12 : index
%142 = arith.select %140, %141, %12 : index
%143 = arith.divsi %142, %c2 : index
%144 = arith.subi %c-1, %143 : index
%145 = arith.select %140, %144, %143 : index
%146 = arith.addi %139, %145 : index
%147 = memref.load %52[%146] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%148 = vector.extract_strided_slice %137 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%149 = vector.bitcast %148 : vector<2xf32> to vector<4xf16>
%150 = vector.reduction <add>, %149 : vector<4xf16> into f16
%151 = vector.insert %150, %cst_2 [0] : f16 into vector<2xf16>
%152 = vector.extract_strided_slice %137 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%153 = vector.bitcast %152 : vector<2xf32> to vector<4xf16>
%154 = vector.reduction <add>, %153 : vector<4xf16> into f16
%155 = vector.insert %154, %151 [1] : f16 into vector<2xf16>
%156 = vector.bitcast %155 : vector<2xf16> to vector<1xi32>
%157 = vector.extract %156[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %157, %c1_i32, %c64_i32 : i32
%158 = vector.splat %shuffleResult : vector<1xi32>
%159 = vector.bitcast %158 : vector<1xi32> to vector<2xf16>
%160 = arith.addf %155, %159 : vector<2xf16>
%161 = vector.bitcast %160 : vector<2xf16> to vector<1xi32>
%162 = vector.extract %161[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %162, %c2_i32, %c64_i32 : i32
%163 = vector.splat %shuffleResult_3 : vector<1xi32>
%164 = vector.bitcast %163 : vector<1xi32> to vector<2xf16>
%165 = arith.addf %160, %164 : vector<2xf16>
%166 = vector.bitcast %165 : vector<2xf16> to vector<1xi32>
%167 = vector.extract %166[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %167, %c4_i32, %c64_i32 : i32
%168 = vector.splat %shuffleResult_5 : vector<1xi32>
%169 = vector.bitcast %168 : vector<1xi32> to vector<2xf16>
%170 = arith.addf %165, %169 : vector<2xf16>
%171 = vector.bitcast %170 : vector<2xf16> to vector<1xi32>
%172 = vector.extract %171[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %172, %c8_i32, %c64_i32 : i32
%173 = vector.splat %shuffleResult_7 : vector<1xi32>
%174 = vector.bitcast %173 : vector<1xi32> to vector<2xf16>
%175 = arith.addf %170, %174 : vector<2xf16>
%176 = vector.bitcast %175 : vector<2xf16> to vector<1xi32>
%177 = vector.extract %176[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %177, %c16_i32, %c64_i32 : i32
%178 = vector.splat %shuffleResult_9 : vector<1xi32>
%179 = vector.bitcast %178 : vector<1xi32> to vector<2xf16>
%180 = arith.addf %175, %179 : vector<2xf16>
%181 = vector.bitcast %180 : vector<2xf16> to vector<1xi32>
%182 = vector.extract %181[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %182, %c32_i32, %c64_i32 : i32
%183 = vector.splat %shuffleResult_11 : vector<1xi32>
%184 = vector.bitcast %183 : vector<1xi32> to vector<2xf16>
%185 = arith.addf %180, %184 : vector<2xf16>
%186 = vector.reduction <add>, %185 : vector<2xf16> into f16
%187 = arith.addf %186, %147 : f16
memref.store %187, %52[%146] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #hal.descriptor_type<storage_buffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #hal.descriptor_type<storage_buffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #hal.descriptor_type<storage_buffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #gpu.address_space<workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #hal.descriptor_type<storage_buffer>>
}
return
}
}
// -----// IR Dump After SPIRVMapMemRefStorageClass (iree-spirv-map-memref-storage-class) //----- //
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
}
return
}
// -----// IR Dump After SPIRVEmulateI64 (iree-spirv-emulate-i64) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
}
return
}
}
// -----// IR Dump After ConvertBf16ArithToF32 (iree-convert-bf16-arith-to-f32) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
}
return
}
}
// -----// IR Dump After ConvertBf16ToUInt16Buffers (iree-convert-bf16-to-uint16-buffers) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
}
return
}
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
}
return
}
}
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @forward_dispatch_3_generic_4096x32x128_f16() {
%c17 = arith.constant 17 : index
%c2048 = arith.constant 2048 : index
%c272 = arith.constant 272 : index
%c68 = arith.constant 68 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%c3 = arith.constant 3 : index
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%c16 = arith.constant 16 : index
%c131072 = arith.constant 131072 : index
%c2 = arith.constant 2 : index
%c2097152 = arith.constant 2097152 : index
%c-1 = arith.constant -1 : index
%c4 = arith.constant 4 : index
%c12_i32 = arith.constant 12 : i32
%cst = arith.constant dense<0> : vector<4xi32>
%c15_i32 = arith.constant 15 : i32
%c20_i32 = arith.constant 20 : i32
%c24_i32 = arith.constant 24 : i32
%c28_i32 = arith.constant 28 : i32
%cst_0 = arith.constant 0.000000e+00 : f16
%cst_1 = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32_i32 = arith.constant 32 : i32
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%cst_2 = arith.constant dense<0.000000e+00> : vector<2xf16>
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%0 = gpu.thread_id x
%1 = gpu.thread_id y
%2 = gpu.thread_id z
%alloc = memref.alloc() : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%3 = hal.interface.constant.load[0] : i32
%4 = hal.interface.constant.load[1] : i32
%5 = hal.interface.constant.load[2] : i32
%6 = hal.interface.constant.load[3] : i32
%7 = hal.interface.constant.load[4] : i32
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = arith.index_castui %5 : i32 to index
%11 = arith.index_castui %6 : i32 to index
%12 = arith.index_castui %7 : i32 to index
%13 = arith.cmpi slt, %8, %c0 : index
%14 = arith.subi %c-1, %8 : index
%15 = arith.select %13, %14, %8 : index
%16 = arith.divsi %15, %c4 : index
%17 = arith.subi %c-1, %16 : index
%18 = arith.select %13, %17, %16 : index
%19 = arith.addi %18, %c2097152 : index
%20 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>{%19}
%21 = arith.cmpi slt, %9, %c0 : index
%22 = arith.subi %c-1, %9 : index
%23 = arith.select %21, %22, %9 : index
%24 = arith.divsi %23, %c2 : index
%25 = arith.subi %c-1, %24 : index
%26 = arith.select %21, %25, %24 : index
%27 = arith.addi %26, %c131072 : index
%28 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%27}
%29 = arith.cmpi slt, %10, %c0 : index
%30 = arith.subi %c-1, %10 : index
%31 = arith.select %29, %30, %10 : index
%32 = arith.divsi %31, %c2 : index
%33 = arith.subi %c-1, %32 : index
%34 = arith.select %29, %33, %32 : index
%35 = arith.addi %34, %c131072 : index
%36 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%35}
%37 = arith.cmpi slt, %11, %c0 : index
%38 = arith.subi %c-1, %11 : index
%39 = arith.select %37, %38, %11 : index
%40 = arith.divsi %39, %c16 : index
%41 = arith.subi %c-1, %40 : index
%42 = arith.select %37, %41, %40 : index
%43 = arith.addi %42, %c512 : index
%44 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%43}
%45 = arith.cmpi slt, %12, %c0 : index
%46 = arith.subi %c-1, %12 : index
%47 = arith.select %45, %46, %12 : index
%48 = arith.divsi %47, %c2 : index
%49 = arith.subi %c-1, %48 : index
%50 = arith.select %45, %49, %48 : index
%51 = arith.addi %50, %c4096 : index
%52 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<?xf16, #spirv.storage_class<StorageBuffer>>{%51}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%53 = arith.muli %workgroup_id_x, %c4 : index
%54 = arith.addi %53, %50 : index
memref.store %cst_0, %52[%54] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%55 = arith.addi %54, %c1 : index
memref.store %cst_0, %52[%55] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%56 = arith.addi %54, %c2 : index
memref.store %cst_0, %52[%56] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%57 = arith.addi %54, %c3 : index
memref.store %cst_0, %52[%57] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
gpu.barrier
%58 = arith.muli %1, %c64 : index
%59 = arith.addi %0, %58 : index
%60 = arith.muli %2, %c256 : index
%61 = arith.addi %59, %60 : index
%62 = arith.addi %61, %42 : index
%63 = memref.load %44[%62] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%64 = arith.muli %1, %c68 : index
%65 = arith.addi %0, %64 : index
%66 = arith.muli %2, %c272 : index
%67 = arith.addi %65, %66 : index
%68 = arith.cmpi slt, %0, %c0 : index
%69 = arith.subi %c-1, %0 : index
%70 = arith.select %68, %69, %0 : index
%71 = arith.divsi %70, %c16 : index
%72 = arith.subi %c-1, %71 : index
%73 = arith.select %68, %72, %71 : index
%74 = arith.addi %67, %73 : index
memref.store %63, %alloc[%74] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%75 = arith.addi %62, %c256 : index
%76 = memref.load %44[%75] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%77 = arith.addi %74, %c272 : index
memref.store %76, %alloc[%77] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
gpu.barrier
%78 = arith.cmpi ult, %0, %c1 : index
scf.if %78 {
%79 = scf.for %arg0 = %c0 to %c32 step %c1 iter_args(%arg1 = %cst_1) -> (vector<4xf32>) {
%123 = arith.muli %1, %c32 : index
%124 = arith.addi %arg0, %123 : index
%125 = arith.muli %workgroup_id_x, %c128 : index
%126 = arith.addi %124, %125 : index
%127 = arith.addi %126, %26 : index
%128 = memref.load %28[%127] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%129 = arith.addi %126, %34 : index
%130 = memref.load %36[%129] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%131 = vector.splat %130 : vector<4xf16>
%132 = vector.splat %128 : vector<4xf16>
%133 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %arg1) -> (vector<4xf32>) {
%134 = arith.muli %arg0, %c16 : index
%135 = arith.muli %1, %c512 : index
%136 = arith.addi %134, %135 : index
%137 = arith.muli %workgroup_id_x, %c2048 : index
%138 = arith.addi %136, %137 : index
%139 = arith.cmpi slt, %arg2, %c0 : index
%140 = arith.subi %c-1, %arg2 : index
%141 = arith.select %139, %140, %arg2 : index
%142 = arith.divsi %141, %c8 : index
%143 = arith.subi %c-1, %142 : index
%144 = arith.select %139, %143, %142 : index
%145 = arith.addi %138, %144 : index
%146 = arith.addi %145, %18 : index
%147 = memref.load %20[%146] : memref<?xvector<1xi32>, #spirv.storage_class<StorageBuffer>>
%148 = vector.extract %147[0] : vector<1xi32>
%149 = arith.andi %148, %c15_i32 : i32
%150 = vector.insert %149, %cst [0] : i32 into vector<4xi32>
%151 = arith.shrui %148, %c4_i32 : i32
%152 = arith.andi %151, %c15_i32 : i32
%153 = vector.insert %152, %150 [1] : i32 into vector<4xi32>
%154 = arith.shrui %148, %c8_i32 : i32
%155 = arith.andi %154, %c15_i32 : i32
%156 = vector.insert %155, %153 [2] : i32 into vector<4xi32>
%157 = arith.shrui %148, %c12_i32 : i32
%158 = arith.andi %157, %c15_i32 : i32
%159 = vector.insert %158, %156 [3] : i32 into vector<4xi32>
%160 = arith.shrui %148, %c16_i32 : i32
%161 = arith.andi %160, %c15_i32 : i32
%162 = vector.insert %161, %cst [0] : i32 into vector<4xi32>
%163 = arith.shrui %148, %c20_i32 : i32
%164 = arith.andi %163, %c15_i32 : i32
%165 = vector.insert %164, %162 [1] : i32 into vector<4xi32>
%166 = arith.shrui %148, %c24_i32 : i32
%167 = arith.andi %166, %c15_i32 : i32
%168 = vector.insert %167, %165 [2] : i32 into vector<4xi32>
%169 = arith.shrui %148, %c28_i32 : i32
%170 = arith.andi %169, %c15_i32 : i32
%171 = vector.insert %170, %168 [3] : i32 into vector<4xi32>
%172 = arith.uitofp %159 : vector<4xi32> to vector<4xf16>
%173 = arith.uitofp %171 : vector<4xi32> to vector<4xf16>
%174 = arith.subf %172, %131 : vector<4xf16>
%175 = arith.subf %173, %131 : vector<4xf16>
%176 = arith.mulf %174, %132 : vector<4xf16>
%177 = arith.mulf %175, %132 : vector<4xf16>
%178 = arith.muli %arg0, %c17 : index
%179 = arith.addi %178, %144 : index
%180 = memref.load %alloc[%179] : memref<544xvector<4xf32>, #spirv.storage_class<Workgroup>>
%181 = vector.extract_strided_slice %180 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%182 = vector.bitcast %181 : vector<2xf32> to vector<4xf16>
%183 = arith.mulf %182, %176 : vector<4xf16>
%184 = vector.extract_strided_slice %180 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%185 = vector.bitcast %184 : vector<2xf32> to vector<4xf16>
%186 = arith.mulf %185, %177 : vector<4xf16>
%187 = vector.extract_strided_slice %arg3 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%188 = vector.bitcast %187 : vector<2xf32> to vector<4xf16>
%189 = arith.addf %183, %188 : vector<4xf16>
%190 = vector.extract_strided_slice %arg3 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%191 = vector.bitcast %190 : vector<2xf32> to vector<4xf16>
%192 = arith.addf %186, %191 : vector<4xf16>
%193 = vector.bitcast %192 : vector<4xf16> to vector<2xf32>
%194 = vector.bitcast %189 : vector<4xf16> to vector<2xf32>
%195 = vector.insert_strided_slice %194, %cst_1 {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
%196 = vector.insert_strided_slice %193, %195 {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
scf.yield %196 : vector<4xf32>
}
scf.yield %133 : vector<4xf32>
}
%80 = arith.addi %1, %53 : index
%81 = arith.addi %80, %50 : index
%82 = memref.load %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
%83 = vector.extract_strided_slice %79 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%84 = vector.bitcast %83 : vector<2xf32> to vector<4xf16>
%85 = vector.reduction <add>, %84 : vector<4xf16> into f16
%86 = vector.insert %85, %cst_2 [0] : f16 into vector<2xf16>
%87 = vector.extract_strided_slice %79 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
%88 = vector.bitcast %87 : vector<2xf32> to vector<4xf16>
%89 = vector.reduction <add>, %88 : vector<4xf16> into f16
%90 = vector.insert %89, %86 [1] : f16 into vector<2xf16>
%91 = vector.bitcast %90 : vector<2xf16> to vector<1xi32>
%92 = vector.extract %91[0] : vector<1xi32>
%shuffleResult, %valid = gpu.shuffle xor %92, %c1_i32, %c64_i32 : i32
%93 = vector.splat %shuffleResult : vector<1xi32>
%94 = vector.bitcast %93 : vector<1xi32> to vector<2xf16>
%95 = arith.addf %90, %94 : vector<2xf16>
%96 = vector.bitcast %95 : vector<2xf16> to vector<1xi32>
%97 = vector.extract %96[0] : vector<1xi32>
%shuffleResult_3, %valid_4 = gpu.shuffle xor %97, %c2_i32, %c64_i32 : i32
%98 = vector.splat %shuffleResult_3 : vector<1xi32>
%99 = vector.bitcast %98 : vector<1xi32> to vector<2xf16>
%100 = arith.addf %95, %99 : vector<2xf16>
%101 = vector.bitcast %100 : vector<2xf16> to vector<1xi32>
%102 = vector.extract %101[0] : vector<1xi32>
%shuffleResult_5, %valid_6 = gpu.shuffle xor %102, %c4_i32, %c64_i32 : i32
%103 = vector.splat %shuffleResult_5 : vector<1xi32>
%104 = vector.bitcast %103 : vector<1xi32> to vector<2xf16>
%105 = arith.addf %100, %104 : vector<2xf16>
%106 = vector.bitcast %105 : vector<2xf16> to vector<1xi32>
%107 = vector.extract %106[0] : vector<1xi32>
%shuffleResult_7, %valid_8 = gpu.shuffle xor %107, %c8_i32, %c64_i32 : i32
%108 = vector.splat %shuffleResult_7 : vector<1xi32>
%109 = vector.bitcast %108 : vector<1xi32> to vector<2xf16>
%110 = arith.addf %105, %109 : vector<2xf16>
%111 = vector.bitcast %110 : vector<2xf16> to vector<1xi32>
%112 = vector.extract %111[0] : vector<1xi32>
%shuffleResult_9, %valid_10 = gpu.shuffle xor %112, %c16_i32, %c64_i32 : i32
%113 = vector.splat %shuffleResult_9 : vector<1xi32>
%114 = vector.bitcast %113 : vector<1xi32> to vector<2xf16>
%115 = arith.addf %110, %114 : vector<2xf16>
%116 = vector.bitcast %115 : vector<2xf16> to vector<1xi32>
%117 = vector.extract %116[0] : vector<1xi32>
%shuffleResult_11, %valid_12 = gpu.shuffle xor %117, %c32_i32, %c64_i32 : i32
%118 = vector.splat %shuffleResult_11 : vector<1xi32>
%119 = vector.bitcast %118 : vector<1xi32> to vector<2xf16>
%120 = arith.addf %115, %119 : vector<2xf16>
%121 = vector.reduction <add>, %120 : vector<2xf16> into f16
%122 = arith.addf %121, %82 : f16
memref.store %122, %52[%81] : memref<?xf16, #spirv.storage_class<StorageBuffer>>
}
return
}
}
// -----// IR Dump After ConvertToSPIRV (iree-convert-to-spirv) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>} {
spirv.module Logical GLSL450 {
spirv.GlobalVariable @__builtin__WorkgroupId__ built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
spirv.GlobalVariable @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
spirv.GlobalVariable @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<544 x vector<4xf32>>)>, Workgroup>
spirv.GlobalVariable @__builtin__LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.GlobalVariable @__resource_var_0_0__0 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_0_ bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f16, stride=2> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_1_ bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_2_ bind(0, 2) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f16, stride=2> [0])>, StorageBuffer>
spirv.func @forward_dispatch_3_generic_4096x32x128_f16() "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [64, 4, 1]>} {
%cst17_i32 = spirv.Constant 17 : i32
%cst2048_i32 = spirv.Constant 2048 : i32
%cst272_i32 = spirv.Constant 272 : i32
%cst68_i32 = spirv.Constant 68 : i32
%cst256_i32 = spirv.Constant 256 : i32
%cst64_i32 = spirv.Constant 64 : i32
%cst3_i32 = spirv.Constant 3 : i32
%cst4096_i32 = spirv.Constant 4096 : i32
%cst512_i32 = spirv.Constant 512 : i32
%cst16_i32 = spirv.Constant 16 : i32
%cst131072_i32 = spirv.Constant 131072 : i32
%cst2_i32 = spirv.Constant 2 : i32
%cst2097152_i32 = spirv.Constant 2097152 : i32
%cst-1_i32 = spirv.Constant -1 : i32
%cst4_i32 = spirv.Constant 4 : i32
%cst12_i32 = spirv.Constant 12 : i32
%cst_vec_4xi32 = spirv.Constant dense<0> : vector<4xi32>
%cst15_i32 = spirv.Constant 15 : i32
%cst20_i32 = spirv.Constant 20 : i32
%cst24_i32 = spirv.Constant 24 : i32
%cst28_i32 = spirv.Constant 28 : i32
%cst_f16 = spirv.Constant 0.000000e+00 : f16
%cst_vec_4xf32 = spirv.Constant dense<0.000000e+00> : vector<4xf32>
%cst0_i32 = spirv.Constant 0 : i32
%cst1_i32 = spirv.Constant 1 : i32
%cst32_i32 = spirv.Constant 32 : i32
%cst16_i32_0 = spirv.Constant 16 : i32
%cst8_i32 = spirv.Constant 8 : i32
%cst4_i32_1 = spirv.Constant 4 : i32
%cst2_i32_2 = spirv.Constant 2 : i32
%cst64_i32_3 = spirv.Constant 64 : i32
%cst1_i32_4 = spirv.Constant 1 : i32
%cst_vec_2xf16 = spirv.Constant dense<0.000000e+00> : vector<2xf16>
%cst8_i32_5 = spirv.Constant 8 : i32
%cst128_i32 = spirv.Constant 128 : i32
%cst32_i32_6 = spirv.Constant 32 : i32
%__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
%0 = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
%1 = spirv.CompositeExtract %0[0 : i32] : vector<3xi32>
%__builtin__LocalInvocationId___addr_7 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
%2 = spirv.Load "Input" %__builtin__LocalInvocationId___addr_7 : vector<3xi32>
%3 = spirv.CompositeExtract %2[1 : i32] : vector<3xi32>
%__builtin__LocalInvocationId___addr_8 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
%4 = spirv.Load "Input" %__builtin__LocalInvocationId___addr_8 : vector<3xi32>
%5 = spirv.CompositeExtract %4[2 : i32] : vector<3xi32>
%__workgroup_mem__5_addr = spirv.mlir.addressof @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<544 x vector<4xf32>>)>, Workgroup>
%cst0_i32_9 = spirv.Constant 0 : i32
%cst0_i32_10 = spirv.Constant 0 : i32
%__push_constant_var___addr = spirv.mlir.addressof @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
%6 = spirv.AccessChain %__push_constant_var___addr[%cst0_i32_9, %cst0_i32_10] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
%7 = spirv.Load "PushConstant" %6 : i32
%cst0_i32_11 = spirv.Constant 0 : i32
%cst1_i32_12 = spirv.Constant 1 : i32
%__push_constant_var___addr_13 = spirv.mlir.addressof @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
%8 = spirv.AccessChain %__push_constant_var___addr_13[%cst0_i32_11, %cst1_i32_12] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
%9 = spirv.Load "PushConstant" %8 : i32
%cst0_i32_14 = spirv.Constant 0 : i32
%cst2_i32_15 = spirv.Constant 2 : i32
%__push_constant_var___addr_16 = spirv.mlir.addressof @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
%10 = spirv.AccessChain %__push_constant_var___addr_16[%cst0_i32_14, %cst2_i32_15] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
%11 = spirv.Load "PushConstant" %10 : i32
%cst0_i32_17 = spirv.Constant 0 : i32
%cst3_i32_18 = spirv.Constant 3 : i32
%__push_constant_var___addr_19 = spirv.mlir.addressof @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
%12 = spirv.AccessChain %__push_constant_var___addr_19[%cst0_i32_17, %cst3_i32_18] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
%13 = spirv.Load "PushConstant" %12 : i32
%cst0_i32_20 = spirv.Constant 0 : i32
%cst4_i32_21 = spirv.Constant 4 : i32
%__push_constant_var___addr_22 = spirv.mlir.addressof @__push_constant_var__ : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>
%14 = spirv.AccessChain %__push_constant_var___addr_22[%cst0_i32_20, %cst4_i32_21] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
%15 =
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment