Skip to content

Instantly share code, notes, and snippets.

@Max191
Last active September 18, 2023 16:57
Show Gist options
  • Save Max191/6ff02702d6b3ad2e760fafa6e53f02ea to your computer and use it in GitHub Desktop.
Save Max191/6ff02702d6b3ad2e760fafa6e53f02ea to your computer and use it in GitHub Desktop.
Tiling for reassociated quantized matmul
// -----// IR Dump After FoldUnitExtentDims (iree-flow-fold-unit-extent-dims) //----- //
func.func @quantized_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<11008x32x128xi4>
%1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<11008x32x1xf32>
%2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<11008x32x1xf32>
%3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<1x1x32x128xf32>
%4 = tensor.empty() : tensor<11008x32x128xf32>
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%collapsed_0 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0, %collapsed, %collapsed_0 : tensor<11008x32x128xi4>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%4 : tensor<11008x32x128xf32>) {
^bb0(%in: i4, %in_2: f32, %in_3: f32, %out: f32):
%10 = arith.extui %in : i4 to i32
%11 = arith.uitofp %10 : i32 to f32
%12 = arith.subf %11, %in_3 : f32
%13 = arith.mulf %12, %in_2 : f32
linalg.yield %13 : f32
} -> tensor<11008x32x128xf32>
%collapsed_1 = tensor.collapse_shape %3 [[0, 1, 2], [3]] : tensor<1x1x32x128xf32> into tensor<32x128xf32>
%6 = tensor.empty() : tensor<11008xf32>
%7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<11008xf32>) -> tensor<11008xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%collapsed_1, %5 : tensor<32x128xf32>, tensor<11008x32x128xf32>) outs(%7 : tensor<11008xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%10 = arith.mulf %in, %in_2 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<11008xf32>
%expanded = tensor.expand_shape %8 [[0, 1, 2]] : tensor<11008xf32> into tensor<1x1x11008xf32>
%9 = hal.tensor.export %expanded "output 0" : tensor<1x1x11008xf32> -> !hal.buffer_view
return %9 : !hal.buffer_view
}
// -----// IR Dump After FuseDequantizationMatmul (iree-flow-fuse-dequantization-matmul) //----- //
func.func @quantized_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<11008x32x128xi4>
%1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<11008x32x1xf32>
%2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<11008x32x1xf32>
%3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<1x1x32x128xf32>
%4 = tensor.empty() : tensor<11008x32x128xf32>
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%collapsed_0 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0, %collapsed, %collapsed_0 : tensor<11008x32x128xi4>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%4 : tensor<11008x32x128xf32>) {
^bb0(%in: i4, %in_4: f32, %in_5: f32, %out: f32):
%22 = arith.extui %in : i4 to i32
%23 = arith.uitofp %22 : i32 to f32
%24 = arith.subf %23, %in_5 : f32
%25 = arith.mulf %24, %in_4 : f32
linalg.yield %25 : f32
} -> tensor<11008x32x128xf32>
%collapsed_1 = tensor.collapse_shape %3 [[0, 1, 2], [3]] : tensor<1x1x32x128xf32> into tensor<32x128xf32>
%6 = tensor.empty() : tensor<11008xf32>
%7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<11008xf32>) -> tensor<11008xf32>
%cst_2 = arith.constant 3.276700e+04 : f32
%cst_3 = arith.constant 0.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%8 = tensor.empty() : tensor<32xf32>
%9 = linalg.fill ins(%cst_3 : f32) outs(%8 : tensor<32xf32>) -> tensor<32xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_1 : tensor<32x128xf32>) outs(%9 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%22 = math.absf %in : f32
%23 = arith.maxf %22, %out : f32
linalg.yield %23 : f32
} -> tensor<32xf32>
%11 = tensor.empty() : tensor<32xf32>
%12 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%10 : tensor<32xf32>) outs(%11 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%22 = arith.divf %in, %cst_2 : f32
linalg.yield %22 : f32
} -> tensor<32xf32>
%13 = tensor.empty() : tensor<32xf32>
%14 = linalg.fill ins(%cst_3 : f32) outs(%13 : tensor<32xf32>) -> tensor<32xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_1 : tensor<32x128xf32>) outs(%14 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%22 = arith.addf %in, %out : f32
linalg.yield %22 : f32
} -> tensor<32xf32>
%16 = tensor.empty() : tensor<32x128xi16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_1, %12 : tensor<32x128xf32>, tensor<32xf32>) outs(%16 : tensor<32x128xi16>) {
^bb0(%in: f32, %in_4: f32, %out: i16):
%22 = arith.divf %in, %in_4 : f32
%23 = arith.fptosi %22 : f32 to i16
linalg.yield %23 : i16
} -> tensor<32x128xi16>
%18 = tensor.empty() : tensor<11008x32xi32>
%19 = linalg.fill ins(%c0_i32 : i32) outs(%18 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
%20 = flow.dispatch.region -> (tensor<11008xf32>) {
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%17, %0 : tensor<32x128xi16>, tensor<11008x32x128xi4>) outs(%19 : tensor<11008x32xi32>) {
^bb0(%in: i16, %in_4: i4, %out: i32):
%24 = arith.extsi %in : i16 to i32
%25 = arith.extui %in_4 : i4 to i32
%26 = arith.muli %24, %25 : i32
%27 = arith.addi %26, %out : i32
linalg.yield %27 : i32
} -> tensor<11008x32xi32>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%22, %12, %15, %collapsed, %collapsed_0 : tensor<11008x32xi32>, tensor<32xf32>, tensor<32xf32>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%7 : tensor<11008xf32>) {
^bb0(%in: i32, %in_4: f32, %in_5: f32, %in_6: f32, %in_7: f32, %out: f32):
%24 = arith.sitofp %in : i32 to f32
%25 = arith.mulf %24, %in_4 : f32
%26 = arith.mulf %25, %in_6 : f32
%27 = arith.mulf %in_7, %in_6 : f32
%28 = arith.mulf %27, %in_5 : f32
%29 = arith.subf %26, %28 : f32
%30 = arith.addf %29, %out : f32
linalg.yield %30 : f32
} -> tensor<11008xf32>
flow.return %23 : tensor<11008xf32>
}
%expanded = tensor.expand_shape %20 [[0, 1, 2]] : tensor<11008xf32> into tensor<1x1x11008xf32>
%21 = hal.tensor.export %expanded "output 0" : tensor<1x1x11008xf32> -> !hal.buffer_view
return %21 : !hal.buffer_view
}
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @quantized_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 3.276700e+04 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<11008x32x128xi4>
%1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<11008x32x1xf32>
%2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<11008x32x1xf32>
%3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<1x1x32x128xf32>
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%collapsed_1 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%collapsed_2 = tensor.collapse_shape %3 [[0, 1, 2], [3]] : tensor<1x1x32x128xf32> into tensor<32x128xf32>
%4 = tensor.empty() : tensor<11008xf32>
%5 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<11008xf32>) -> tensor<11008xf32>
%6 = tensor.empty() : tensor<32xf32>
%7 = linalg.fill ins(%cst_0 : f32) outs(%6 : tensor<32xf32>) -> tensor<32xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_2 : tensor<32x128xf32>) outs(%7 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%20 = math.absf %in : f32
%21 = arith.maxf %20, %out : f32
linalg.yield %21 : f32
} -> tensor<32xf32>
%9 = tensor.empty() : tensor<32xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%8 : tensor<32xf32>) outs(%9 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.divf %in, %cst : f32
linalg.yield %20 : f32
} -> tensor<32xf32>
%11 = tensor.empty() : tensor<32xf32>
%12 = linalg.fill ins(%cst_0 : f32) outs(%11 : tensor<32xf32>) -> tensor<32xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_2 : tensor<32x128xf32>) outs(%12 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.addf %in, %out : f32
linalg.yield %20 : f32
} -> tensor<32xf32>
%14 = tensor.empty() : tensor<32x128xi16>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_2, %10 : tensor<32x128xf32>, tensor<32xf32>) outs(%14 : tensor<32x128xi16>) {
^bb0(%in: f32, %in_3: f32, %out: i16):
%20 = arith.divf %in, %in_3 : f32
%21 = arith.fptosi %20 : f32 to i16
linalg.yield %21 : i16
} -> tensor<32x128xi16>
%16 = tensor.empty() : tensor<11008x32xi32>
%17 = linalg.fill ins(%c0_i32 : i32) outs(%16 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
%18 = flow.dispatch.region -> (tensor<11008xf32>) {
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%15, %0 : tensor<32x128xi16>, tensor<11008x32x128xi4>) outs(%17 : tensor<11008x32xi32>) {
^bb0(%in: i16, %in_3: i4, %out: i32):
%22 = arith.extsi %in : i16 to i32
%23 = arith.extui %in_3 : i4 to i32
%24 = arith.muli %22, %23 : i32
%25 = arith.addi %24, %out : i32
linalg.yield %25 : i32
} -> tensor<11008x32xi32>
%21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%20, %10, %13, %collapsed, %collapsed_1 : tensor<11008x32xi32>, tensor<32xf32>, tensor<32xf32>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%5 : tensor<11008xf32>) {
^bb0(%in: i32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f32):
%22 = arith.sitofp %in : i32 to f32
%23 = arith.mulf %22, %in_3 : f32
%24 = arith.mulf %23, %in_5 : f32
%25 = arith.mulf %in_6, %in_5 : f32
%26 = arith.mulf %25, %in_4 : f32
%27 = arith.subf %24, %26 : f32
%28 = arith.addf %27, %out : f32
linalg.yield %28 : f32
} -> tensor<11008xf32>
flow.return %21 : tensor<11008xf32>
}
%expanded = tensor.expand_shape %18 [[0, 1, 2]] : tensor<11008xf32> into tensor<1x1x11008xf32>
%19 = hal.tensor.export %expanded "output 0" : tensor<1x1x11008xf32> -> !hal.buffer_view
return %19 : !hal.buffer_view
}
// -----// IR Dump After CSE (cse) //----- //
func.func @quantized_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 3.276700e+04 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<11008x32x128xi4>
%1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<11008x32x1xf32>
%2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<11008x32x1xf32>
%3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<1x1x32x128xf32>
%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%collapsed_1 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<11008x32x1xf32> into tensor<11008x32xf32>
%collapsed_2 = tensor.collapse_shape %3 [[0, 1, 2], [3]] : tensor<1x1x32x128xf32> into tensor<32x128xf32>
%4 = tensor.empty() : tensor<11008xf32>
%5 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<11008xf32>) -> tensor<11008xf32>
%6 = tensor.empty() : tensor<32xf32>
%7 = linalg.fill ins(%cst_0 : f32) outs(%6 : tensor<32xf32>) -> tensor<32xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_2 : tensor<32x128xf32>) outs(%7 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%17 = math.absf %in : f32
%18 = arith.maxf %17, %out : f32
linalg.yield %18 : f32
} -> tensor<32xf32>
%9 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%8 : tensor<32xf32>) outs(%6 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%17 = arith.divf %in, %cst : f32
linalg.yield %17 : f32
} -> tensor<32xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_2 : tensor<32x128xf32>) outs(%7 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%17 = arith.addf %in, %out : f32
linalg.yield %17 : f32
} -> tensor<32xf32>
%11 = tensor.empty() : tensor<32x128xi16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_2, %9 : tensor<32x128xf32>, tensor<32xf32>) outs(%11 : tensor<32x128xi16>) {
^bb0(%in: f32, %in_3: f32, %out: i16):
%17 = arith.divf %in, %in_3 : f32
%18 = arith.fptosi %17 : f32 to i16
linalg.yield %18 : i16
} -> tensor<32x128xi16>
%13 = tensor.empty() : tensor<11008x32xi32>
%14 = linalg.fill ins(%c0_i32 : i32) outs(%13 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
%15 = flow.dispatch.region -> (tensor<11008xf32>) {
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12, %0 : tensor<32x128xi16>, tensor<11008x32x128xi4>) outs(%14 : tensor<11008x32xi32>) {
^bb0(%in: i16, %in_3: i4, %out: i32):
%19 = arith.extsi %in : i16 to i32
%20 = arith.extui %in_3 : i4 to i32
%21 = arith.muli %19, %20 : i32
%22 = arith.addi %21, %out : i32
linalg.yield %22 : i32
} -> tensor<11008x32xi32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%17, %9, %10, %collapsed, %collapsed_1 : tensor<11008x32xi32>, tensor<32xf32>, tensor<32xf32>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%5 : tensor<11008xf32>) {
^bb0(%in: i32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f32):
%19 = arith.sitofp %in : i32 to f32
%20 = arith.mulf %19, %in_3 : f32
%21 = arith.mulf %20, %in_5 : f32
%22 = arith.mulf %in_6, %in_5 : f32
%23 = arith.mulf %22, %in_4 : f32
%24 = arith.subf %21, %23 : f32
%25 = arith.addf %24, %out : f32
linalg.yield %25 : f32
} -> tensor<11008xf32>
flow.return %18 : tensor<11008xf32>
}
%expanded = tensor.expand_shape %15 [[0, 1, 2]] : tensor<11008xf32> into tensor<1x1x11008xf32>
%16 = hal.tensor.export %expanded "output 0" : tensor<1x1x11008xf32> -> !hal.buffer_view
return %16 : !hal.buffer_view
}
// -----// IR Dump After EraseHALDescriptorTypeFromMemRef (iree-codegen-erase-hal-descriptor-type-from-memref) //----- //
module {
func.func @quantized_matmul_dispatch_3_generic_11008x32x128_i16xi4xi32() {
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c256) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xi16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32xf32>>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c128) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32xf32>>
%4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>>
%5 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>>
%6 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<11008xf32>>
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xi16>> -> tensor<32x128xi16>
%8 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [11008, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>> -> tensor<11008x32x128xi4>
%9 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readonly:tensor<32xf32>> -> tensor<32xf32>
%10 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readonly:tensor<32xf32>> -> tensor<32xf32>
%11 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [11008, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>> -> tensor<11008x32xf32>
%12 = flow.dispatch.tensor.load %5, offsets = [0, 0], sizes = [11008, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>> -> tensor<11008x32xf32>
%13 = tensor.empty() : tensor<11008xf32>
%14 = tensor.empty() : tensor<11008x32xi32>
%15 = linalg.fill ins(%cst : f32) outs(%13 : tensor<11008xf32>) -> tensor<11008xf32>
%16 = linalg.fill ins(%c0_i32 : i32) outs(%14 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%7, %8 : tensor<32x128xi16>, tensor<11008x32x128xi4>) outs(%16 : tensor<11008x32xi32>) {
^bb0(%in: i16, %in_0: i4, %out: i32):
%19 = arith.extsi %in : i16 to i32
%20 = arith.extui %in_0 : i4 to i32
%21 = arith.muli %19, %20 : i32
%22 = arith.addi %21, %out : i32
linalg.yield %22 : i32
} -> tensor<11008x32xi32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%17, %9, %10, %11, %12 : tensor<11008x32xi32>, tensor<32xf32>, tensor<32xf32>, tensor<11008x32xf32>, tensor<11008x32xf32>) outs(%15 : tensor<11008xf32>) {
^bb0(%in: i32, %in_0: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f32):
%19 = arith.sitofp %in : i32 to f32
%20 = arith.mulf %19, %in_0 : f32
%21 = arith.mulf %20, %in_2 : f32
%22 = arith.mulf %in_3, %in_2 : f32
%23 = arith.mulf %22, %in_1 : f32
%24 = arith.subf %21, %23 : f32
%25 = arith.addf %24, %out : f32
linalg.yield %25 : f32
} -> tensor<11008xf32>
flow.dispatch.tensor.store %18, %6, offsets = [0], sizes = [11008], strides = [1] : tensor<11008xf32> -> !flow.dispatch.tensor<writeonly:tensor<11008xf32>>
return
}
}
// -----// IR Dump After TileAndDistributeToWorkgroups (iree-codegen-tile-and-distribute-to-workgroups) //----- //
hal.executable.variant public @embedded_elf_x86_64, target = <"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-sm4,+sse4.1,+avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,+avx512vpopcntdq,+cmov,-avx512vp2intersect,+avx512cd,+movbe,-avxvnniint8,-avx512er,-amx-int8,-kl,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,+avx512vl,-uintr,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,+gfni,-avxvnniint16,-amx-fp16,+xsaveopt,+rdrnd,+avx512f,-amx-bf16,+avx512bf16,+avx512vnni,+cx8,+avx512bw,+sse3,+pku,+fsgsbase,+clzero,+mwaitx,-lwp,+lzcnt,+sha,-movdir64b,+wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,+avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,+avx512vbmi2,-prefetchi,+rdpid,-fma4,+avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,+avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = true}> {
hal.executable.export public @quantized_matmul_dispatch_3_generic_11008x32x128_i16xi4xi32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer, ReadOnly>, <4, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} {
^bb0(%arg0: !hal.device):
%c344 = arith.constant 344 : index
%c1 = arith.constant 1 : index
hal.return %c344, %c1, %c1 : index, index, index
}
builtin.module {
func.func @quantized_matmul_dispatch_3_generic_11008x32x128_i16xi4xi32() {
%c32 = arith.constant 32 : index
%c11008 = arith.constant 11008 : index
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c256) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xi16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32xf32>>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c128) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32xf32>>
%4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>>
%5 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>>
%6 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<11008xf32>>
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xi16>> -> tensor<32x128xi16>
%8 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readonly:tensor<32xf32>> -> tensor<32xf32>
%9 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readonly:tensor<32xf32>> -> tensor<32xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%10 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%11 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %10 to %c11008 step %11 {
%12 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [%c32, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>> -> tensor<?x32x128xi4>
%13 = tensor.empty() : tensor<32x32xi32>
%14 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 0], [8, 0], [0, 0], [0, 0]]>} ins(%c0_i32 : i32) outs(%13 : tensor<32x32xi32>) -> tensor<32x32xi32>
%cast = tensor.cast %12 : tensor<?x32x128xi4> to tensor<32x32x128xi4>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%7, %cast : tensor<32x128xi16>, tensor<32x32x128xi4>) outs(%14 : tensor<32x32xi32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 0, 0], [8, 0, 0], [0, 0, 16], [0, 8, 0]]>} {
^bb0(%in: i16, %in_3: i4, %out: i32):
%21 = arith.extsi %in : i16 to i32
%22 = arith.extui %in_3 : i4 to i32
%23 = arith.muli %21, %22 : i32
%24 = arith.addi %23, %out : i32
linalg.yield %24 : i32
} -> tensor<32x32xi32>
%16 = flow.dispatch.tensor.load %4, offsets = [%arg0, 0], sizes = [%c32, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>> -> tensor<?x32xf32>
%17 = flow.dispatch.tensor.load %5, offsets = [%arg0, 0], sizes = [%c32, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>> -> tensor<?x32xf32>
%18 = tensor.empty() : tensor<32xf32>
%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32], [8], [0], [0]]>} ins(%cst : f32) outs(%18 : tensor<32xf32>) -> tensor<32xf32>
%cast_0 = tensor.cast %16 : tensor<?x32xf32> to tensor<32x32xf32>
%cast_1 = tensor.cast %17 : tensor<?x32xf32> to tensor<32x32xf32>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%15, %8, %9, %cast_0, %cast_1 : tensor<32x32xi32>, tensor<32xf32>, tensor<32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) outs(%19 : tensor<32xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 0], [8, 0], [0, 32], [0, 0]]>} {
^bb0(%in: i32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f32):
%21 = arith.sitofp %in : i32 to f32
%22 = arith.mulf %21, %in_3 : f32
%23 = arith.mulf %22, %in_5 : f32
%24 = arith.mulf %in_6, %in_5 : f32
%25 = arith.mulf %24, %in_4 : f32
%26 = arith.subf %23, %25 : f32
%27 = arith.addf %26, %out : f32
linalg.yield %27 : f32
} -> tensor<32xf32>
%cast_2 = tensor.cast %20 : tensor<32xf32> to tensor<?xf32>
flow.dispatch.tensor.store %cast_2, %6, offsets = [%arg0], sizes = [%c32], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<11008xf32>>
}
return
}
}
}
// -----// IR Dump After LLVMCPUTileAndFuse (iree-llvmcpu-tile-and-fuse) //----- //
func.func @quantized_matmul_dispatch_3_generic_11008x32x128_i16xi4xi32() {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%c11008 = arith.constant 11008 : index
%c256 = arith.constant 256 : index
%c128 = arith.constant 128 : index
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c256) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x128xi16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32xf32>>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c128) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32xf32>>
%4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>>
%5 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>>
%6 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<11008xf32>>
%7 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32x128xi16>> -> tensor<32x128xi16>
%8 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readonly:tensor<32xf32>> -> tensor<32xf32>
%9 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [32], strides = [1] : !flow.dispatch.tensor<readonly:tensor<32xf32>> -> tensor<32xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%10 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%11 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
scf.for %arg0 = %10 to %c11008 step %11 {
%12 = flow.dispatch.tensor.load %6, offsets = [%arg0], sizes = [32], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<11008xf32>> -> tensor<32xf32>
%13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [32, 32, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32x128xi4>> -> tensor<32x32x128xi4>
%14 = flow.dispatch.tensor.load %4, offsets = [%arg0, 0], sizes = [32, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>> -> tensor<32x32xf32>
%15 = flow.dispatch.tensor.load %5, offsets = [%arg0, 0], sizes = [32, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11008x32xf32>> -> tensor<32x32xf32>
%16 = scf.for %arg1 = %c0 to %c32 step %c8 iter_args(%arg2 = %12) -> (tensor<32xf32>) {
%extracted_slice = tensor.extract_slice %13[%arg1, 0, 0] [8, 32, 128] [1, 1, 1] : tensor<32x32x128xi4> to tensor<8x32x128xi4>
%17 = tensor.empty() : tensor<8x32xi32>
%18 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 0], [8, 0], [0, 0], [0, 0]]>} ins(%c0_i32 : i32) outs(%17 : tensor<8x32xi32>) -> tensor<8x32xi32>
%19 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %18) -> (tensor<8x32xi32>) {
%extracted_slice_3 = tensor.extract_slice %7[0, %arg3] [32, 16] [1, 1] : tensor<32x128xi16> to tensor<32x16xi16>
%extracted_slice_4 = tensor.extract_slice %extracted_slice[0, 0, %arg3] [8, 32, 16] [1, 1, 1] : tensor<8x32x128xi4> to tensor<8x32x16xi4>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_3, %extracted_slice_4 : tensor<32x16xi16>, tensor<8x32x16xi4>) outs(%arg4 : tensor<8x32xi32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 0, 0], [8, 0, 0], [0, 0, 16], [0, 8, 0]]>} {
^bb0(%in: i16, %in_5: i4, %out: i32):
%23 = arith.extsi %in : i16 to i32
%24 = arith.extui %in_5 : i4 to i32
%25 = arith.muli %23, %24 : i32
%26 = arith.addi %25, %out : i32
linalg.yield %26 : i32
} -> tensor<8x32xi32>
scf.yield %22 : tensor<8x32xi32>
}
%extracted_slice_0 = tensor.extract_slice %14[%arg1, 0] [8, 32] [1, 1] : tensor<32x32xf32> to tensor<8x32xf32>
%extracted_slice_1 = tensor.extract_slice %15[%arg1, 0] [8, 32] [1, 1] : tensor<32x32xf32> to tensor<8x32xf32>
%extracted_slice_2 = tensor.extract_slice %arg2[%arg1] [8] [1] : tensor<32xf32> to tensor<8xf32>
%20 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32], [8], [0], [0]]>} ins(%cst : f32) outs(%extracted_slice_2 : tensor<8xf32>) -> tensor<8xf32>
%21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%19, %8, %9, %extracted_slice_0, %extracted_slice_1 : tensor<8x32xi32>, tensor<32xf32>, tensor<32xf32>, tensor<8x32xf32>, tensor<8x32xf32>) outs(%20 : tensor<8xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 0], [8, 0], [0, 32], [0, 0]]>} {
^bb0(%in: i32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f32):
%22 = arith.sitofp %in : i32 to f32
%23 = arith.mulf %22, %in_3 : f32
%24 = arith.mulf %23, %in_5 : f32
%25 = arith.mulf %in_6, %in_5 : f32
%26 = arith.mulf %25, %in_4 : f32
%27 = arith.subf %24, %26 : f32
%28 = arith.addf %27, %out : f32
linalg.yield %28 : f32
} -> tensor<8xf32>
%inserted_slice = tensor.insert_slice %21 into %arg2[%arg1] [8] [1] : tensor<8xf32> into tensor<32xf32>
scf.yield %inserted_slice : tensor<32xf32>
}
flow.dispatch.tensor.store %16, %6, offsets = [%arg0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor<writeonly:tensor<11008xf32>>
}
return
}
// -----// IR Dump After RemoveSingleIterationLoop (iree-codegen-remove-single-iteration-loop) //----- //
func.func @quantized_matmul_dispatch_3_generic_11008x32x128_i16xi4xi32() {
%cst = arith.constant dense<0.000000e+00> : vector<8xf32>
%c0_i4 = arith.constant 0 : i4
%c0_i16 = arith.constant 0 : i16
%cst_0 = arith.constant dense<0> : vector<8x32xi32>
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c128 = arith.constant 128 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c256) flags(ReadOnly) : memref<32x128xi16, strided<[128, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<32x128xi16, strided<[128, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32x128xi4, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<11008x32x128xi4, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<32xf32, #hal.descriptor_type<storage_buffer>>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c128) flags(ReadOnly) : memref<32xf32, strided<[1], offset: 32>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<32xf32, strided<[1], offset: 32>, #hal.descriptor_type<storage_buffer>>
%4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %4, 64 : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
%5 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %5, 64 : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>>
%6 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : memref<11008xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %6, 64 : memref<11008xf32, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%7 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview = memref.subview %6[%7] [32] [1] : memref<11008xf32, #hal.descriptor_type<storage_buffer>> to memref<32xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_2 = memref.subview %1[%7, 0, 0] [32, 32, 128] [1, 1, 1] : memref<11008x32x128xi4, #hal.descriptor_type<storage_buffer>> to memref<32x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %4[%7, 0] [32, 32] [1, 1] : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>> to memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_4 = memref.subview %5[%7, 0] [32, 32] [1, 1] : memref<11008x32xf32, #hal.descriptor_type<storage_buffer>> to memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
scf.for %arg0 = %c0 to %c32 step %c8 {
%8 = scf.for %arg1 = %c0 to %c128 step %c16 iter_args(%arg2 = %cst_0) -> (vector<8x32xi32>) {
%22 = vector.transfer_read %0[%c0, %arg1], %c0_i16 {in_bounds = [true, true]} : memref<32x128xi16, strided<[128, 1], offset: 128>, #hal.descriptor_type<storage_buffer>>, vector<32x16xi16>
%23 = vector.transfer_read %subview_2[%arg0, %c0, %arg1], %c0_i4 {in_bounds = [true, true, true]} : memref<32x32x128xi4, strided<[4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8x32x16xi4>
%24 = arith.extsi %22 : vector<32x16xi16> to vector<32x16xi32>
%25 = arith.extui %23 : vector<8x32x16xi4> to vector<8x32x16xi32>
%26 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %24, %25, %arg2 : vector<32x16xi32>, vector<8x32x16xi32> into vector<8x32xi32>
scf.yield %26 : vector<8x32xi32>
}
%9 = vector.transfer_read %2[%c0], %cst_1 {in_bounds = [true]} : memref<32xf32, #hal.descriptor_type<storage_buffer>>, vector<32xf32>
%10 = vector.broadcast %9 : vector<32xf32> to vector<8x32xf32>
%11 = vector.transfer_read %3[%c0], %cst_1 {in_bounds = [true]} : memref<32xf32, strided<[1], offset: 32>, #hal.descriptor_type<storage_buffer>>, vector<32xf32>
%12 = vector.broadcast %11 : vector<32xf32> to vector<8x32xf32>
%13 = vector.transfer_read %subview_3[%arg0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8x32xf32>
%14 = vector.transfer_read %subview_4[%arg0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<8x32xf32>
%15 = arith.sitofp %8 : vector<8x32xi32> to vector<8x32xf32>
%16 = arith.mulf %15, %10 : vector<8x32xf32>
%17 = arith.mulf %16, %13 : vector<8x32xf32>
%18 = arith.mulf %14, %13 : vector<8x32xf32>
%19 = arith.mulf %18, %12 : vector<8x32xf32>
%20 = arith.subf %17, %19 : vector<8x32xf32>
%21 = vector.multi_reduction <add>, %20, %cst [1] : vector<8x32xf32> to vector<8xf32>
vector.transfer_write %21, %subview[%arg0] {in_bounds = [true]} : vector<8xf32>, memref<32xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment