Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Created November 12, 2021 19:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stellaraccident/e65efeeeb25502401bf5153198abc815 to your computer and use it in GitHub Desktop.
Save stellaraccident/e65efeeeb25502401bf5153198abc815 to your computer and use it in GitHub Desktop.
MM Packing
This file has been truncated, but you can view the full file.
Loading:
Loading: 0 packages loaded
Analyzing: target //third_party/mlir_edge/iree_llvm_sandbox:matmul_bench (0 packages loaded, 0 targets configured)
INFO: Analyzed target //third_party/mlir_edge/iree_llvm_sandbox:matmul_bench (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
[0 / 10] [Prepa] action 'BuildInfo build-info.txt'
Target //third_party/mlir_edge/iree_llvm_sandbox:matmul_bench up-to-date:
blaze-bin/third_party/mlir_edge/iree_llvm_sandbox/matmul_bench
INFO: Elapsed time: 0.528s, Critical Path: 0.33s, Remote (72.78% of the time): [queue: 0.00%, upload: 72.48%, setup: 0.00%, process: 0.00%]
INFO: Build completed successfully, 3 total actions
INFO: Running command line: blaze-bin/third_party/mlir_edge/iree_llvm_sandbox/matmul_bench
INFO: Build completed successfully, 3 total actions
###############################################################
Runtime problem size {'M': 2040, 'N': 2041, 'K': 2042}
Compile-time problem size {'M': 2040, 'N': 2041, 'K': 2042}
Problem types [<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>]
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43d6affd0>]]]
#map0 = affine_map<(d0) -> (288, -d0 + 2040)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (128, -d0 + 2041)>
module {
func @matmul_on_tensors(%arg0: tensor<2040x2042xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2042x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2040x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<2040x2041xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c0 = arith.constant 0 : index
%c2042 = arith.constant 2042 : index
%c2041 = arith.constant 2041 : index
%c2040 = arith.constant 2040 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg2) : f32, tensor<2040x2041xf32> -> tensor<2040x2041xf32>
%1 = scf.for %arg3 = %c0 to %c2040 step %c288 iter_args(%arg4 = %0) -> (tensor<2040x2041xf32>) {
%2 = affine.min #map0(%arg3)
%3 = scf.for %arg5 = %c0 to %c2042 step %c512 iter_args(%arg6 = %arg4) -> (tensor<2040x2041xf32>) {
%4 = affine.min #map1(%arg5)
%5 = tensor.extract_slice %arg0[%arg3, %arg5] [%2, %4] [1, 1] : tensor<2040x2042xf32> to tensor<?x?xf32>
%6 = scf.for %arg7 = %c0 to %c2041 step %c128 iter_args(%arg8 = %arg6) -> (tensor<2040x2041xf32>) {
%7 = affine.min #map2(%arg7)
%8 = tensor.extract_slice %arg1[%arg5, %arg7] [%4, %7] [1, 1] : tensor<2042x2041xf32> to tensor<?x?xf32>
%9 = tensor.extract_slice %arg8[%arg3, %arg7] [%2, %7] [1, 1] : tensor<2040x2041xf32> to tensor<?x?xf32>
%10 = linalg.matmul ins(%5, %8 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%9 : tensor<?x?xf32>) -> tensor<?x?xf32>
%11 = tensor.insert_slice %10 into %arg8[%arg3, %arg7] [%2, %7] [1, 1] : tensor<?x?xf32> into tensor<2040x2041xf32>
scf.yield %11 : tensor<2040x2041xf32>
}
scf.yield %6 : tensor<2040x2041xf32>
}
scf.yield %3 : tensor<2040x2041xf32>
}
return %1 : tensor<2040x2041xf32>
}
func public @matmul_main(%arg0: tensor<2040x2042xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2042x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2040x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<2040x2041xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<2040x2041xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<2040x2042xf32>, tensor<2042x2041xf32>, tensor<2040x2041xf32>) -> tensor<2040x2041xf32>
scf.yield %1 : tensor<2040x2041xf32>
}
return %0 : tensor<2040x2041xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43b6dbc50>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (-d0 + 32)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0) -> (-d0 + 16)>
#map11 = affine_map<(d0) -> (288, -d0 + 2040)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0) -> (-d0 + 9)>
module {
func @matmul_on_tensors(%arg0: tensor<2040x2042xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2042x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2040x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<2040x2041xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2040 = arith.constant 2040 : index
%c2041 = arith.constant 2041 : index
%c2042 = arith.constant 2042 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<2040x2041xf32> -> tensor<2040x2041xf32>
%1 = linalg.init_tensor [4, 16, 4, 32, 16, 32] : tensor<4x16x4x32x16x32xf32>
%2 = tensor.cast %1 : tensor<4x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%3 = scf.for %arg3 = %c0 to %c2042 step %c512 iter_args(%arg4 = %2) -> (tensor<?x?x?x?x16x32xf32>) {
%7 = affine.apply #map0(%arg3)
%8 = affine.min #map1(%arg3)
%9 = scf.for %arg5 = %c0 to %c2041 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%10 = affine.apply #map2(%arg5)
%11 = affine.min #map3(%arg5)
%12 = scf.for %arg7 = %c0 to %11 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%13 = affine.apply #map4(%arg7)
%14 = affine.apply #map5(%arg7, %arg5)
%15 = affine.min #map6(%arg7, %11)
%16 = affine.apply #map7(%15)
%17 = scf.for %arg9 = %c0 to %8 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map8(%arg9)
%19 = affine.apply #map5(%arg9, %arg3)
%20 = affine.min #map9(%arg9, %8)
%21 = tensor.extract_slice %arg1[%19, %14] [%20, %15] [1, 1] : tensor<2042x2041xf32> to tensor<?x?xf32>
%22 = affine.apply #map10(%20)
%23 = linalg.pad_tensor %21 nofold low[%c0, %c0] high[%22, %16] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<16x32xf32>
%24 = tensor.insert_slice %23 into %arg10[%7, %10, %13, %18, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<?x?x?x?x16x32xf32>
scf.yield %24 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %12 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %9 : tensor<?x?x?x?x16x32xf32>
}
%4 = linalg.init_tensor [4, 32, 32, 9, 16] : tensor<4x32x32x9x16xf32>
%5 = tensor.cast %4 : tensor<4x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%6 = scf.for %arg3 = %c0 to %c2040 step %c288 iter_args(%arg4 = %0) -> (tensor<2040x2041xf32>) {
%7 = affine.min #map11(%arg3)
%8 = scf.for %arg5 = %c0 to %c2042 step %c512 iter_args(%arg6 = %5) -> (tensor<?x?x?x9x16xf32>) {
%10 = affine.apply #map0(%arg5)
%11 = affine.min #map1(%arg5)
%12 = scf.for %arg7 = %c0 to %7 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%13 = affine.apply #map12(%arg7)
%14 = affine.apply #map5(%arg7, %arg3)
%15 = affine.min #map13(%arg7, %7)
%16 = affine.apply #map14(%15)
%17 = scf.for %arg9 = %c0 to %11 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%18 = affine.apply #map8(%arg9)
%19 = affine.apply #map5(%arg9, %arg5)
%20 = affine.min #map9(%arg9, %11)
%21 = tensor.extract_slice %arg0[%14, %19] [%15, %20] [1, 1] : tensor<2040x2042xf32> to tensor<?x?xf32>
%22 = affine.apply #map10(%20)
%23 = linalg.pad_tensor %21 nofold low[%c0, %c0] high[%16, %22] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x16xf32>
%24 = tensor.insert_slice %23 into %arg10[%10, %13, %18, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<9x16xf32> into tensor<?x?x?x9x16xf32>
scf.yield %24 : tensor<?x?x?x9x16xf32>
}
scf.yield %17 : tensor<?x?x?x9x16xf32>
}
scf.yield %12 : tensor<?x?x?x9x16xf32>
}
%9 = scf.for %arg5 = %c0 to %c2042 step %c512 iter_args(%arg6 = %arg4) -> (tensor<2040x2041xf32>) {
%10 = affine.min #map1(%arg5)
%11 = affine.apply #map0(%arg5)
%12 = scf.for %arg7 = %c0 to %c2041 step %c128 iter_args(%arg8 = %arg6) -> (tensor<2040x2041xf32>) {
%13 = affine.min #map3(%arg7)
%14 = tensor.extract_slice %arg8[%arg3, %arg7] [%7, %13] [1, 1] : tensor<2040x2041xf32> to tensor<?x?xf32>
%15 = affine.apply #map2(%arg7)
%16 = scf.for %arg9 = %c0 to %7 step %c9 iter_args(%arg10 = %14) -> (tensor<?x?xf32>) {
%18 = affine.min #map13(%arg9, %7)
%19 = affine.apply #map12(%arg9)
%20 = affine.apply #map14(%18)
%21 = scf.for %arg11 = %c0 to %13 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%22 = affine.min #map6(%arg11, %13)
%23 = affine.apply #map4(%arg11)
%24 = affine.apply #map7(%22)
%25 = scf.for %arg13 = %c0 to %10 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%26 = tensor.extract_slice %arg14[%arg9, %arg11] [%18, %22] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%27 = affine.apply #map8(%arg13)
%28 = tensor.extract_slice %8[%11, %19, %27, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<?x?x?x9x16xf32> to tensor<9x16xf32>
%29 = tensor.extract_slice %3[%11, %15, %23, %27, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<?x?x?x?x16x32xf32> to tensor<16x32xf32>
%30 = linalg.pad_tensor %26 low[%c0, %c0] high[%20, %24] {
^bb0(%arg15: index, %arg16: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x32xf32>
%31 = linalg.matmul ins(%28, %29 : tensor<9x16xf32>, tensor<16x32xf32>) outs(%30 : tensor<9x32xf32>) -> tensor<9x32xf32>
%32 = tensor.extract_slice %31[0, 0] [%18, %22] [1, 1] : tensor<9x32xf32> to tensor<?x?xf32>
%33 = tensor.insert_slice %32 into %arg14[%arg9, %arg11] [%18, %22] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %33 : tensor<?x?xf32>
}
scf.yield %25 : tensor<?x?xf32>
}
scf.yield %21 : tensor<?x?xf32>
}
%17 = tensor.insert_slice %16 into %arg8[%arg3, %arg7] [%7, %13] [1, 1] : tensor<?x?xf32> into tensor<2040x2041xf32>
scf.yield %17 : tensor<2040x2041xf32>
}
scf.yield %12 : tensor<2040x2041xf32>
}
scf.yield %9 : tensor<2040x2041xf32>
}
return %6 : tensor<2040x2041xf32>
}
func public @matmul_main(%arg0: tensor<2040x2042xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2042x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2040x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<2040x2041xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<2040x2041xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<2040x2042xf32>, tensor<2042x2041xf32>, tensor<2040x2041xf32>) -> tensor<2040x2041xf32>
scf.yield %1 : tensor<2040x2041xf32>
}
return %0 : tensor<2040x2041xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Vectorize object at 0x7fb43b6bdc10>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0) -> (288, -d0 + 2040)>
#map10 = affine_map<(d0) -> (d0 ceildiv 9)>
#map11 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map12 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map13 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map14 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: tensor<2040x2042xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2042x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2040x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<2040x2041xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2042 = arith.constant 2042 : index
%c2041 = arith.constant 2041 : index
%c2040 = arith.constant 2040 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<2040x2041xf32> -> tensor<2040x2041xf32>
%1 = linalg.init_tensor [4, 16, 4, 32, 16, 32] : tensor<4x16x4x32x16x32xf32>
%2 = tensor.cast %1 : tensor<4x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%3 = scf.for %arg3 = %c0 to %c2042 step %c512 iter_args(%arg4 = %2) -> (tensor<?x?x?x?x16x32xf32>) {
%7 = affine.apply #map0(%arg3)
%8 = affine.min #map1(%arg3)
%9 = scf.for %arg5 = %c0 to %c2041 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%10 = affine.apply #map2(%arg5)
%11 = affine.min #map3(%arg5)
%12 = scf.for %arg7 = %c0 to %11 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%13 = affine.apply #map4(%arg7)
%14 = affine.apply #map5(%arg7, %arg5)
%15 = affine.min #map6(%arg7, %11)
%16 = scf.for %arg9 = %c0 to %8 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%17 = affine.apply #map7(%arg9)
%18 = affine.apply #map5(%arg9, %arg3)
%19 = affine.min #map8(%arg9, %8)
%20 = tensor.extract_slice %arg1[%18, %14] [%19, %15] [1, 1] : tensor<2042x2041xf32> to tensor<?x?xf32>
%21 = vector.transfer_read %20[%c0, %c0], %cst : tensor<?x?xf32>, vector<16x32xf32>
%22 = vector.transfer_write %21, %arg10[%7, %10, %13, %17, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, tensor<?x?x?x?x16x32xf32>
scf.yield %22 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %16 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %12 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %9 : tensor<?x?x?x?x16x32xf32>
}
%4 = linalg.init_tensor [4, 32, 32, 9, 16] : tensor<4x32x32x9x16xf32>
%5 = tensor.cast %4 : tensor<4x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%6 = scf.for %arg3 = %c0 to %c2040 step %c288 iter_args(%arg4 = %0) -> (tensor<2040x2041xf32>) {
%7 = affine.min #map9(%arg3)
%8 = scf.for %arg5 = %c0 to %c2042 step %c512 iter_args(%arg6 = %5) -> (tensor<?x?x?x9x16xf32>) {
%10 = affine.apply #map0(%arg5)
%11 = affine.min #map1(%arg5)
%12 = scf.for %arg7 = %c0 to %7 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%13 = affine.apply #map10(%arg7)
%14 = affine.apply #map5(%arg7, %arg3)
%15 = affine.min #map11(%arg7, %7)
%16 = scf.for %arg9 = %c0 to %11 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%17 = affine.apply #map7(%arg9)
%18 = affine.apply #map5(%arg9, %arg5)
%19 = affine.min #map8(%arg9, %11)
%20 = tensor.extract_slice %arg0[%14, %18] [%15, %19] [1, 1] : tensor<2040x2042xf32> to tensor<?x?xf32>
%21 = vector.transfer_read %20[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x16xf32>
%22 = vector.transfer_write %21, %arg10[%10, %13, %17, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, tensor<?x?x?x9x16xf32>
scf.yield %22 : tensor<?x?x?x9x16xf32>
}
scf.yield %16 : tensor<?x?x?x9x16xf32>
}
scf.yield %12 : tensor<?x?x?x9x16xf32>
}
%9 = scf.for %arg5 = %c0 to %c2042 step %c512 iter_args(%arg6 = %arg4) -> (tensor<2040x2041xf32>) {
%10 = affine.min #map1(%arg5)
%11 = affine.apply #map0(%arg5)
%12 = scf.for %arg7 = %c0 to %c2041 step %c128 iter_args(%arg8 = %arg6) -> (tensor<2040x2041xf32>) {
%13 = affine.min #map3(%arg7)
%14 = tensor.extract_slice %arg8[%arg3, %arg7] [%7, %13] [1, 1] : tensor<2040x2041xf32> to tensor<?x?xf32>
%15 = affine.apply #map2(%arg7)
%16 = scf.for %arg9 = %c0 to %7 step %c9 iter_args(%arg10 = %14) -> (tensor<?x?xf32>) {
%18 = affine.min #map11(%arg9, %7)
%19 = affine.apply #map10(%arg9)
%20 = scf.for %arg11 = %c0 to %13 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%21 = affine.min #map6(%arg11, %13)
%22 = affine.apply #map4(%arg11)
%23 = scf.for %arg13 = %c0 to %10 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%24 = tensor.extract_slice %arg14[%arg9, %arg11] [%18, %21] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%25 = affine.apply #map7(%arg13)
%26 = vector.transfer_read %8[%11, %19, %25, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x9x16xf32>, vector<9x16xf32>
%27 = vector.transfer_read %3[%11, %15, %22, %25, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x?x16x32xf32>, vector<16x32xf32>
%28 = vector.transfer_read %24[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x32xf32>
%29 = vector.contract {indexing_maps = [#map12, #map13, #map14], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %26, %27, %28 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
%30 = vector.transfer_write %29, %24[%c0, %c0] : vector<9x32xf32>, tensor<?x?xf32>
%31 = tensor.insert_slice %30 into %arg14[%arg9, %arg11] [%18, %21] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %31 : tensor<?x?xf32>
}
scf.yield %23 : tensor<?x?xf32>
}
scf.yield %20 : tensor<?x?xf32>
}
%17 = tensor.insert_slice %16 into %arg8[%arg3, %arg7] [%7, %13] [1, 1] : tensor<?x?xf32> into tensor<2040x2041xf32>
scf.yield %17 : tensor<2040x2041xf32>
}
scf.yield %12 : tensor<2040x2041xf32>
}
scf.yield %9 : tensor<2040x2041xf32>
}
return %6 : tensor<2040x2041xf32>
}
func public @matmul_main(%arg0: tensor<2040x2042xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2042x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2040x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<2040x2041xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<2040x2041xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<2040x2042xf32>, tensor<2042x2041xf32>, tensor<2040x2041xf32>) -> tensor<2040x2041xf32>
scf.yield %1 : tensor<2040x2041xf32>
}
return %0 : tensor<2040x2041xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Bufferize object at 0x7fb43b692110>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0) -> (288, -d0 + 2040)>
#map11 = affine_map<(d0) -> (d0 ceildiv 9)>
#map12 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map13 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
#map14 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map15 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map16 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2040 = arith.constant 2040 : index
%c2041 = arith.constant 2041 : index
%c2042 = arith.constant 2042 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%2 = affine.apply #map0(%arg3)
%3 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%4 = affine.apply #map2(%arg4)
%5 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %5 step %c32 {
%6 = affine.apply #map4(%arg5)
%7 = affine.apply #map5(%arg5, %arg4)
%8 = affine.min #map6(%arg5, %5)
scf.for %arg6 = %c0 to %3 step %c16 {
%9 = affine.apply #map7(%arg6)
%10 = affine.apply #map5(%arg6, %arg3)
%11 = affine.min #map8(%arg6, %3)
%12 = memref.subview %arg1[%10, %7] [%11, %8] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%13 = vector.transfer_read %12[%c0, %c0], %cst : memref<?x?xf32, #map9>, vector<16x32xf32>
vector.transfer_write %13, %1[%2, %4, %6, %9, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%2 = affine.min #map10(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%3 = affine.apply #map0(%arg4)
%4 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %2 step %c9 {
%5 = affine.apply #map11(%arg5)
%6 = affine.apply #map5(%arg5, %arg3)
%7 = affine.min #map12(%arg5, %2)
scf.for %arg6 = %c0 to %4 step %c16 {
%8 = affine.apply #map7(%arg6)
%9 = affine.apply #map5(%arg6, %arg4)
%10 = affine.min #map8(%arg6, %4)
%11 = memref.subview %arg0[%6, %9] [%7, %10] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map13>
%12 = vector.transfer_read %11[%c0, %c0], %cst : memref<?x?xf32, #map13>, vector<9x16xf32>
vector.transfer_write %12, %0[%3, %5, %8, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%3 = affine.min #map1(%arg4)
%4 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%5 = affine.min #map3(%arg5)
%6 = memref.subview %arg2[%arg3, %arg5] [%2, %5] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%7 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %2 step %c9 {
%8 = affine.min #map12(%arg6, %2)
%9 = affine.apply #map11(%arg6)
scf.for %arg7 = %c0 to %5 step %c32 {
%10 = affine.min #map6(%arg7, %5)
%11 = affine.apply #map4(%arg7)
%12 = memref.subview %6[%arg6, %arg7] [%8, %10] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%13 = vector.transfer_read %12[%c0, %c0], %cst : memref<?x?xf32, #map9>, vector<9x32xf32>
%14 = scf.for %arg8 = %c0 to %3 step %c16 iter_args(%arg9 = %13) -> (vector<9x32xf32>) {
%15 = affine.apply #map7(%arg8)
%16 = vector.transfer_read %0[%4, %9, %15, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%17 = vector.transfer_read %1[%4, %7, %11, %15, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%18 = vector.contract {indexing_maps = [#map14, #map15, #map16], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %16, %17, %arg9 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
scf.yield %18 : vector<9x32xf32>
}
vector.transfer_write %14, %12[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map9>
}
}
}
}
}
memref.dealloc %1 : memref<4x16x4x32x16x32xf32>
memref.dealloc %0 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692190>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0) -> (288, -d0 + 2040)>
#map11 = affine_map<(d0) -> (d0 ceildiv 9)>
#map12 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map13 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2042 = arith.constant 2042 : index
%c2041 = arith.constant 2041 : index
%c2040 = arith.constant 2040 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%2 = affine.apply #map0(%arg3)
%3 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%4 = affine.apply #map2(%arg4)
%5 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %5 step %c32 {
%6 = affine.apply #map4(%arg5)
%7 = affine.apply #map5(%arg5, %arg4)
%8 = affine.min #map6(%arg5, %5)
scf.for %arg6 = %c0 to %3 step %c16 {
%9 = affine.apply #map7(%arg6)
%10 = affine.apply #map5(%arg6, %arg3)
%11 = affine.min #map8(%arg6, %3)
%12 = memref.subview %arg1[%10, %7] [%11, %8] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%13 = vector.transfer_read %12[%c0, %c0], %cst : memref<?x?xf32, #map9>, vector<16x32xf32>
vector.transfer_write %13, %1[%2, %4, %6, %9, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%2 = affine.min #map10(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%3 = affine.apply #map0(%arg4)
%4 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %2 step %c9 {
%5 = affine.apply #map11(%arg5)
%6 = affine.apply #map5(%arg5, %arg3)
%7 = affine.min #map12(%arg5, %2)
scf.for %arg6 = %c0 to %4 step %c16 {
%8 = affine.apply #map7(%arg6)
%9 = affine.apply #map5(%arg6, %arg4)
%10 = affine.min #map8(%arg6, %4)
%11 = memref.subview %arg0[%6, %9] [%7, %10] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map13>
%12 = vector.transfer_read %11[%c0, %c0], %cst : memref<?x?xf32, #map13>, vector<9x16xf32>
vector.transfer_write %12, %0[%3, %5, %8, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%3 = affine.min #map1(%arg4)
%4 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%5 = affine.min #map3(%arg5)
%6 = memref.subview %arg2[%arg3, %arg5] [%2, %5] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%7 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %2 step %c9 {
%8 = affine.min #map12(%arg6, %2)
%9 = affine.apply #map11(%arg6)
scf.for %arg7 = %c0 to %5 step %c32 {
%10 = affine.min #map6(%arg7, %5)
%11 = affine.apply #map4(%arg7)
%12 = memref.subview %6[%arg6, %arg7] [%8, %10] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%13 = vector.transfer_read %12[%c0, %c0], %cst : memref<?x?xf32, #map9>, vector<9x32xf32>
%14 = scf.for %arg8 = %c0 to %3 step %c16 iter_args(%arg9 = %13) -> (vector<9x32xf32>) {
%15 = affine.apply #map7(%arg8)
%16 = vector.transfer_read %0[%4, %9, %15, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%17 = vector.transfer_read %1[%4, %7, %11, %15, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%18 = vector.transpose %16, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%19 = vector.extract %18[0] : vector<16x9xf32>
%20 = vector.extract %17[0] : vector<16x32xf32>
%21 = vector.outerproduct %19, %20, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%22 = vector.extract %18[1] : vector<16x9xf32>
%23 = vector.extract %17[1] : vector<16x32xf32>
%24 = vector.outerproduct %22, %23, %21 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%25 = vector.extract %18[2] : vector<16x9xf32>
%26 = vector.extract %17[2] : vector<16x32xf32>
%27 = vector.outerproduct %25, %26, %24 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%28 = vector.extract %18[3] : vector<16x9xf32>
%29 = vector.extract %17[3] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %27 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %18[4] : vector<16x9xf32>
%32 = vector.extract %17[4] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %18[5] : vector<16x9xf32>
%35 = vector.extract %17[5] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %18[6] : vector<16x9xf32>
%38 = vector.extract %17[6] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %18[7] : vector<16x9xf32>
%41 = vector.extract %17[7] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %18[8] : vector<16x9xf32>
%44 = vector.extract %17[8] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %18[9] : vector<16x9xf32>
%47 = vector.extract %17[9] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %18[10] : vector<16x9xf32>
%50 = vector.extract %17[10] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %18[11] : vector<16x9xf32>
%53 = vector.extract %17[11] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %18[12] : vector<16x9xf32>
%56 = vector.extract %17[12] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %18[13] : vector<16x9xf32>
%59 = vector.extract %17[13] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %18[14] : vector<16x9xf32>
%62 = vector.extract %17[14] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %18[15] : vector<16x9xf32>
%65 = vector.extract %17[15] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %66 : vector<9x32xf32>
}
vector.transfer_write %14, %12[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map9>
}
}
}
}
}
memref.dealloc %1 : memref<4x16x4x32x16x32xf32>
memref.dealloc %0 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6923d0>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0) -> (288, -d0 + 2040)>
#map11 = affine_map<(d0) -> (d0 ceildiv 9)>
#map12 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map13 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2040 = arith.constant 2040 : index
%c2041 = arith.constant 2041 : index
%c2042 = arith.constant 2042 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%2 = affine.apply #map0(%arg3)
%3 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%4 = affine.apply #map2(%arg4)
%5 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %5 step %c32 {
%6 = affine.apply #map4(%arg5)
%7 = affine.apply #map5(%arg5, %arg4)
%8 = affine.min #map6(%arg5, %5)
scf.for %arg6 = %c0 to %3 step %c16 {
%9 = affine.apply #map7(%arg6)
%10 = affine.apply #map5(%arg6, %arg3)
%11 = affine.min #map8(%arg6, %3)
%12 = memref.subview %arg1[%10, %7] [%11, %8] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%13 = vector.transfer_read %12[%c0, %c0], %cst : memref<?x?xf32, #map9>, vector<16x32xf32>
vector.transfer_write %13, %1[%2, %4, %6, %9, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%2 = affine.min #map10(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%3 = affine.apply #map0(%arg4)
%4 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %2 step %c9 {
%5 = affine.apply #map11(%arg5)
%6 = affine.apply #map5(%arg5, %arg3)
%7 = affine.min #map12(%arg5, %2)
scf.for %arg6 = %c0 to %4 step %c16 {
%8 = affine.apply #map7(%arg6)
%9 = affine.apply #map5(%arg6, %arg4)
%10 = affine.min #map8(%arg6, %4)
%11 = memref.subview %arg0[%6, %9] [%7, %10] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map13>
%12 = vector.transfer_read %11[%c0, %c0], %cst : memref<?x?xf32, #map13>, vector<9x16xf32>
vector.transfer_write %12, %0[%3, %5, %8, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%3 = affine.min #map1(%arg4)
%4 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%5 = affine.min #map3(%arg5)
%6 = memref.subview %arg2[%arg3, %arg5] [%2, %5] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%7 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %2 step %c9 {
%8 = affine.min #map12(%arg6, %2)
%9 = affine.apply #map11(%arg6)
scf.for %arg7 = %c0 to %5 step %c32 {
%10 = affine.min #map6(%arg7, %5)
%11 = affine.apply #map4(%arg7)
%12 = memref.subview %6[%arg6, %arg7] [%8, %10] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%13 = vector.transfer_read %12[%c0, %c0], %cst : memref<?x?xf32, #map9>, vector<9x32xf32>
%14 = scf.for %arg8 = %c0 to %3 step %c16 iter_args(%arg9 = %13) -> (vector<9x32xf32>) {
%15 = affine.apply #map7(%arg8)
%16 = vector.transfer_read %0[%4, %9, %15, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%17 = vector.transfer_read %1[%4, %7, %11, %15, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%18 = vector.transpose %16, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%19 = vector.extract %18[0] : vector<16x9xf32>
%20 = vector.extract %17[0] : vector<16x32xf32>
%21 = vector.outerproduct %19, %20, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%22 = vector.extract %18[1] : vector<16x9xf32>
%23 = vector.extract %17[1] : vector<16x32xf32>
%24 = vector.outerproduct %22, %23, %21 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%25 = vector.extract %18[2] : vector<16x9xf32>
%26 = vector.extract %17[2] : vector<16x32xf32>
%27 = vector.outerproduct %25, %26, %24 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%28 = vector.extract %18[3] : vector<16x9xf32>
%29 = vector.extract %17[3] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %27 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %18[4] : vector<16x9xf32>
%32 = vector.extract %17[4] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %18[5] : vector<16x9xf32>
%35 = vector.extract %17[5] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %18[6] : vector<16x9xf32>
%38 = vector.extract %17[6] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %18[7] : vector<16x9xf32>
%41 = vector.extract %17[7] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %18[8] : vector<16x9xf32>
%44 = vector.extract %17[8] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %18[9] : vector<16x9xf32>
%47 = vector.extract %17[9] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %18[10] : vector<16x9xf32>
%50 = vector.extract %17[10] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %18[11] : vector<16x9xf32>
%53 = vector.extract %17[11] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %18[12] : vector<16x9xf32>
%56 = vector.extract %17[12] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %18[13] : vector<16x9xf32>
%59 = vector.extract %17[13] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %18[14] : vector<16x9xf32>
%62 = vector.extract %17[14] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %18[15] : vector<16x9xf32>
%65 = vector.extract %17[15] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %66 : vector<9x32xf32>
}
vector.transfer_write %14, %12[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map9>
}
}
}
}
}
memref.dealloc %1 : memref<4x16x4x32x16x32xf32>
memref.dealloc %0 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692210>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map12 = affine_map<(d0) -> (288, -d0 + 2040)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2042 = arith.constant 2042 : index
%c2041 = arith.constant 2041 : index
%c2040 = arith.constant 2040 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%5 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%6 = affine.apply #map0(%arg3)
%7 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %9 step %c32 {
%10 = affine.apply #map4(%arg5)
%11 = affine.apply #map5(%arg5, %arg4)
%12 = affine.min #map6(%arg5, %9)
%13 = arith.cmpi sle, %c32, %12 : index
scf.for %arg6 = %c0 to %7 step %c16 {
%14 = affine.apply #map7(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map8(%arg6, %7)
%17 = memref.subview %arg1[%15, %11] [%16, %12] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = arith.andi %18, %13 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%22 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %22 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%22 = memref.subview %17[0, 0] [%16, %12] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%23 = memref.subview %0[0, 0] [%16, %12] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%22, %23) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%24 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map10>
scf.yield %24 : memref<?x?xf32, #map10>
}
%21 = vector.transfer_read %20[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map10>, vector<16x32xf32>
vector.transfer_write %21, %5[%6, %8, %10, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%6 = affine.min #map12(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.apply #map0(%arg4)
%8 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %6 step %c9 {
%9 = affine.apply #map13(%arg5)
%10 = affine.apply #map5(%arg5, %arg3)
%11 = affine.min #map14(%arg5, %6)
%12 = arith.cmpi sle, %c9, %11 : index
scf.for %arg6 = %c0 to %8 step %c16 {
%13 = affine.apply #map7(%arg6)
%14 = affine.apply #map5(%arg6, %arg4)
%15 = affine.min #map8(%arg6, %8)
%16 = memref.subview %arg0[%10, %14] [%11, %15] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map15>
%17 = arith.cmpi sle, %c16, %15 : index
%18 = arith.andi %12, %17 : i1
%19 = scf.if %18 -> (memref<?x?xf32, #map10>) {
%21 = memref.cast %16 : memref<?x?xf32, #map15> to memref<?x?xf32, #map10>
scf.yield %21 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%21 = memref.subview %16[0, 0] [%11, %15] [1, 1] : memref<?x?xf32, #map15> to memref<?x?xf32, #map15>
%22 = memref.subview %1[0, 0] [%11, %15] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%21, %22) : memref<?x?xf32, #map15>, memref<?x?xf32, #map16>
%23 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map10>
scf.yield %23 : memref<?x?xf32, #map10>
}
%20 = vector.transfer_read %19[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map10>, vector<9x16xf32>
vector.transfer_write %20, %4[%7, %9, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.min #map1(%arg4)
%8 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%9 = affine.min #map3(%arg5)
%10 = memref.subview %arg2[%arg3, %arg5] [%6, %9] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%11 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %6 step %c9 {
%12 = affine.min #map14(%arg6, %6)
%13 = affine.apply #map13(%arg6)
%14 = arith.cmpi sle, %c9, %12 : index
scf.for %arg7 = %c0 to %9 step %c32 {
%15 = affine.min #map6(%arg7, %9)
%16 = affine.apply #map4(%arg7)
%17 = memref.subview %10[%arg6, %arg7] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c32, %15 : index
%19 = arith.andi %14, %18 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%25 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %25 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%25 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%26 = memref.subview %2[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%25, %26) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%27 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %27 : memref<?x?xf32, #map10>
}
%21 = vector.transfer_read %20[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map10>, vector<9x32xf32>
%22 = scf.for %arg8 = %c0 to %7 step %c16 iter_args(%arg9 = %21) -> (vector<9x32xf32>) {
%25 = affine.apply #map7(%arg8)
%26 = vector.transfer_read %4[%8, %13, %25, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%27 = vector.transfer_read %5[%8, %11, %16, %25, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%28 = vector.transpose %26, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%29 = vector.extract %28[0] : vector<16x9xf32>
%30 = vector.extract %27[0] : vector<16x32xf32>
%31 = vector.outerproduct %29, %30, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%32 = vector.extract %28[1] : vector<16x9xf32>
%33 = vector.extract %27[1] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %31 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %28[2] : vector<16x9xf32>
%36 = vector.extract %27[2] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %28[3] : vector<16x9xf32>
%39 = vector.extract %27[3] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %28[4] : vector<16x9xf32>
%42 = vector.extract %27[4] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %28[5] : vector<16x9xf32>
%45 = vector.extract %27[5] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %28[6] : vector<16x9xf32>
%48 = vector.extract %27[6] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %28[7] : vector<16x9xf32>
%51 = vector.extract %27[7] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %28[8] : vector<16x9xf32>
%54 = vector.extract %27[8] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %28[9] : vector<16x9xf32>
%57 = vector.extract %27[9] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %28[10] : vector<16x9xf32>
%60 = vector.extract %27[10] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %28[11] : vector<16x9xf32>
%63 = vector.extract %27[11] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %28[12] : vector<16x9xf32>
%66 = vector.extract %27[12] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%68 = vector.extract %28[13] : vector<16x9xf32>
%69 = vector.extract %27[13] : vector<16x32xf32>
%70 = vector.outerproduct %68, %69, %67 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%71 = vector.extract %28[14] : vector<16x9xf32>
%72 = vector.extract %27[14] : vector<16x32xf32>
%73 = vector.outerproduct %71, %72, %70 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%74 = vector.extract %28[15] : vector<16x9xf32>
%75 = vector.extract %27[15] : vector<16x32xf32>
%76 = vector.outerproduct %74, %75, %73 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %76 : vector<9x32xf32>
}
%23 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%25 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %25 : memref<?x?xf32, #map10>
} else {
%25 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %25 : memref<?x?xf32, #map10>
}
vector.transfer_write %22, %23[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map10>
%24 = arith.xori %19, %true : i1
scf.if %24 {
%25 = memref.subview %3[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
%26 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
linalg.copy(%25, %26) : memref<?x?xf32, #map11>, memref<?x?xf32, #map9>
}
}
}
}
}
}
memref.dealloc %5 : memref<4x16x4x32x16x32xf32>
memref.dealloc %4 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6921d0>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map12 = affine_map<(d0) -> (288, -d0 + 2040)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2040 = arith.constant 2040 : index
%c2041 = arith.constant 2041 : index
%c2042 = arith.constant 2042 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%true = arith.constant true
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%5 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%6 = affine.apply #map0(%arg3)
%7 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %9 step %c32 {
%10 = affine.apply #map4(%arg5)
%11 = affine.apply #map5(%arg5, %arg4)
%12 = affine.min #map6(%arg5, %9)
%13 = arith.cmpi sle, %c32, %12 : index
scf.for %arg6 = %c0 to %7 step %c16 {
%14 = affine.apply #map7(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map8(%arg6, %7)
%17 = memref.subview %arg1[%15, %11] [%16, %12] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = arith.andi %18, %13 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%22 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %22 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%22 = memref.subview %17[0, 0] [%16, %12] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%23 = memref.subview %0[0, 0] [%16, %12] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%22, %23) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%24 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map10>
scf.yield %24 : memref<?x?xf32, #map10>
}
%21 = vector.transfer_read %20[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map10>, vector<16x32xf32>
vector.transfer_write %21, %5[%6, %8, %10, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%6 = affine.min #map12(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.apply #map0(%arg4)
%8 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %6 step %c9 {
%9 = affine.apply #map13(%arg5)
%10 = affine.apply #map5(%arg5, %arg3)
%11 = affine.min #map14(%arg5, %6)
%12 = arith.cmpi sle, %c9, %11 : index
scf.for %arg6 = %c0 to %8 step %c16 {
%13 = affine.apply #map7(%arg6)
%14 = affine.apply #map5(%arg6, %arg4)
%15 = affine.min #map8(%arg6, %8)
%16 = memref.subview %arg0[%10, %14] [%11, %15] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map15>
%17 = arith.cmpi sle, %c16, %15 : index
%18 = arith.andi %12, %17 : i1
%19 = scf.if %18 -> (memref<?x?xf32, #map10>) {
%21 = memref.cast %16 : memref<?x?xf32, #map15> to memref<?x?xf32, #map10>
scf.yield %21 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%21 = memref.subview %16[0, 0] [%11, %15] [1, 1] : memref<?x?xf32, #map15> to memref<?x?xf32, #map15>
%22 = memref.subview %1[0, 0] [%11, %15] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%21, %22) : memref<?x?xf32, #map15>, memref<?x?xf32, #map16>
%23 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map10>
scf.yield %23 : memref<?x?xf32, #map10>
}
%20 = vector.transfer_read %19[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map10>, vector<9x16xf32>
vector.transfer_write %20, %4[%7, %9, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.min #map1(%arg4)
%8 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%9 = affine.min #map3(%arg5)
%10 = memref.subview %arg2[%arg3, %arg5] [%6, %9] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%11 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %6 step %c9 {
%12 = affine.min #map14(%arg6, %6)
%13 = affine.apply #map13(%arg6)
%14 = arith.cmpi sle, %c9, %12 : index
scf.for %arg7 = %c0 to %9 step %c32 {
%15 = affine.min #map6(%arg7, %9)
%16 = affine.apply #map4(%arg7)
%17 = memref.subview %10[%arg6, %arg7] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c32, %15 : index
%19 = arith.andi %14, %18 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%25 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %25 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%25 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%26 = memref.subview %2[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%25, %26) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%27 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %27 : memref<?x?xf32, #map10>
}
%21 = vector.transfer_read %20[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map10>, vector<9x32xf32>
%22 = scf.for %arg8 = %c0 to %7 step %c16 iter_args(%arg9 = %21) -> (vector<9x32xf32>) {
%25 = affine.apply #map7(%arg8)
%26 = vector.transfer_read %4[%8, %13, %25, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%27 = vector.transfer_read %5[%8, %11, %16, %25, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%28 = vector.transpose %26, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%29 = vector.extract %28[0] : vector<16x9xf32>
%30 = vector.extract %27[0] : vector<16x32xf32>
%31 = vector.outerproduct %29, %30, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%32 = vector.extract %28[1] : vector<16x9xf32>
%33 = vector.extract %27[1] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %31 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %28[2] : vector<16x9xf32>
%36 = vector.extract %27[2] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %28[3] : vector<16x9xf32>
%39 = vector.extract %27[3] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %28[4] : vector<16x9xf32>
%42 = vector.extract %27[4] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %28[5] : vector<16x9xf32>
%45 = vector.extract %27[5] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %28[6] : vector<16x9xf32>
%48 = vector.extract %27[6] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %28[7] : vector<16x9xf32>
%51 = vector.extract %27[7] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %28[8] : vector<16x9xf32>
%54 = vector.extract %27[8] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %28[9] : vector<16x9xf32>
%57 = vector.extract %27[9] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %28[10] : vector<16x9xf32>
%60 = vector.extract %27[10] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %28[11] : vector<16x9xf32>
%63 = vector.extract %27[11] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %28[12] : vector<16x9xf32>
%66 = vector.extract %27[12] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%68 = vector.extract %28[13] : vector<16x9xf32>
%69 = vector.extract %27[13] : vector<16x32xf32>
%70 = vector.outerproduct %68, %69, %67 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%71 = vector.extract %28[14] : vector<16x9xf32>
%72 = vector.extract %27[14] : vector<16x32xf32>
%73 = vector.outerproduct %71, %72, %70 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%74 = vector.extract %28[15] : vector<16x9xf32>
%75 = vector.extract %27[15] : vector<16x32xf32>
%76 = vector.outerproduct %74, %75, %73 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %76 : vector<9x32xf32>
}
%23 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%25 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %25 : memref<?x?xf32, #map10>
} else {
%25 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %25 : memref<?x?xf32, #map10>
}
vector.transfer_write %22, %23[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map10>
%24 = arith.xori %19, %true : i1
scf.if %24 {
%25 = memref.subview %3[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
%26 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
linalg.copy(%25, %26) : memref<?x?xf32, #map11>, memref<?x?xf32, #map9>
}
}
}
}
}
}
memref.dealloc %5 : memref<4x16x4x32x16x32xf32>
memref.dealloc %4 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692150>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map12 = affine_map<(d0) -> (288, -d0 + 2040)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2042 = arith.constant 2042 : index
%c2041 = arith.constant 2041 : index
%c2040 = arith.constant 2040 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%5 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst_1, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%6 = affine.apply #map0(%arg3)
%7 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %9 step %c32 {
%10 = affine.apply #map4(%arg5)
%11 = affine.apply #map5(%arg5, %arg4)
%12 = affine.min #map6(%arg5, %9)
%13 = arith.cmpi sle, %c32, %12 : index
scf.for %arg6 = %c0 to %7 step %c16 {
%14 = affine.apply #map7(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map8(%arg6, %7)
%17 = memref.subview %arg1[%15, %11] [%16, %12] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = arith.andi %18, %13 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%37 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %37 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst_1, %0) : f32, memref<16x32xf32>
%37 = memref.subview %17[0, 0] [%16, %12] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%38 = memref.subview %0[0, 0] [%16, %12] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%37, %38) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%39 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map10>
scf.yield %39 : memref<?x?xf32, #map10>
}
%21 = vector.load %20[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%22 = vector.load %20[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%23 = vector.load %20[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%24 = vector.load %20[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%25 = vector.load %20[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%26 = vector.load %20[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%27 = vector.load %20[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%28 = vector.load %20[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%29 = vector.load %20[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%30 = vector.load %20[%c9, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%31 = vector.load %20[%c10, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%32 = vector.load %20[%c11, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%33 = vector.load %20[%c12, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%34 = vector.load %20[%c13, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%35 = vector.load %20[%c14, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%36 = vector.load %20[%c15, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
vector.store %21, %5[%6, %8, %10, %14, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %22, %5[%6, %8, %10, %14, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %5[%6, %8, %10, %14, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %24, %5[%6, %8, %10, %14, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %5[%6, %8, %10, %14, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %5[%6, %8, %10, %14, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %5[%6, %8, %10, %14, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %5[%6, %8, %10, %14, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %5[%6, %8, %10, %14, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %5[%6, %8, %10, %14, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %5[%6, %8, %10, %14, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %5[%6, %8, %10, %14, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %5[%6, %8, %10, %14, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %5[%6, %8, %10, %14, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %5[%6, %8, %10, %14, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %5[%6, %8, %10, %14, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%6 = affine.min #map12(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.apply #map0(%arg4)
%8 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %6 step %c9 {
%9 = affine.apply #map13(%arg5)
%10 = affine.apply #map5(%arg5, %arg3)
%11 = affine.min #map14(%arg5, %6)
%12 = arith.cmpi sle, %c9, %11 : index
scf.for %arg6 = %c0 to %8 step %c16 {
%13 = affine.apply #map7(%arg6)
%14 = affine.apply #map5(%arg6, %arg4)
%15 = affine.min #map8(%arg6, %8)
%16 = memref.subview %arg0[%10, %14] [%11, %15] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map15>
%17 = arith.cmpi sle, %c16, %15 : index
%18 = arith.andi %12, %17 : i1
%19 = scf.if %18 -> (memref<?x?xf32, #map10>) {
%29 = memref.cast %16 : memref<?x?xf32, #map15> to memref<?x?xf32, #map10>
scf.yield %29 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst_1, %1) : f32, memref<9x16xf32>
%29 = memref.subview %16[0, 0] [%11, %15] [1, 1] : memref<?x?xf32, #map15> to memref<?x?xf32, #map15>
%30 = memref.subview %1[0, 0] [%11, %15] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%29, %30) : memref<?x?xf32, #map15>, memref<?x?xf32, #map16>
%31 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map10>
scf.yield %31 : memref<?x?xf32, #map10>
}
%20 = vector.load %19[%c0, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%21 = vector.load %19[%c1, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%22 = vector.load %19[%c2, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%23 = vector.load %19[%c3, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%24 = vector.load %19[%c4, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%25 = vector.load %19[%c5, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%26 = vector.load %19[%c6, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%27 = vector.load %19[%c7, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%28 = vector.load %19[%c8, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
vector.store %20, %4[%7, %9, %13, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %21, %4[%7, %9, %13, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %22, %4[%7, %9, %13, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %23, %4[%7, %9, %13, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %4[%7, %9, %13, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %4[%7, %9, %13, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %4[%7, %9, %13, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %4[%7, %9, %13, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %4[%7, %9, %13, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.min #map1(%arg4)
%8 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%9 = affine.min #map3(%arg5)
%10 = memref.subview %arg2[%arg3, %arg5] [%6, %9] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%11 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %6 step %c9 {
%12 = affine.min #map14(%arg6, %6)
%13 = affine.apply #map13(%arg6)
%14 = arith.cmpi sle, %c9, %12 : index
scf.for %arg7 = %c0 to %9 step %c32 {
%15 = affine.min #map6(%arg7, %9)
%16 = affine.apply #map4(%arg7)
%17 = memref.subview %10[%arg6, %arg7] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c32, %15 : index
%19 = arith.andi %14, %18 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%51 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst_1, %2) : f32, memref<9x32xf32>
%51 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%52 = memref.subview %2[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%51, %52) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%53 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %53 : memref<?x?xf32, #map10>
}
%21 = vector.load %20[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%22 = vector.insert %21, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%23 = vector.load %20[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%24 = vector.insert %23, %22 [1] : vector<32xf32> into vector<9x32xf32>
%25 = vector.load %20[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%26 = vector.insert %25, %24 [2] : vector<32xf32> into vector<9x32xf32>
%27 = vector.load %20[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%28 = vector.insert %27, %26 [3] : vector<32xf32> into vector<9x32xf32>
%29 = vector.load %20[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%30 = vector.insert %29, %28 [4] : vector<32xf32> into vector<9x32xf32>
%31 = vector.load %20[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%32 = vector.insert %31, %30 [5] : vector<32xf32> into vector<9x32xf32>
%33 = vector.load %20[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%34 = vector.insert %33, %32 [6] : vector<32xf32> into vector<9x32xf32>
%35 = vector.load %20[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%36 = vector.insert %35, %34 [7] : vector<32xf32> into vector<9x32xf32>
%37 = vector.load %20[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%38 = vector.insert %37, %36 [8] : vector<32xf32> into vector<9x32xf32>
%39 = scf.for %arg8 = %c0 to %7 step %c16 iter_args(%arg9 = %38) -> (vector<9x32xf32>) {
%51 = affine.apply #map7(%arg8)
%52 = vector.load %4[%8, %13, %51, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%53 = vector.insert %52, %cst [0] : vector<16xf32> into vector<9x16xf32>
%54 = vector.load %4[%8, %13, %51, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%55 = vector.insert %54, %53 [1] : vector<16xf32> into vector<9x16xf32>
%56 = vector.load %4[%8, %13, %51, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%57 = vector.insert %56, %55 [2] : vector<16xf32> into vector<9x16xf32>
%58 = vector.load %4[%8, %13, %51, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%59 = vector.insert %58, %57 [3] : vector<16xf32> into vector<9x16xf32>
%60 = vector.load %4[%8, %13, %51, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%61 = vector.insert %60, %59 [4] : vector<16xf32> into vector<9x16xf32>
%62 = vector.load %4[%8, %13, %51, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%63 = vector.insert %62, %61 [5] : vector<16xf32> into vector<9x16xf32>
%64 = vector.load %4[%8, %13, %51, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%65 = vector.insert %64, %63 [6] : vector<16xf32> into vector<9x16xf32>
%66 = vector.load %4[%8, %13, %51, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%67 = vector.insert %66, %65 [7] : vector<16xf32> into vector<9x16xf32>
%68 = vector.load %4[%8, %13, %51, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%69 = vector.insert %68, %67 [8] : vector<16xf32> into vector<9x16xf32>
%70 = vector.load %5[%8, %11, %16, %51, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %5[%8, %11, %16, %51, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %5[%8, %11, %16, %51, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %5[%8, %11, %16, %51, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %5[%8, %11, %16, %51, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %5[%8, %11, %16, %51, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %5[%8, %11, %16, %51, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %5[%8, %11, %16, %51, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %5[%8, %11, %16, %51, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %5[%8, %11, %16, %51, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %5[%8, %11, %16, %51, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %5[%8, %11, %16, %51, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %5[%8, %11, %16, %51, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %5[%8, %11, %16, %51, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %5[%8, %11, %16, %51, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %5[%8, %11, %16, %51, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%86 = vector.transpose %69, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%87 = vector.extract %86[0] : vector<16x9xf32>
%88 = vector.outerproduct %87, %70, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%89 = vector.extract %86[1] : vector<16x9xf32>
%90 = vector.outerproduct %89, %71, %88 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%91 = vector.extract %86[2] : vector<16x9xf32>
%92 = vector.outerproduct %91, %72, %90 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%93 = vector.extract %86[3] : vector<16x9xf32>
%94 = vector.outerproduct %93, %73, %92 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%95 = vector.extract %86[4] : vector<16x9xf32>
%96 = vector.outerproduct %95, %74, %94 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%97 = vector.extract %86[5] : vector<16x9xf32>
%98 = vector.outerproduct %97, %75, %96 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%99 = vector.extract %86[6] : vector<16x9xf32>
%100 = vector.outerproduct %99, %76, %98 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%101 = vector.extract %86[7] : vector<16x9xf32>
%102 = vector.outerproduct %101, %77, %100 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%103 = vector.extract %86[8] : vector<16x9xf32>
%104 = vector.outerproduct %103, %78, %102 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%105 = vector.extract %86[9] : vector<16x9xf32>
%106 = vector.outerproduct %105, %79, %104 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%107 = vector.extract %86[10] : vector<16x9xf32>
%108 = vector.outerproduct %107, %80, %106 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%109 = vector.extract %86[11] : vector<16x9xf32>
%110 = vector.outerproduct %109, %81, %108 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%111 = vector.extract %86[12] : vector<16x9xf32>
%112 = vector.outerproduct %111, %82, %110 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%113 = vector.extract %86[13] : vector<16x9xf32>
%114 = vector.outerproduct %113, %83, %112 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%115 = vector.extract %86[14] : vector<16x9xf32>
%116 = vector.outerproduct %115, %84, %114 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%117 = vector.extract %86[15] : vector<16x9xf32>
%118 = vector.outerproduct %117, %85, %116 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %118 : vector<9x32xf32>
}
%40 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%51 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
} else {
%51 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
}
%41 = vector.extract %39[0] : vector<9x32xf32>
vector.store %41, %40[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%42 = vector.extract %39[1] : vector<9x32xf32>
vector.store %42, %40[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%43 = vector.extract %39[2] : vector<9x32xf32>
vector.store %43, %40[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%44 = vector.extract %39[3] : vector<9x32xf32>
vector.store %44, %40[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%45 = vector.extract %39[4] : vector<9x32xf32>
vector.store %45, %40[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%46 = vector.extract %39[5] : vector<9x32xf32>
vector.store %46, %40[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%47 = vector.extract %39[6] : vector<9x32xf32>
vector.store %47, %40[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%48 = vector.extract %39[7] : vector<9x32xf32>
vector.store %48, %40[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%49 = vector.extract %39[8] : vector<9x32xf32>
vector.store %49, %40[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%50 = arith.xori %19, %true : i1
scf.if %50 {
%51 = memref.subview %3[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
%52 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
linalg.copy(%51, %52) : memref<?x?xf32, #map11>, memref<?x?xf32, #map9>
}
}
}
}
}
}
memref.dealloc %5 : memref<4x16x4x32x16x32xf32>
memref.dealloc %4 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692390>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map12 = affine_map<(d0) -> (288, -d0 + 2040)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2040 = arith.constant 2040 : index
%c2041 = arith.constant 2041 : index
%c2042 = arith.constant 2042 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%true = arith.constant true
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%5 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%6 = affine.apply #map0(%arg3)
%7 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %9 step %c32 {
%10 = affine.apply #map4(%arg5)
%11 = affine.apply #map5(%arg5, %arg4)
%12 = affine.min #map6(%arg5, %9)
%13 = arith.cmpi sle, %c32, %12 : index
scf.for %arg6 = %c0 to %7 step %c16 {
%14 = affine.apply #map7(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map8(%arg6, %7)
%17 = memref.subview %arg1[%15, %11] [%16, %12] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = arith.andi %18, %13 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%37 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %37 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%37 = memref.subview %17[0, 0] [%16, %12] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%38 = memref.subview %0[0, 0] [%16, %12] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%37, %38) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%39 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map10>
scf.yield %39 : memref<?x?xf32, #map10>
}
%21 = vector.load %20[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%22 = vector.load %20[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%23 = vector.load %20[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%24 = vector.load %20[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%25 = vector.load %20[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%26 = vector.load %20[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%27 = vector.load %20[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%28 = vector.load %20[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%29 = vector.load %20[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%30 = vector.load %20[%c9, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%31 = vector.load %20[%c10, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%32 = vector.load %20[%c11, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%33 = vector.load %20[%c12, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%34 = vector.load %20[%c13, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%35 = vector.load %20[%c14, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%36 = vector.load %20[%c15, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
vector.store %21, %5[%6, %8, %10, %14, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %22, %5[%6, %8, %10, %14, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %5[%6, %8, %10, %14, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %24, %5[%6, %8, %10, %14, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %5[%6, %8, %10, %14, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %5[%6, %8, %10, %14, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %5[%6, %8, %10, %14, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %5[%6, %8, %10, %14, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %5[%6, %8, %10, %14, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %5[%6, %8, %10, %14, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %5[%6, %8, %10, %14, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %5[%6, %8, %10, %14, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %5[%6, %8, %10, %14, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %5[%6, %8, %10, %14, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %5[%6, %8, %10, %14, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %5[%6, %8, %10, %14, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%6 = affine.min #map12(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.apply #map0(%arg4)
%8 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %6 step %c9 {
%9 = affine.apply #map13(%arg5)
%10 = affine.apply #map5(%arg5, %arg3)
%11 = affine.min #map14(%arg5, %6)
%12 = arith.cmpi sle, %c9, %11 : index
scf.for %arg6 = %c0 to %8 step %c16 {
%13 = affine.apply #map7(%arg6)
%14 = affine.apply #map5(%arg6, %arg4)
%15 = affine.min #map8(%arg6, %8)
%16 = memref.subview %arg0[%10, %14] [%11, %15] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map15>
%17 = arith.cmpi sle, %c16, %15 : index
%18 = arith.andi %12, %17 : i1
%19 = scf.if %18 -> (memref<?x?xf32, #map10>) {
%29 = memref.cast %16 : memref<?x?xf32, #map15> to memref<?x?xf32, #map10>
scf.yield %29 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%29 = memref.subview %16[0, 0] [%11, %15] [1, 1] : memref<?x?xf32, #map15> to memref<?x?xf32, #map15>
%30 = memref.subview %1[0, 0] [%11, %15] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%29, %30) : memref<?x?xf32, #map15>, memref<?x?xf32, #map16>
%31 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map10>
scf.yield %31 : memref<?x?xf32, #map10>
}
%20 = vector.load %19[%c0, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%21 = vector.load %19[%c1, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%22 = vector.load %19[%c2, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%23 = vector.load %19[%c3, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%24 = vector.load %19[%c4, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%25 = vector.load %19[%c5, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%26 = vector.load %19[%c6, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%27 = vector.load %19[%c7, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%28 = vector.load %19[%c8, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
vector.store %20, %4[%7, %9, %13, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %21, %4[%7, %9, %13, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %22, %4[%7, %9, %13, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %23, %4[%7, %9, %13, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %4[%7, %9, %13, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %4[%7, %9, %13, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %4[%7, %9, %13, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %4[%7, %9, %13, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %4[%7, %9, %13, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.min #map1(%arg4)
%8 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%9 = affine.min #map3(%arg5)
%10 = memref.subview %arg2[%arg3, %arg5] [%6, %9] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%11 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %6 step %c9 {
%12 = affine.min #map14(%arg6, %6)
%13 = affine.apply #map13(%arg6)
%14 = arith.cmpi sle, %c9, %12 : index
scf.for %arg7 = %c0 to %9 step %c32 {
%15 = affine.min #map6(%arg7, %9)
%16 = affine.apply #map4(%arg7)
%17 = memref.subview %10[%arg6, %arg7] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c32, %15 : index
%19 = arith.andi %14, %18 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%51 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%51 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%52 = memref.subview %2[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%51, %52) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%53 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %53 : memref<?x?xf32, #map10>
}
%21 = vector.load %20[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%22 = vector.insert %21, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%23 = vector.load %20[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%24 = vector.insert %23, %22 [1] : vector<32xf32> into vector<9x32xf32>
%25 = vector.load %20[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%26 = vector.insert %25, %24 [2] : vector<32xf32> into vector<9x32xf32>
%27 = vector.load %20[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%28 = vector.insert %27, %26 [3] : vector<32xf32> into vector<9x32xf32>
%29 = vector.load %20[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%30 = vector.insert %29, %28 [4] : vector<32xf32> into vector<9x32xf32>
%31 = vector.load %20[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%32 = vector.insert %31, %30 [5] : vector<32xf32> into vector<9x32xf32>
%33 = vector.load %20[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%34 = vector.insert %33, %32 [6] : vector<32xf32> into vector<9x32xf32>
%35 = vector.load %20[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%36 = vector.insert %35, %34 [7] : vector<32xf32> into vector<9x32xf32>
%37 = vector.load %20[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%38 = vector.insert %37, %36 [8] : vector<32xf32> into vector<9x32xf32>
%39 = scf.for %arg8 = %c0 to %7 step %c16 iter_args(%arg9 = %38) -> (vector<9x32xf32>) {
%51 = affine.apply #map7(%arg8)
%52 = vector.load %4[%8, %13, %51, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%53 = vector.insert %52, %cst_1 [0] : vector<16xf32> into vector<9x16xf32>
%54 = vector.load %4[%8, %13, %51, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%55 = vector.insert %54, %53 [1] : vector<16xf32> into vector<9x16xf32>
%56 = vector.load %4[%8, %13, %51, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%57 = vector.insert %56, %55 [2] : vector<16xf32> into vector<9x16xf32>
%58 = vector.load %4[%8, %13, %51, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%59 = vector.insert %58, %57 [3] : vector<16xf32> into vector<9x16xf32>
%60 = vector.load %4[%8, %13, %51, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%61 = vector.insert %60, %59 [4] : vector<16xf32> into vector<9x16xf32>
%62 = vector.load %4[%8, %13, %51, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%63 = vector.insert %62, %61 [5] : vector<16xf32> into vector<9x16xf32>
%64 = vector.load %4[%8, %13, %51, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%65 = vector.insert %64, %63 [6] : vector<16xf32> into vector<9x16xf32>
%66 = vector.load %4[%8, %13, %51, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%67 = vector.insert %66, %65 [7] : vector<16xf32> into vector<9x16xf32>
%68 = vector.load %4[%8, %13, %51, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%69 = vector.insert %68, %67 [8] : vector<16xf32> into vector<9x16xf32>
%70 = vector.load %5[%8, %11, %16, %51, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %5[%8, %11, %16, %51, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %5[%8, %11, %16, %51, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %5[%8, %11, %16, %51, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %5[%8, %11, %16, %51, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %5[%8, %11, %16, %51, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %5[%8, %11, %16, %51, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %5[%8, %11, %16, %51, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %5[%8, %11, %16, %51, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %5[%8, %11, %16, %51, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %5[%8, %11, %16, %51, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %5[%8, %11, %16, %51, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %5[%8, %11, %16, %51, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %5[%8, %11, %16, %51, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %5[%8, %11, %16, %51, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %5[%8, %11, %16, %51, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%86 = vector.transpose %69, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%87 = vector.extract %86[0] : vector<16x9xf32>
%88 = vector.outerproduct %87, %70, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%89 = vector.extract %86[1] : vector<16x9xf32>
%90 = vector.outerproduct %89, %71, %88 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%91 = vector.extract %86[2] : vector<16x9xf32>
%92 = vector.outerproduct %91, %72, %90 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%93 = vector.extract %86[3] : vector<16x9xf32>
%94 = vector.outerproduct %93, %73, %92 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%95 = vector.extract %86[4] : vector<16x9xf32>
%96 = vector.outerproduct %95, %74, %94 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%97 = vector.extract %86[5] : vector<16x9xf32>
%98 = vector.outerproduct %97, %75, %96 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%99 = vector.extract %86[6] : vector<16x9xf32>
%100 = vector.outerproduct %99, %76, %98 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%101 = vector.extract %86[7] : vector<16x9xf32>
%102 = vector.outerproduct %101, %77, %100 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%103 = vector.extract %86[8] : vector<16x9xf32>
%104 = vector.outerproduct %103, %78, %102 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%105 = vector.extract %86[9] : vector<16x9xf32>
%106 = vector.outerproduct %105, %79, %104 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%107 = vector.extract %86[10] : vector<16x9xf32>
%108 = vector.outerproduct %107, %80, %106 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%109 = vector.extract %86[11] : vector<16x9xf32>
%110 = vector.outerproduct %109, %81, %108 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%111 = vector.extract %86[12] : vector<16x9xf32>
%112 = vector.outerproduct %111, %82, %110 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%113 = vector.extract %86[13] : vector<16x9xf32>
%114 = vector.outerproduct %113, %83, %112 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%115 = vector.extract %86[14] : vector<16x9xf32>
%116 = vector.outerproduct %115, %84, %114 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%117 = vector.extract %86[15] : vector<16x9xf32>
%118 = vector.outerproduct %117, %85, %116 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %118 : vector<9x32xf32>
}
%40 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%51 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
} else {
%51 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
}
%41 = vector.extract %39[0] : vector<9x32xf32>
vector.store %41, %40[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%42 = vector.extract %39[1] : vector<9x32xf32>
vector.store %42, %40[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%43 = vector.extract %39[2] : vector<9x32xf32>
vector.store %43, %40[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%44 = vector.extract %39[3] : vector<9x32xf32>
vector.store %44, %40[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%45 = vector.extract %39[4] : vector<9x32xf32>
vector.store %45, %40[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%46 = vector.extract %39[5] : vector<9x32xf32>
vector.store %46, %40[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%47 = vector.extract %39[6] : vector<9x32xf32>
vector.store %47, %40[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%48 = vector.extract %39[7] : vector<9x32xf32>
vector.store %48, %40[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%49 = vector.extract %39[8] : vector<9x32xf32>
vector.store %49, %40[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%50 = arith.xori %19, %true : i1
scf.if %50 {
%51 = memref.subview %3[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
%52 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
linalg.copy(%51, %52) : memref<?x?xf32, #map11>, memref<?x?xf32, #map9>
}
}
}
}
}
}
memref.dealloc %5 : memref<4x16x4x32x16x32xf32>
memref.dealloc %4 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692250>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (512, -d0 + 2042)>
#map2 = affine_map<(d0) -> (d0 ceildiv 128)>
#map3 = affine_map<(d0) -> (128, -d0 + 2041)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map7 = affine_map<(d0) -> (d0 ceildiv 16)>
#map8 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map9 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map10 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map12 = affine_map<(d0) -> (288, -d0 + 2040)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1)[s0] -> (d0 * 2042 + s0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%cst = arith.constant dense<0.000000e+00> : vector<16x9xf32>
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2042 = arith.constant 2042 : index
%c2041 = arith.constant 2041 : index
%c2040 = arith.constant 2040 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%5 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst_2, %arg2) : f32, memref<2040x2041xf32>
scf.for %arg3 = %c0 to %c2042 step %c512 {
%6 = affine.apply #map0(%arg3)
%7 = affine.min #map1(%arg3)
scf.for %arg4 = %c0 to %c2041 step %c128 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)
scf.for %arg5 = %c0 to %9 step %c32 {
%10 = affine.apply #map4(%arg5)
%11 = affine.apply #map5(%arg5, %arg4)
%12 = affine.min #map6(%arg5, %9)
%13 = arith.cmpi sle, %c32, %12 : index
scf.for %arg6 = %c0 to %7 step %c16 {
%14 = affine.apply #map7(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map8(%arg6, %7)
%17 = memref.subview %arg1[%15, %11] [%16, %12] [1, 1] : memref<2042x2041xf32> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = arith.andi %18, %13 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%37 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %37 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst_2, %0) : f32, memref<16x32xf32>
%37 = memref.subview %17[0, 0] [%16, %12] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%38 = memref.subview %0[0, 0] [%16, %12] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%37, %38) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%39 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map10>
scf.yield %39 : memref<?x?xf32, #map10>
}
%21 = vector.load %20[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%22 = vector.load %20[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%23 = vector.load %20[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%24 = vector.load %20[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%25 = vector.load %20[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%26 = vector.load %20[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%27 = vector.load %20[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%28 = vector.load %20[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%29 = vector.load %20[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%30 = vector.load %20[%c9, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%31 = vector.load %20[%c10, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%32 = vector.load %20[%c11, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%33 = vector.load %20[%c12, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%34 = vector.load %20[%c13, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%35 = vector.load %20[%c14, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%36 = vector.load %20[%c15, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
vector.store %21, %5[%6, %8, %10, %14, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %22, %5[%6, %8, %10, %14, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %5[%6, %8, %10, %14, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %24, %5[%6, %8, %10, %14, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %5[%6, %8, %10, %14, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %5[%6, %8, %10, %14, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %5[%6, %8, %10, %14, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %5[%6, %8, %10, %14, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %5[%6, %8, %10, %14, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %5[%6, %8, %10, %14, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %5[%6, %8, %10, %14, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %5[%6, %8, %10, %14, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %5[%6, %8, %10, %14, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %5[%6, %8, %10, %14, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %5[%6, %8, %10, %14, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %5[%6, %8, %10, %14, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2040 step %c288 {
%6 = affine.min #map12(%arg3)
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.apply #map0(%arg4)
%8 = affine.min #map1(%arg4)
scf.for %arg5 = %c0 to %6 step %c9 {
%9 = affine.apply #map13(%arg5)
%10 = affine.apply #map5(%arg5, %arg3)
%11 = affine.min #map14(%arg5, %6)
%12 = arith.cmpi sle, %c9, %11 : index
scf.for %arg6 = %c0 to %8 step %c16 {
%13 = affine.apply #map7(%arg6)
%14 = affine.apply #map5(%arg6, %arg4)
%15 = affine.min #map8(%arg6, %8)
%16 = memref.subview %arg0[%10, %14] [%11, %15] [1, 1] : memref<2040x2042xf32> to memref<?x?xf32, #map15>
%17 = arith.cmpi sle, %c16, %15 : index
%18 = arith.andi %12, %17 : i1
%19 = scf.if %18 -> (memref<?x?xf32, #map10>) {
%29 = memref.cast %16 : memref<?x?xf32, #map15> to memref<?x?xf32, #map10>
scf.yield %29 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst_2, %1) : f32, memref<9x16xf32>
%29 = memref.subview %16[0, 0] [%11, %15] [1, 1] : memref<?x?xf32, #map15> to memref<?x?xf32, #map15>
%30 = memref.subview %1[0, 0] [%11, %15] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%29, %30) : memref<?x?xf32, #map15>, memref<?x?xf32, #map16>
%31 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map10>
scf.yield %31 : memref<?x?xf32, #map10>
}
%20 = vector.load %19[%c0, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%21 = vector.load %19[%c1, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%22 = vector.load %19[%c2, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%23 = vector.load %19[%c3, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%24 = vector.load %19[%c4, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%25 = vector.load %19[%c5, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%26 = vector.load %19[%c6, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%27 = vector.load %19[%c7, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
%28 = vector.load %19[%c8, %c0] : memref<?x?xf32, #map10>, vector<16xf32>
vector.store %20, %4[%7, %9, %13, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %21, %4[%7, %9, %13, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %22, %4[%7, %9, %13, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %23, %4[%7, %9, %13, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %4[%7, %9, %13, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %4[%7, %9, %13, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %4[%7, %9, %13, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %4[%7, %9, %13, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %4[%7, %9, %13, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2042 step %c512 {
%7 = affine.min #map1(%arg4)
%8 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%9 = affine.min #map3(%arg5)
%10 = memref.subview %arg2[%arg3, %arg5] [%6, %9] [1, 1] : memref<2040x2041xf32> to memref<?x?xf32, #map9>
%11 = affine.apply #map2(%arg5)
scf.for %arg6 = %c0 to %6 step %c9 {
%12 = affine.min #map14(%arg6, %6)
%13 = affine.apply #map13(%arg6)
%14 = arith.cmpi sle, %c9, %12 : index
scf.for %arg7 = %c0 to %9 step %c32 {
%15 = affine.min #map6(%arg7, %9)
%16 = affine.apply #map4(%arg7)
%17 = memref.subview %10[%arg6, %arg7] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%18 = arith.cmpi sle, %c32, %15 : index
%19 = arith.andi %14, %18 : i1
%20 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%51 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
} else {
linalg.fill(%cst_2, %2) : f32, memref<9x32xf32>
%51 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%52 = memref.subview %2[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
linalg.copy(%51, %52) : memref<?x?xf32, #map9>, memref<?x?xf32, #map11>
%53 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %53 : memref<?x?xf32, #map10>
}
%21 = vector.load %20[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%22 = vector.insert %21, %cst_1 [0] : vector<32xf32> into vector<9x32xf32>
%23 = vector.load %20[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%24 = vector.insert %23, %22 [1] : vector<32xf32> into vector<9x32xf32>
%25 = vector.load %20[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%26 = vector.insert %25, %24 [2] : vector<32xf32> into vector<9x32xf32>
%27 = vector.load %20[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%28 = vector.insert %27, %26 [3] : vector<32xf32> into vector<9x32xf32>
%29 = vector.load %20[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%30 = vector.insert %29, %28 [4] : vector<32xf32> into vector<9x32xf32>
%31 = vector.load %20[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%32 = vector.insert %31, %30 [5] : vector<32xf32> into vector<9x32xf32>
%33 = vector.load %20[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%34 = vector.insert %33, %32 [6] : vector<32xf32> into vector<9x32xf32>
%35 = vector.load %20[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%36 = vector.insert %35, %34 [7] : vector<32xf32> into vector<9x32xf32>
%37 = vector.load %20[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%38 = vector.insert %37, %36 [8] : vector<32xf32> into vector<9x32xf32>
%39 = scf.for %arg8 = %c0 to %7 step %c16 iter_args(%arg9 = %38) -> (vector<9x32xf32>) {
%51 = affine.apply #map7(%arg8)
%52 = vector.load %4[%8, %13, %51, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%53 = vector.insert %52, %cst_0 [0] : vector<16xf32> into vector<9x16xf32>
%54 = vector.load %4[%8, %13, %51, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%55 = vector.insert %54, %53 [1] : vector<16xf32> into vector<9x16xf32>
%56 = vector.load %4[%8, %13, %51, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%57 = vector.insert %56, %55 [2] : vector<16xf32> into vector<9x16xf32>
%58 = vector.load %4[%8, %13, %51, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%59 = vector.insert %58, %57 [3] : vector<16xf32> into vector<9x16xf32>
%60 = vector.load %4[%8, %13, %51, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%61 = vector.insert %60, %59 [4] : vector<16xf32> into vector<9x16xf32>
%62 = vector.load %4[%8, %13, %51, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%63 = vector.insert %62, %61 [5] : vector<16xf32> into vector<9x16xf32>
%64 = vector.load %4[%8, %13, %51, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%65 = vector.insert %64, %63 [6] : vector<16xf32> into vector<9x16xf32>
%66 = vector.load %4[%8, %13, %51, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%67 = vector.insert %66, %65 [7] : vector<16xf32> into vector<9x16xf32>
%68 = vector.load %4[%8, %13, %51, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%69 = vector.insert %68, %67 [8] : vector<16xf32> into vector<9x16xf32>
%70 = vector.load %5[%8, %11, %16, %51, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %5[%8, %11, %16, %51, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %5[%8, %11, %16, %51, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %5[%8, %11, %16, %51, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %5[%8, %11, %16, %51, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %5[%8, %11, %16, %51, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %5[%8, %11, %16, %51, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %5[%8, %11, %16, %51, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %5[%8, %11, %16, %51, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %5[%8, %11, %16, %51, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %5[%8, %11, %16, %51, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %5[%8, %11, %16, %51, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %5[%8, %11, %16, %51, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %5[%8, %11, %16, %51, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %5[%8, %11, %16, %51, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %5[%8, %11, %16, %51, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%86 = vector.extract %69[0, 0] : vector<9x16xf32>
%87 = vector.insert %86, %cst [0, 0] : f32 into vector<16x9xf32>
%88 = vector.extract %69[1, 0] : vector<9x16xf32>
%89 = vector.insert %88, %87 [0, 1] : f32 into vector<16x9xf32>
%90 = vector.extract %69[2, 0] : vector<9x16xf32>
%91 = vector.insert %90, %89 [0, 2] : f32 into vector<16x9xf32>
%92 = vector.extract %69[3, 0] : vector<9x16xf32>
%93 = vector.insert %92, %91 [0, 3] : f32 into vector<16x9xf32>
%94 = vector.extract %69[4, 0] : vector<9x16xf32>
%95 = vector.insert %94, %93 [0, 4] : f32 into vector<16x9xf32>
%96 = vector.extract %69[5, 0] : vector<9x16xf32>
%97 = vector.insert %96, %95 [0, 5] : f32 into vector<16x9xf32>
%98 = vector.extract %69[6, 0] : vector<9x16xf32>
%99 = vector.insert %98, %97 [0, 6] : f32 into vector<16x9xf32>
%100 = vector.extract %69[7, 0] : vector<9x16xf32>
%101 = vector.insert %100, %99 [0, 7] : f32 into vector<16x9xf32>
%102 = vector.extract %69[8, 0] : vector<9x16xf32>
%103 = vector.insert %102, %101 [0, 8] : f32 into vector<16x9xf32>
%104 = vector.extract %69[0, 1] : vector<9x16xf32>
%105 = vector.insert %104, %103 [1, 0] : f32 into vector<16x9xf32>
%106 = vector.extract %69[1, 1] : vector<9x16xf32>
%107 = vector.insert %106, %105 [1, 1] : f32 into vector<16x9xf32>
%108 = vector.extract %69[2, 1] : vector<9x16xf32>
%109 = vector.insert %108, %107 [1, 2] : f32 into vector<16x9xf32>
%110 = vector.extract %69[3, 1] : vector<9x16xf32>
%111 = vector.insert %110, %109 [1, 3] : f32 into vector<16x9xf32>
%112 = vector.extract %69[4, 1] : vector<9x16xf32>
%113 = vector.insert %112, %111 [1, 4] : f32 into vector<16x9xf32>
%114 = vector.extract %69[5, 1] : vector<9x16xf32>
%115 = vector.insert %114, %113 [1, 5] : f32 into vector<16x9xf32>
%116 = vector.extract %69[6, 1] : vector<9x16xf32>
%117 = vector.insert %116, %115 [1, 6] : f32 into vector<16x9xf32>
%118 = vector.extract %69[7, 1] : vector<9x16xf32>
%119 = vector.insert %118, %117 [1, 7] : f32 into vector<16x9xf32>
%120 = vector.extract %69[8, 1] : vector<9x16xf32>
%121 = vector.insert %120, %119 [1, 8] : f32 into vector<16x9xf32>
%122 = vector.extract %69[0, 2] : vector<9x16xf32>
%123 = vector.insert %122, %121 [2, 0] : f32 into vector<16x9xf32>
%124 = vector.extract %69[1, 2] : vector<9x16xf32>
%125 = vector.insert %124, %123 [2, 1] : f32 into vector<16x9xf32>
%126 = vector.extract %69[2, 2] : vector<9x16xf32>
%127 = vector.insert %126, %125 [2, 2] : f32 into vector<16x9xf32>
%128 = vector.extract %69[3, 2] : vector<9x16xf32>
%129 = vector.insert %128, %127 [2, 3] : f32 into vector<16x9xf32>
%130 = vector.extract %69[4, 2] : vector<9x16xf32>
%131 = vector.insert %130, %129 [2, 4] : f32 into vector<16x9xf32>
%132 = vector.extract %69[5, 2] : vector<9x16xf32>
%133 = vector.insert %132, %131 [2, 5] : f32 into vector<16x9xf32>
%134 = vector.extract %69[6, 2] : vector<9x16xf32>
%135 = vector.insert %134, %133 [2, 6] : f32 into vector<16x9xf32>
%136 = vector.extract %69[7, 2] : vector<9x16xf32>
%137 = vector.insert %136, %135 [2, 7] : f32 into vector<16x9xf32>
%138 = vector.extract %69[8, 2] : vector<9x16xf32>
%139 = vector.insert %138, %137 [2, 8] : f32 into vector<16x9xf32>
%140 = vector.extract %69[0, 3] : vector<9x16xf32>
%141 = vector.insert %140, %139 [3, 0] : f32 into vector<16x9xf32>
%142 = vector.extract %69[1, 3] : vector<9x16xf32>
%143 = vector.insert %142, %141 [3, 1] : f32 into vector<16x9xf32>
%144 = vector.extract %69[2, 3] : vector<9x16xf32>
%145 = vector.insert %144, %143 [3, 2] : f32 into vector<16x9xf32>
%146 = vector.extract %69[3, 3] : vector<9x16xf32>
%147 = vector.insert %146, %145 [3, 3] : f32 into vector<16x9xf32>
%148 = vector.extract %69[4, 3] : vector<9x16xf32>
%149 = vector.insert %148, %147 [3, 4] : f32 into vector<16x9xf32>
%150 = vector.extract %69[5, 3] : vector<9x16xf32>
%151 = vector.insert %150, %149 [3, 5] : f32 into vector<16x9xf32>
%152 = vector.extract %69[6, 3] : vector<9x16xf32>
%153 = vector.insert %152, %151 [3, 6] : f32 into vector<16x9xf32>
%154 = vector.extract %69[7, 3] : vector<9x16xf32>
%155 = vector.insert %154, %153 [3, 7] : f32 into vector<16x9xf32>
%156 = vector.extract %69[8, 3] : vector<9x16xf32>
%157 = vector.insert %156, %155 [3, 8] : f32 into vector<16x9xf32>
%158 = vector.extract %69[0, 4] : vector<9x16xf32>
%159 = vector.insert %158, %157 [4, 0] : f32 into vector<16x9xf32>
%160 = vector.extract %69[1, 4] : vector<9x16xf32>
%161 = vector.insert %160, %159 [4, 1] : f32 into vector<16x9xf32>
%162 = vector.extract %69[2, 4] : vector<9x16xf32>
%163 = vector.insert %162, %161 [4, 2] : f32 into vector<16x9xf32>
%164 = vector.extract %69[3, 4] : vector<9x16xf32>
%165 = vector.insert %164, %163 [4, 3] : f32 into vector<16x9xf32>
%166 = vector.extract %69[4, 4] : vector<9x16xf32>
%167 = vector.insert %166, %165 [4, 4] : f32 into vector<16x9xf32>
%168 = vector.extract %69[5, 4] : vector<9x16xf32>
%169 = vector.insert %168, %167 [4, 5] : f32 into vector<16x9xf32>
%170 = vector.extract %69[6, 4] : vector<9x16xf32>
%171 = vector.insert %170, %169 [4, 6] : f32 into vector<16x9xf32>
%172 = vector.extract %69[7, 4] : vector<9x16xf32>
%173 = vector.insert %172, %171 [4, 7] : f32 into vector<16x9xf32>
%174 = vector.extract %69[8, 4] : vector<9x16xf32>
%175 = vector.insert %174, %173 [4, 8] : f32 into vector<16x9xf32>
%176 = vector.extract %69[0, 5] : vector<9x16xf32>
%177 = vector.insert %176, %175 [5, 0] : f32 into vector<16x9xf32>
%178 = vector.extract %69[1, 5] : vector<9x16xf32>
%179 = vector.insert %178, %177 [5, 1] : f32 into vector<16x9xf32>
%180 = vector.extract %69[2, 5] : vector<9x16xf32>
%181 = vector.insert %180, %179 [5, 2] : f32 into vector<16x9xf32>
%182 = vector.extract %69[3, 5] : vector<9x16xf32>
%183 = vector.insert %182, %181 [5, 3] : f32 into vector<16x9xf32>
%184 = vector.extract %69[4, 5] : vector<9x16xf32>
%185 = vector.insert %184, %183 [5, 4] : f32 into vector<16x9xf32>
%186 = vector.extract %69[5, 5] : vector<9x16xf32>
%187 = vector.insert %186, %185 [5, 5] : f32 into vector<16x9xf32>
%188 = vector.extract %69[6, 5] : vector<9x16xf32>
%189 = vector.insert %188, %187 [5, 6] : f32 into vector<16x9xf32>
%190 = vector.extract %69[7, 5] : vector<9x16xf32>
%191 = vector.insert %190, %189 [5, 7] : f32 into vector<16x9xf32>
%192 = vector.extract %69[8, 5] : vector<9x16xf32>
%193 = vector.insert %192, %191 [5, 8] : f32 into vector<16x9xf32>
%194 = vector.extract %69[0, 6] : vector<9x16xf32>
%195 = vector.insert %194, %193 [6, 0] : f32 into vector<16x9xf32>
%196 = vector.extract %69[1, 6] : vector<9x16xf32>
%197 = vector.insert %196, %195 [6, 1] : f32 into vector<16x9xf32>
%198 = vector.extract %69[2, 6] : vector<9x16xf32>
%199 = vector.insert %198, %197 [6, 2] : f32 into vector<16x9xf32>
%200 = vector.extract %69[3, 6] : vector<9x16xf32>
%201 = vector.insert %200, %199 [6, 3] : f32 into vector<16x9xf32>
%202 = vector.extract %69[4, 6] : vector<9x16xf32>
%203 = vector.insert %202, %201 [6, 4] : f32 into vector<16x9xf32>
%204 = vector.extract %69[5, 6] : vector<9x16xf32>
%205 = vector.insert %204, %203 [6, 5] : f32 into vector<16x9xf32>
%206 = vector.extract %69[6, 6] : vector<9x16xf32>
%207 = vector.insert %206, %205 [6, 6] : f32 into vector<16x9xf32>
%208 = vector.extract %69[7, 6] : vector<9x16xf32>
%209 = vector.insert %208, %207 [6, 7] : f32 into vector<16x9xf32>
%210 = vector.extract %69[8, 6] : vector<9x16xf32>
%211 = vector.insert %210, %209 [6, 8] : f32 into vector<16x9xf32>
%212 = vector.extract %69[0, 7] : vector<9x16xf32>
%213 = vector.insert %212, %211 [7, 0] : f32 into vector<16x9xf32>
%214 = vector.extract %69[1, 7] : vector<9x16xf32>
%215 = vector.insert %214, %213 [7, 1] : f32 into vector<16x9xf32>
%216 = vector.extract %69[2, 7] : vector<9x16xf32>
%217 = vector.insert %216, %215 [7, 2] : f32 into vector<16x9xf32>
%218 = vector.extract %69[3, 7] : vector<9x16xf32>
%219 = vector.insert %218, %217 [7, 3] : f32 into vector<16x9xf32>
%220 = vector.extract %69[4, 7] : vector<9x16xf32>
%221 = vector.insert %220, %219 [7, 4] : f32 into vector<16x9xf32>
%222 = vector.extract %69[5, 7] : vector<9x16xf32>
%223 = vector.insert %222, %221 [7, 5] : f32 into vector<16x9xf32>
%224 = vector.extract %69[6, 7] : vector<9x16xf32>
%225 = vector.insert %224, %223 [7, 6] : f32 into vector<16x9xf32>
%226 = vector.extract %69[7, 7] : vector<9x16xf32>
%227 = vector.insert %226, %225 [7, 7] : f32 into vector<16x9xf32>
%228 = vector.extract %69[8, 7] : vector<9x16xf32>
%229 = vector.insert %228, %227 [7, 8] : f32 into vector<16x9xf32>
%230 = vector.extract %69[0, 8] : vector<9x16xf32>
%231 = vector.insert %230, %229 [8, 0] : f32 into vector<16x9xf32>
%232 = vector.extract %69[1, 8] : vector<9x16xf32>
%233 = vector.insert %232, %231 [8, 1] : f32 into vector<16x9xf32>
%234 = vector.extract %69[2, 8] : vector<9x16xf32>
%235 = vector.insert %234, %233 [8, 2] : f32 into vector<16x9xf32>
%236 = vector.extract %69[3, 8] : vector<9x16xf32>
%237 = vector.insert %236, %235 [8, 3] : f32 into vector<16x9xf32>
%238 = vector.extract %69[4, 8] : vector<9x16xf32>
%239 = vector.insert %238, %237 [8, 4] : f32 into vector<16x9xf32>
%240 = vector.extract %69[5, 8] : vector<9x16xf32>
%241 = vector.insert %240, %239 [8, 5] : f32 into vector<16x9xf32>
%242 = vector.extract %69[6, 8] : vector<9x16xf32>
%243 = vector.insert %242, %241 [8, 6] : f32 into vector<16x9xf32>
%244 = vector.extract %69[7, 8] : vector<9x16xf32>
%245 = vector.insert %244, %243 [8, 7] : f32 into vector<16x9xf32>
%246 = vector.extract %69[8, 8] : vector<9x16xf32>
%247 = vector.insert %246, %245 [8, 8] : f32 into vector<16x9xf32>
%248 = vector.extract %69[0, 9] : vector<9x16xf32>
%249 = vector.insert %248, %247 [9, 0] : f32 into vector<16x9xf32>
%250 = vector.extract %69[1, 9] : vector<9x16xf32>
%251 = vector.insert %250, %249 [9, 1] : f32 into vector<16x9xf32>
%252 = vector.extract %69[2, 9] : vector<9x16xf32>
%253 = vector.insert %252, %251 [9, 2] : f32 into vector<16x9xf32>
%254 = vector.extract %69[3, 9] : vector<9x16xf32>
%255 = vector.insert %254, %253 [9, 3] : f32 into vector<16x9xf32>
%256 = vector.extract %69[4, 9] : vector<9x16xf32>
%257 = vector.insert %256, %255 [9, 4] : f32 into vector<16x9xf32>
%258 = vector.extract %69[5, 9] : vector<9x16xf32>
%259 = vector.insert %258, %257 [9, 5] : f32 into vector<16x9xf32>
%260 = vector.extract %69[6, 9] : vector<9x16xf32>
%261 = vector.insert %260, %259 [9, 6] : f32 into vector<16x9xf32>
%262 = vector.extract %69[7, 9] : vector<9x16xf32>
%263 = vector.insert %262, %261 [9, 7] : f32 into vector<16x9xf32>
%264 = vector.extract %69[8, 9] : vector<9x16xf32>
%265 = vector.insert %264, %263 [9, 8] : f32 into vector<16x9xf32>
%266 = vector.extract %69[0, 10] : vector<9x16xf32>
%267 = vector.insert %266, %265 [10, 0] : f32 into vector<16x9xf32>
%268 = vector.extract %69[1, 10] : vector<9x16xf32>
%269 = vector.insert %268, %267 [10, 1] : f32 into vector<16x9xf32>
%270 = vector.extract %69[2, 10] : vector<9x16xf32>
%271 = vector.insert %270, %269 [10, 2] : f32 into vector<16x9xf32>
%272 = vector.extract %69[3, 10] : vector<9x16xf32>
%273 = vector.insert %272, %271 [10, 3] : f32 into vector<16x9xf32>
%274 = vector.extract %69[4, 10] : vector<9x16xf32>
%275 = vector.insert %274, %273 [10, 4] : f32 into vector<16x9xf32>
%276 = vector.extract %69[5, 10] : vector<9x16xf32>
%277 = vector.insert %276, %275 [10, 5] : f32 into vector<16x9xf32>
%278 = vector.extract %69[6, 10] : vector<9x16xf32>
%279 = vector.insert %278, %277 [10, 6] : f32 into vector<16x9xf32>
%280 = vector.extract %69[7, 10] : vector<9x16xf32>
%281 = vector.insert %280, %279 [10, 7] : f32 into vector<16x9xf32>
%282 = vector.extract %69[8, 10] : vector<9x16xf32>
%283 = vector.insert %282, %281 [10, 8] : f32 into vector<16x9xf32>
%284 = vector.extract %69[0, 11] : vector<9x16xf32>
%285 = vector.insert %284, %283 [11, 0] : f32 into vector<16x9xf32>
%286 = vector.extract %69[1, 11] : vector<9x16xf32>
%287 = vector.insert %286, %285 [11, 1] : f32 into vector<16x9xf32>
%288 = vector.extract %69[2, 11] : vector<9x16xf32>
%289 = vector.insert %288, %287 [11, 2] : f32 into vector<16x9xf32>
%290 = vector.extract %69[3, 11] : vector<9x16xf32>
%291 = vector.insert %290, %289 [11, 3] : f32 into vector<16x9xf32>
%292 = vector.extract %69[4, 11] : vector<9x16xf32>
%293 = vector.insert %292, %291 [11, 4] : f32 into vector<16x9xf32>
%294 = vector.extract %69[5, 11] : vector<9x16xf32>
%295 = vector.insert %294, %293 [11, 5] : f32 into vector<16x9xf32>
%296 = vector.extract %69[6, 11] : vector<9x16xf32>
%297 = vector.insert %296, %295 [11, 6] : f32 into vector<16x9xf32>
%298 = vector.extract %69[7, 11] : vector<9x16xf32>
%299 = vector.insert %298, %297 [11, 7] : f32 into vector<16x9xf32>
%300 = vector.extract %69[8, 11] : vector<9x16xf32>
%301 = vector.insert %300, %299 [11, 8] : f32 into vector<16x9xf32>
%302 = vector.extract %69[0, 12] : vector<9x16xf32>
%303 = vector.insert %302, %301 [12, 0] : f32 into vector<16x9xf32>
%304 = vector.extract %69[1, 12] : vector<9x16xf32>
%305 = vector.insert %304, %303 [12, 1] : f32 into vector<16x9xf32>
%306 = vector.extract %69[2, 12] : vector<9x16xf32>
%307 = vector.insert %306, %305 [12, 2] : f32 into vector<16x9xf32>
%308 = vector.extract %69[3, 12] : vector<9x16xf32>
%309 = vector.insert %308, %307 [12, 3] : f32 into vector<16x9xf32>
%310 = vector.extract %69[4, 12] : vector<9x16xf32>
%311 = vector.insert %310, %309 [12, 4] : f32 into vector<16x9xf32>
%312 = vector.extract %69[5, 12] : vector<9x16xf32>
%313 = vector.insert %312, %311 [12, 5] : f32 into vector<16x9xf32>
%314 = vector.extract %69[6, 12] : vector<9x16xf32>
%315 = vector.insert %314, %313 [12, 6] : f32 into vector<16x9xf32>
%316 = vector.extract %69[7, 12] : vector<9x16xf32>
%317 = vector.insert %316, %315 [12, 7] : f32 into vector<16x9xf32>
%318 = vector.extract %69[8, 12] : vector<9x16xf32>
%319 = vector.insert %318, %317 [12, 8] : f32 into vector<16x9xf32>
%320 = vector.extract %69[0, 13] : vector<9x16xf32>
%321 = vector.insert %320, %319 [13, 0] : f32 into vector<16x9xf32>
%322 = vector.extract %69[1, 13] : vector<9x16xf32>
%323 = vector.insert %322, %321 [13, 1] : f32 into vector<16x9xf32>
%324 = vector.extract %69[2, 13] : vector<9x16xf32>
%325 = vector.insert %324, %323 [13, 2] : f32 into vector<16x9xf32>
%326 = vector.extract %69[3, 13] : vector<9x16xf32>
%327 = vector.insert %326, %325 [13, 3] : f32 into vector<16x9xf32>
%328 = vector.extract %69[4, 13] : vector<9x16xf32>
%329 = vector.insert %328, %327 [13, 4] : f32 into vector<16x9xf32>
%330 = vector.extract %69[5, 13] : vector<9x16xf32>
%331 = vector.insert %330, %329 [13, 5] : f32 into vector<16x9xf32>
%332 = vector.extract %69[6, 13] : vector<9x16xf32>
%333 = vector.insert %332, %331 [13, 6] : f32 into vector<16x9xf32>
%334 = vector.extract %69[7, 13] : vector<9x16xf32>
%335 = vector.insert %334, %333 [13, 7] : f32 into vector<16x9xf32>
%336 = vector.extract %69[8, 13] : vector<9x16xf32>
%337 = vector.insert %336, %335 [13, 8] : f32 into vector<16x9xf32>
%338 = vector.extract %69[0, 14] : vector<9x16xf32>
%339 = vector.insert %338, %337 [14, 0] : f32 into vector<16x9xf32>
%340 = vector.extract %69[1, 14] : vector<9x16xf32>
%341 = vector.insert %340, %339 [14, 1] : f32 into vector<16x9xf32>
%342 = vector.extract %69[2, 14] : vector<9x16xf32>
%343 = vector.insert %342, %341 [14, 2] : f32 into vector<16x9xf32>
%344 = vector.extract %69[3, 14] : vector<9x16xf32>
%345 = vector.insert %344, %343 [14, 3] : f32 into vector<16x9xf32>
%346 = vector.extract %69[4, 14] : vector<9x16xf32>
%347 = vector.insert %346, %345 [14, 4] : f32 into vector<16x9xf32>
%348 = vector.extract %69[5, 14] : vector<9x16xf32>
%349 = vector.insert %348, %347 [14, 5] : f32 into vector<16x9xf32>
%350 = vector.extract %69[6, 14] : vector<9x16xf32>
%351 = vector.insert %350, %349 [14, 6] : f32 into vector<16x9xf32>
%352 = vector.extract %69[7, 14] : vector<9x16xf32>
%353 = vector.insert %352, %351 [14, 7] : f32 into vector<16x9xf32>
%354 = vector.extract %69[8, 14] : vector<9x16xf32>
%355 = vector.insert %354, %353 [14, 8] : f32 into vector<16x9xf32>
%356 = vector.extract %69[0, 15] : vector<9x16xf32>
%357 = vector.insert %356, %355 [15, 0] : f32 into vector<16x9xf32>
%358 = vector.extract %69[1, 15] : vector<9x16xf32>
%359 = vector.insert %358, %357 [15, 1] : f32 into vector<16x9xf32>
%360 = vector.extract %69[2, 15] : vector<9x16xf32>
%361 = vector.insert %360, %359 [15, 2] : f32 into vector<16x9xf32>
%362 = vector.extract %69[3, 15] : vector<9x16xf32>
%363 = vector.insert %362, %361 [15, 3] : f32 into vector<16x9xf32>
%364 = vector.extract %69[4, 15] : vector<9x16xf32>
%365 = vector.insert %364, %363 [15, 4] : f32 into vector<16x9xf32>
%366 = vector.extract %69[5, 15] : vector<9x16xf32>
%367 = vector.insert %366, %365 [15, 5] : f32 into vector<16x9xf32>
%368 = vector.extract %69[6, 15] : vector<9x16xf32>
%369 = vector.insert %368, %367 [15, 6] : f32 into vector<16x9xf32>
%370 = vector.extract %69[7, 15] : vector<9x16xf32>
%371 = vector.insert %370, %369 [15, 7] : f32 into vector<16x9xf32>
%372 = vector.extract %69[8, 15] : vector<9x16xf32>
%373 = vector.insert %372, %371 [15, 8] : f32 into vector<16x9xf32>
%374 = vector.extract %373[0] : vector<16x9xf32>
%375 = vector.outerproduct %374, %70, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%376 = vector.extract %373[1] : vector<16x9xf32>
%377 = vector.outerproduct %376, %71, %375 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%378 = vector.extract %373[2] : vector<16x9xf32>
%379 = vector.outerproduct %378, %72, %377 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%380 = vector.extract %373[3] : vector<16x9xf32>
%381 = vector.outerproduct %380, %73, %379 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%382 = vector.extract %373[4] : vector<16x9xf32>
%383 = vector.outerproduct %382, %74, %381 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%384 = vector.extract %373[5] : vector<16x9xf32>
%385 = vector.outerproduct %384, %75, %383 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%386 = vector.extract %373[6] : vector<16x9xf32>
%387 = vector.outerproduct %386, %76, %385 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%388 = vector.extract %373[7] : vector<16x9xf32>
%389 = vector.outerproduct %388, %77, %387 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%390 = vector.extract %373[8] : vector<16x9xf32>
%391 = vector.outerproduct %390, %78, %389 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%392 = vector.extract %373[9] : vector<16x9xf32>
%393 = vector.outerproduct %392, %79, %391 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%394 = vector.extract %373[10] : vector<16x9xf32>
%395 = vector.outerproduct %394, %80, %393 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%396 = vector.extract %373[11] : vector<16x9xf32>
%397 = vector.outerproduct %396, %81, %395 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%398 = vector.extract %373[12] : vector<16x9xf32>
%399 = vector.outerproduct %398, %82, %397 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%400 = vector.extract %373[13] : vector<16x9xf32>
%401 = vector.outerproduct %400, %83, %399 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%402 = vector.extract %373[14] : vector<16x9xf32>
%403 = vector.outerproduct %402, %84, %401 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%404 = vector.extract %373[15] : vector<16x9xf32>
%405 = vector.outerproduct %404, %85, %403 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %405 : vector<9x32xf32>
}
%40 = scf.if %19 -> (memref<?x?xf32, #map10>) {
%51 = memref.cast %17 : memref<?x?xf32, #map9> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
} else {
%51 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map10>
scf.yield %51 : memref<?x?xf32, #map10>
}
%41 = vector.extract %39[0] : vector<9x32xf32>
vector.store %41, %40[%c0, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%42 = vector.extract %39[1] : vector<9x32xf32>
vector.store %42, %40[%c1, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%43 = vector.extract %39[2] : vector<9x32xf32>
vector.store %43, %40[%c2, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%44 = vector.extract %39[3] : vector<9x32xf32>
vector.store %44, %40[%c3, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%45 = vector.extract %39[4] : vector<9x32xf32>
vector.store %45, %40[%c4, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%46 = vector.extract %39[5] : vector<9x32xf32>
vector.store %46, %40[%c5, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%47 = vector.extract %39[6] : vector<9x32xf32>
vector.store %47, %40[%c6, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%48 = vector.extract %39[7] : vector<9x32xf32>
vector.store %48, %40[%c7, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%49 = vector.extract %39[8] : vector<9x32xf32>
vector.store %49, %40[%c8, %c0] : memref<?x?xf32, #map10>, vector<32xf32>
%50 = arith.xori %19, %true : i1
scf.if %50 {
%51 = memref.subview %3[0, 0] [%12, %15] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map11>
%52 = memref.subview %17[0, 0] [%12, %15] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
linalg.copy(%51, %52) : memref<?x?xf32, #map11>, memref<?x?xf32, #map9>
}
}
}
}
}
}
memref.dealloc %5 : memref<4x16x4x32x16x32xf32>
memref.dealloc %4 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2040x2042xf32>, %arg1: memref<2042x2041xf32>, %arg2: memref<2040x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2040x2042xf32>, memref<2042x2041xf32>, memref<2040x2041xf32>) -> ()
}
return
}
}
compilation in 0.2561s
xxxxxxxxxx : 10 iters time on 1 threads in 0.1578s per iter sec (107.8 GFlop/s, 0.3168 GB/s) total time 1.578s
###############################################################
Runtime problem size {'M': 2040, 'N': 2041, 'K': 2042}
Compile-time problem size {'M': -1, 'N': 2041, 'K': -1}
Problem types [<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>]
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43d6affd0>]]]
#map0 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map2 = affine_map<(d0) -> (128, -d0 + 2041)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x2041xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c2041 = arith.constant 2041 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x2041xf32> -> tensor<?x2041xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x2041xf32>) {
%4 = affine.min #map0(%arg3)[%1]
%5 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x2041xf32>) {
%6 = affine.min #map1(%arg5)[%2]
%7 = tensor.extract_slice %arg0[%arg3, %arg5] [%4, %6] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%8 = scf.for %arg7 = %c0 to %c2041 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x2041xf32>) {
%9 = affine.min #map2(%arg7)
%10 = tensor.extract_slice %arg1[%arg5, %arg7] [%6, %9] [1, 1] : tensor<?x2041xf32> to tensor<?x?xf32>
%11 = tensor.extract_slice %arg8[%arg3, %arg7] [%4, %9] [1, 1] : tensor<?x2041xf32> to tensor<?x?xf32>
%12 = linalg.matmul ins(%7, %10 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) -> tensor<?x?xf32>
%13 = tensor.insert_slice %12 into %arg8[%arg3, %arg7] [%4, %9] [1, 1] : tensor<?x?xf32> into tensor<?x2041xf32>
scf.yield %13 : tensor<?x2041xf32>
}
scf.yield %8 : tensor<?x2041xf32>
}
scf.yield %5 : tensor<?x2041xf32>
}
return %3 : tensor<?x2041xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x2041xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x2041xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x2041xf32>, tensor<?x2041xf32>) -> tensor<?x2041xf32>
scf.yield %1 : tensor<?x2041xf32>
}
return %0 : tensor<?x2041xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43b6dbc50>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (-d0 + 32)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0) -> (-d0 + 16)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0) -> (-d0 + 9)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x2041xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2041 = arith.constant 2041 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x2041xf32> -> tensor<?x2041xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = affine.apply #map0()[%2]
%4 = linalg.init_tensor [%3, 16, 4, 32, 16, 32] : tensor<?x16x4x32x16x32xf32>
%5 = tensor.cast %4 : tensor<?x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%6 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %5) -> (tensor<?x?x?x?x16x32xf32>) {
%10 = affine.apply #map1(%arg3)
%11 = affine.min #map2(%arg3)[%2]
%12 = scf.for %arg5 = %c0 to %c2041 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%13 = affine.apply #map3(%arg5)
%14 = affine.min #map4(%arg5)
%15 = scf.for %arg7 = %c0 to %14 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%16 = affine.apply #map5(%arg7)
%17 = affine.apply #map6(%arg7, %arg5)
%18 = affine.min #map7(%arg7, %14)
%19 = affine.apply #map8(%18)
%20 = scf.for %arg9 = %c0 to %11 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%21 = affine.apply #map9(%arg9)
%22 = affine.apply #map6(%arg9, %arg3)
%23 = affine.min #map10(%arg9, %11)
%24 = tensor.extract_slice %arg1[%22, %17] [%23, %18] [1, 1] : tensor<?x2041xf32> to tensor<?x?xf32>
%25 = affine.apply #map11(%23)
%26 = linalg.pad_tensor %24 nofold low[%c0, %c0] high[%25, %19] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<16x32xf32>
%27 = tensor.insert_slice %26 into %arg10[%10, %13, %16, %21, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<?x?x?x?x16x32xf32>
scf.yield %27 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %20 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %15 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %12 : tensor<?x?x?x?x16x32xf32>
}
%7 = linalg.init_tensor [%3, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%8 = tensor.cast %7 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%9 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x2041xf32>) {
%10 = affine.min #map12(%arg3)[%1]
%11 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %8) -> (tensor<?x?x?x9x16xf32>) {
%13 = affine.apply #map1(%arg5)
%14 = affine.min #map2(%arg5)[%2]
%15 = scf.for %arg7 = %c0 to %10 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%16 = affine.apply #map13(%arg7)
%17 = affine.apply #map6(%arg7, %arg3)
%18 = affine.min #map14(%arg7, %10)
%19 = affine.apply #map15(%18)
%20 = scf.for %arg9 = %c0 to %14 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%21 = affine.apply #map9(%arg9)
%22 = affine.apply #map6(%arg9, %arg5)
%23 = affine.min #map10(%arg9, %14)
%24 = tensor.extract_slice %arg0[%17, %22] [%18, %23] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%25 = affine.apply #map11(%23)
%26 = linalg.pad_tensor %24 nofold low[%c0, %c0] high[%19, %25] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x16xf32>
%27 = tensor.insert_slice %26 into %arg10[%13, %16, %21, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<9x16xf32> into tensor<?x?x?x9x16xf32>
scf.yield %27 : tensor<?x?x?x9x16xf32>
}
scf.yield %20 : tensor<?x?x?x9x16xf32>
}
scf.yield %15 : tensor<?x?x?x9x16xf32>
}
%12 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x2041xf32>) {
%13 = affine.min #map2(%arg5)[%2]
%14 = affine.apply #map1(%arg5)
%15 = scf.for %arg7 = %c0 to %c2041 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x2041xf32>) {
%16 = affine.min #map4(%arg7)
%17 = tensor.extract_slice %arg8[%arg3, %arg7] [%10, %16] [1, 1] : tensor<?x2041xf32> to tensor<?x?xf32>
%18 = affine.apply #map3(%arg7)
%19 = scf.for %arg9 = %c0 to %10 step %c9 iter_args(%arg10 = %17) -> (tensor<?x?xf32>) {
%21 = affine.min #map14(%arg9, %10)
%22 = affine.apply #map13(%arg9)
%23 = affine.apply #map15(%21)
%24 = scf.for %arg11 = %c0 to %16 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%25 = affine.min #map7(%arg11, %16)
%26 = affine.apply #map5(%arg11)
%27 = affine.apply #map8(%25)
%28 = scf.for %arg13 = %c0 to %13 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%29 = tensor.extract_slice %arg14[%arg9, %arg11] [%21, %25] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%30 = affine.apply #map9(%arg13)
%31 = tensor.extract_slice %11[%14, %22, %30, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<?x?x?x9x16xf32> to tensor<9x16xf32>
%32 = tensor.extract_slice %6[%14, %18, %26, %30, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<?x?x?x?x16x32xf32> to tensor<16x32xf32>
%33 = linalg.pad_tensor %29 low[%c0, %c0] high[%23, %27] {
^bb0(%arg15: index, %arg16: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x32xf32>
%34 = linalg.matmul ins(%31, %32 : tensor<9x16xf32>, tensor<16x32xf32>) outs(%33 : tensor<9x32xf32>) -> tensor<9x32xf32>
%35 = tensor.extract_slice %34[0, 0] [%21, %25] [1, 1] : tensor<9x32xf32> to tensor<?x?xf32>
%36 = tensor.insert_slice %35 into %arg14[%arg9, %arg11] [%21, %25] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %36 : tensor<?x?xf32>
}
scf.yield %28 : tensor<?x?xf32>
}
scf.yield %24 : tensor<?x?xf32>
}
%20 = tensor.insert_slice %19 into %arg8[%arg3, %arg7] [%10, %16] [1, 1] : tensor<?x?xf32> into tensor<?x2041xf32>
scf.yield %20 : tensor<?x2041xf32>
}
scf.yield %15 : tensor<?x2041xf32>
}
scf.yield %12 : tensor<?x2041xf32>
}
return %9 : tensor<?x2041xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x2041xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x2041xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x2041xf32>, tensor<?x2041xf32>) -> tensor<?x2041xf32>
scf.yield %1 : tensor<?x2041xf32>
}
return %0 : tensor<?x2041xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Vectorize object at 0x7fb43b6bdc10>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map11 = affine_map<(d0) -> (d0 ceildiv 9)>
#map12 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map13 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map14 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map15 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x2041xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2041 = arith.constant 2041 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x2041xf32> -> tensor<?x2041xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = affine.apply #map0()[%2]
%4 = linalg.init_tensor [%3, 16, 4, 32, 16, 32] : tensor<?x16x4x32x16x32xf32>
%5 = tensor.cast %4 : tensor<?x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%6 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %5) -> (tensor<?x?x?x?x16x32xf32>) {
%10 = affine.apply #map1(%arg3)
%11 = affine.min #map2(%arg3)[%2]
%12 = scf.for %arg5 = %c0 to %c2041 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%13 = affine.apply #map3(%arg5)
%14 = affine.min #map4(%arg5)
%15 = scf.for %arg7 = %c0 to %14 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%16 = affine.apply #map5(%arg7)
%17 = affine.apply #map6(%arg7, %arg5)
%18 = affine.min #map7(%arg7, %14)
%19 = scf.for %arg9 = %c0 to %11 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%20 = affine.apply #map8(%arg9)
%21 = affine.apply #map6(%arg9, %arg3)
%22 = affine.min #map9(%arg9, %11)
%23 = tensor.extract_slice %arg1[%21, %17] [%22, %18] [1, 1] : tensor<?x2041xf32> to tensor<?x?xf32>
%24 = vector.transfer_read %23[%c0, %c0], %cst : tensor<?x?xf32>, vector<16x32xf32>
%25 = vector.transfer_write %24, %arg10[%10, %13, %16, %20, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, tensor<?x?x?x?x16x32xf32>
scf.yield %25 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %19 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %15 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %12 : tensor<?x?x?x?x16x32xf32>
}
%7 = linalg.init_tensor [%3, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%8 = tensor.cast %7 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%9 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x2041xf32>) {
%10 = affine.min #map10(%arg3)[%1]
%11 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %8) -> (tensor<?x?x?x9x16xf32>) {
%13 = affine.apply #map1(%arg5)
%14 = affine.min #map2(%arg5)[%2]
%15 = scf.for %arg7 = %c0 to %10 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%16 = affine.apply #map11(%arg7)
%17 = affine.apply #map6(%arg7, %arg3)
%18 = affine.min #map12(%arg7, %10)
%19 = scf.for %arg9 = %c0 to %14 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%20 = affine.apply #map8(%arg9)
%21 = affine.apply #map6(%arg9, %arg5)
%22 = affine.min #map9(%arg9, %14)
%23 = tensor.extract_slice %arg0[%17, %21] [%18, %22] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%24 = vector.transfer_read %23[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x16xf32>
%25 = vector.transfer_write %24, %arg10[%13, %16, %20, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, tensor<?x?x?x9x16xf32>
scf.yield %25 : tensor<?x?x?x9x16xf32>
}
scf.yield %19 : tensor<?x?x?x9x16xf32>
}
scf.yield %15 : tensor<?x?x?x9x16xf32>
}
%12 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x2041xf32>) {
%13 = affine.min #map2(%arg5)[%2]
%14 = affine.apply #map1(%arg5)
%15 = scf.for %arg7 = %c0 to %c2041 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x2041xf32>) {
%16 = affine.min #map4(%arg7)
%17 = tensor.extract_slice %arg8[%arg3, %arg7] [%10, %16] [1, 1] : tensor<?x2041xf32> to tensor<?x?xf32>
%18 = affine.apply #map3(%arg7)
%19 = scf.for %arg9 = %c0 to %10 step %c9 iter_args(%arg10 = %17) -> (tensor<?x?xf32>) {
%21 = affine.min #map12(%arg9, %10)
%22 = affine.apply #map11(%arg9)
%23 = scf.for %arg11 = %c0 to %16 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%24 = affine.min #map7(%arg11, %16)
%25 = affine.apply #map5(%arg11)
%26 = scf.for %arg13 = %c0 to %13 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%27 = tensor.extract_slice %arg14[%arg9, %arg11] [%21, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%28 = affine.apply #map8(%arg13)
%29 = vector.transfer_read %11[%14, %22, %28, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x9x16xf32>, vector<9x16xf32>
%30 = vector.transfer_read %6[%14, %18, %25, %28, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x?x16x32xf32>, vector<16x32xf32>
%31 = vector.transfer_read %27[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x32xf32>
%32 = vector.contract {indexing_maps = [#map13, #map14, #map15], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %29, %30, %31 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
%33 = vector.transfer_write %32, %27[%c0, %c0] : vector<9x32xf32>, tensor<?x?xf32>
%34 = tensor.insert_slice %33 into %arg14[%arg9, %arg11] [%21, %24] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %34 : tensor<?x?xf32>
}
scf.yield %26 : tensor<?x?xf32>
}
scf.yield %23 : tensor<?x?xf32>
}
%20 = tensor.insert_slice %19 into %arg8[%arg3, %arg7] [%10, %16] [1, 1] : tensor<?x?xf32> into tensor<?x2041xf32>
scf.yield %20 : tensor<?x2041xf32>
}
scf.yield %15 : tensor<?x2041xf32>
}
scf.yield %12 : tensor<?x2041xf32>
}
return %9 : tensor<?x2041xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2041xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x2041xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x2041xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x2041xf32>, tensor<?x2041xf32>) -> tensor<?x2041xf32>
scf.yield %1 : tensor<?x2041xf32>
}
return %0 : tensor<?x2041xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Bufferize object at 0x7fb43b692110>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map15 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map16 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map17 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2041 = arith.constant 2041 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
linalg.fill(%cst, %arg2) : f32, memref<?x2041xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = affine.apply #map0()[%1]
%3 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%5 = affine.apply #map1(%arg3)
%6 = affine.min #map2(%arg3)[%1]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%7 = affine.apply #map3(%arg4)
%8 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %8 step %c32 {
%9 = affine.apply #map5(%arg5)
%10 = affine.apply #map6(%arg5, %arg4)
%11 = affine.min #map7(%arg5, %8)
scf.for %arg6 = %c0 to %6 step %c16 {
%12 = affine.apply #map8(%arg6)
%13 = affine.apply #map6(%arg6, %arg3)
%14 = affine.min #map9(%arg6, %6)
%15 = memref.subview %arg1[%13, %10] [%14, %11] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%16 = vector.transfer_read %15[%c0, %c0], %cst : memref<?x?xf32, #map10>, vector<16x32xf32>
vector.transfer_write %16, %3[%5, %7, %9, %12, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%4 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%5 = affine.min #map11(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.apply #map1(%arg4)
%7 = affine.min #map2(%arg4)[%1]
scf.for %arg5 = %c0 to %5 step %c9 {
%8 = affine.apply #map12(%arg5)
%9 = affine.apply #map6(%arg5, %arg3)
%10 = affine.min #map13(%arg5, %5)
scf.for %arg6 = %c0 to %7 step %c16 {
%11 = affine.apply #map8(%arg6)
%12 = affine.apply #map6(%arg6, %arg4)
%13 = affine.min #map9(%arg6, %7)
%14 = memref.subview %arg0[%9, %12] [%10, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map14>
%15 = vector.transfer_read %14[%c0, %c0], %cst : memref<?x?xf32, #map14>, vector<9x16xf32>
vector.transfer_write %15, %4[%6, %8, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.min #map2(%arg4)[%1]
%7 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%8 = affine.min #map4(%arg5)
%9 = memref.subview %arg2[%arg3, %arg5] [%5, %8] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%10 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%11 = affine.min #map13(%arg6, %5)
%12 = affine.apply #map12(%arg6)
scf.for %arg7 = %c0 to %8 step %c32 {
%13 = affine.min #map7(%arg7, %8)
%14 = affine.apply #map5(%arg7)
%15 = memref.subview %9[%arg6, %arg7] [%11, %13] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%16 = vector.transfer_read %15[%c0, %c0], %cst : memref<?x?xf32, #map10>, vector<9x32xf32>
%17 = scf.for %arg8 = %c0 to %6 step %c16 iter_args(%arg9 = %16) -> (vector<9x32xf32>) {
%18 = affine.apply #map8(%arg8)
%19 = vector.transfer_read %4[%7, %12, %18, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%20 = vector.transfer_read %3[%7, %10, %14, %18, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%21 = vector.contract {indexing_maps = [#map15, #map16, #map17], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %19, %20, %arg9 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
scf.yield %21 : vector<9x32xf32>
}
vector.transfer_write %17, %15[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map10>
}
}
}
}
}
memref.dealloc %3 : memref<?x16x4x32x16x32xf32>
memref.dealloc %4 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692190>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2041 = arith.constant 2041 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
linalg.fill(%cst, %arg2) : f32, memref<?x2041xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = affine.apply #map0()[%1]
%3 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%5 = affine.apply #map1(%arg3)
%6 = affine.min #map2(%arg3)[%1]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%7 = affine.apply #map3(%arg4)
%8 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %8 step %c32 {
%9 = affine.apply #map5(%arg5)
%10 = affine.apply #map6(%arg5, %arg4)
%11 = affine.min #map7(%arg5, %8)
scf.for %arg6 = %c0 to %6 step %c16 {
%12 = affine.apply #map8(%arg6)
%13 = affine.apply #map6(%arg6, %arg3)
%14 = affine.min #map9(%arg6, %6)
%15 = memref.subview %arg1[%13, %10] [%14, %11] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%16 = vector.transfer_read %15[%c0, %c0], %cst : memref<?x?xf32, #map10>, vector<16x32xf32>
vector.transfer_write %16, %3[%5, %7, %9, %12, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%4 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%5 = affine.min #map11(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.apply #map1(%arg4)
%7 = affine.min #map2(%arg4)[%1]
scf.for %arg5 = %c0 to %5 step %c9 {
%8 = affine.apply #map12(%arg5)
%9 = affine.apply #map6(%arg5, %arg3)
%10 = affine.min #map13(%arg5, %5)
scf.for %arg6 = %c0 to %7 step %c16 {
%11 = affine.apply #map8(%arg6)
%12 = affine.apply #map6(%arg6, %arg4)
%13 = affine.min #map9(%arg6, %7)
%14 = memref.subview %arg0[%9, %12] [%10, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map14>
%15 = vector.transfer_read %14[%c0, %c0], %cst : memref<?x?xf32, #map14>, vector<9x16xf32>
vector.transfer_write %15, %4[%6, %8, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.min #map2(%arg4)[%1]
%7 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%8 = affine.min #map4(%arg5)
%9 = memref.subview %arg2[%arg3, %arg5] [%5, %8] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%10 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%11 = affine.min #map13(%arg6, %5)
%12 = affine.apply #map12(%arg6)
scf.for %arg7 = %c0 to %8 step %c32 {
%13 = affine.min #map7(%arg7, %8)
%14 = affine.apply #map5(%arg7)
%15 = memref.subview %9[%arg6, %arg7] [%11, %13] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%16 = vector.transfer_read %15[%c0, %c0], %cst : memref<?x?xf32, #map10>, vector<9x32xf32>
%17 = scf.for %arg8 = %c0 to %6 step %c16 iter_args(%arg9 = %16) -> (vector<9x32xf32>) {
%18 = affine.apply #map8(%arg8)
%19 = vector.transfer_read %4[%7, %12, %18, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%20 = vector.transfer_read %3[%7, %10, %14, %18, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%21 = vector.transpose %19, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%22 = vector.extract %21[0] : vector<16x9xf32>
%23 = vector.extract %20[0] : vector<16x32xf32>
%24 = vector.outerproduct %22, %23, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%25 = vector.extract %21[1] : vector<16x9xf32>
%26 = vector.extract %20[1] : vector<16x32xf32>
%27 = vector.outerproduct %25, %26, %24 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%28 = vector.extract %21[2] : vector<16x9xf32>
%29 = vector.extract %20[2] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %27 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %21[3] : vector<16x9xf32>
%32 = vector.extract %20[3] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %21[4] : vector<16x9xf32>
%35 = vector.extract %20[4] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %21[5] : vector<16x9xf32>
%38 = vector.extract %20[5] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %21[6] : vector<16x9xf32>
%41 = vector.extract %20[6] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %21[7] : vector<16x9xf32>
%44 = vector.extract %20[7] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %21[8] : vector<16x9xf32>
%47 = vector.extract %20[8] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %21[9] : vector<16x9xf32>
%50 = vector.extract %20[9] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %21[10] : vector<16x9xf32>
%53 = vector.extract %20[10] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %21[11] : vector<16x9xf32>
%56 = vector.extract %20[11] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %21[12] : vector<16x9xf32>
%59 = vector.extract %20[12] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %21[13] : vector<16x9xf32>
%62 = vector.extract %20[13] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %21[14] : vector<16x9xf32>
%65 = vector.extract %20[14] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %21[15] : vector<16x9xf32>
%68 = vector.extract %20[15] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %69 : vector<9x32xf32>
}
vector.transfer_write %17, %15[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map10>
}
}
}
}
}
memref.dealloc %3 : memref<?x16x4x32x16x32xf32>
memref.dealloc %4 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6923d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2041 = arith.constant 2041 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
linalg.fill(%cst, %arg2) : f32, memref<?x2041xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = affine.apply #map0()[%1]
%3 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%5 = affine.apply #map1(%arg3)
%6 = affine.min #map2(%arg3)[%1]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%7 = affine.apply #map3(%arg4)
%8 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %8 step %c32 {
%9 = affine.apply #map5(%arg5)
%10 = affine.apply #map6(%arg5, %arg4)
%11 = affine.min #map7(%arg5, %8)
scf.for %arg6 = %c0 to %6 step %c16 {
%12 = affine.apply #map8(%arg6)
%13 = affine.apply #map6(%arg6, %arg3)
%14 = affine.min #map9(%arg6, %6)
%15 = memref.subview %arg1[%13, %10] [%14, %11] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%16 = vector.transfer_read %15[%c0, %c0], %cst : memref<?x?xf32, #map10>, vector<16x32xf32>
vector.transfer_write %16, %3[%5, %7, %9, %12, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%4 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%5 = affine.min #map11(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.apply #map1(%arg4)
%7 = affine.min #map2(%arg4)[%1]
scf.for %arg5 = %c0 to %5 step %c9 {
%8 = affine.apply #map12(%arg5)
%9 = affine.apply #map6(%arg5, %arg3)
%10 = affine.min #map13(%arg5, %5)
scf.for %arg6 = %c0 to %7 step %c16 {
%11 = affine.apply #map8(%arg6)
%12 = affine.apply #map6(%arg6, %arg4)
%13 = affine.min #map9(%arg6, %7)
%14 = memref.subview %arg0[%9, %12] [%10, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map14>
%15 = vector.transfer_read %14[%c0, %c0], %cst : memref<?x?xf32, #map14>, vector<9x16xf32>
vector.transfer_write %15, %4[%6, %8, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.min #map2(%arg4)[%1]
%7 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%8 = affine.min #map4(%arg5)
%9 = memref.subview %arg2[%arg3, %arg5] [%5, %8] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%10 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%11 = affine.min #map13(%arg6, %5)
%12 = affine.apply #map12(%arg6)
scf.for %arg7 = %c0 to %8 step %c32 {
%13 = affine.min #map7(%arg7, %8)
%14 = affine.apply #map5(%arg7)
%15 = memref.subview %9[%arg6, %arg7] [%11, %13] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%16 = vector.transfer_read %15[%c0, %c0], %cst : memref<?x?xf32, #map10>, vector<9x32xf32>
%17 = scf.for %arg8 = %c0 to %6 step %c16 iter_args(%arg9 = %16) -> (vector<9x32xf32>) {
%18 = affine.apply #map8(%arg8)
%19 = vector.transfer_read %4[%7, %12, %18, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%20 = vector.transfer_read %3[%7, %10, %14, %18, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%21 = vector.transpose %19, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%22 = vector.extract %21[0] : vector<16x9xf32>
%23 = vector.extract %20[0] : vector<16x32xf32>
%24 = vector.outerproduct %22, %23, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%25 = vector.extract %21[1] : vector<16x9xf32>
%26 = vector.extract %20[1] : vector<16x32xf32>
%27 = vector.outerproduct %25, %26, %24 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%28 = vector.extract %21[2] : vector<16x9xf32>
%29 = vector.extract %20[2] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %27 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %21[3] : vector<16x9xf32>
%32 = vector.extract %20[3] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %21[4] : vector<16x9xf32>
%35 = vector.extract %20[4] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %21[5] : vector<16x9xf32>
%38 = vector.extract %20[5] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %21[6] : vector<16x9xf32>
%41 = vector.extract %20[6] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %21[7] : vector<16x9xf32>
%44 = vector.extract %20[7] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %21[8] : vector<16x9xf32>
%47 = vector.extract %20[8] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %21[9] : vector<16x9xf32>
%50 = vector.extract %20[9] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %21[10] : vector<16x9xf32>
%53 = vector.extract %20[10] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %21[11] : vector<16x9xf32>
%56 = vector.extract %20[11] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %21[12] : vector<16x9xf32>
%59 = vector.extract %20[12] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %21[13] : vector<16x9xf32>
%62 = vector.extract %20[13] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %21[14] : vector<16x9xf32>
%65 = vector.extract %20[14] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %21[15] : vector<16x9xf32>
%68 = vector.extract %20[15] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %69 : vector<9x32xf32>
}
vector.transfer_write %17, %15[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map10>
}
}
}
}
}
memref.dealloc %3 : memref<?x16x4x32x16x32xf32>
memref.dealloc %4 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692210>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2041 = arith.constant 2041 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x2041xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%11 = affine.apply #map3(%arg4)
%12 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %12 step %c32 {
%13 = affine.apply #map5(%arg5)
%14 = affine.apply #map6(%arg5, %arg4)
%15 = affine.min #map7(%arg5, %12)
%16 = arith.cmpi sle, %c32, %15 : index
scf.for %arg6 = %c0 to %10 step %c16 {
%17 = affine.apply #map8(%arg6)
%18 = affine.apply #map6(%arg6, %arg3)
%19 = affine.min #map9(%arg6, %10)
%20 = memref.subview %arg1[%18, %14] [%19, %15] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c16, %19 : index
%22 = arith.andi %21, %16 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%25 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %25 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%25 = memref.subview %20[0, 0] [%19, %15] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%26 = memref.subview %0[0, 0] [%19, %15] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%25, %26) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%27 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %27 : memref<?x?xf32, #map11>
}
%24 = vector.transfer_read %23[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %24, %7[%9, %11, %13, %17, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map14(%arg5)
%13 = affine.apply #map6(%arg5, %arg3)
%14 = affine.min #map15(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map8(%arg6)
%17 = affine.apply #map6(%arg6, %arg4)
%18 = affine.min #map9(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map11>) {
scf.yield %19 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%24 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%25 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%24, %25) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%26 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %26 : memref<?x?xf32, #map11>
}
%23 = vector.transfer_read %22[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %23, %8[%10, %12, %16, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%12 = affine.min #map4(%arg5)
%13 = memref.subview %arg2[%arg3, %arg5] [%9, %12] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%14 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%15 = affine.min #map15(%arg6, %9)
%16 = affine.apply #map14(%arg6)
%17 = arith.cmpi sle, %c9, %15 : index
scf.for %arg7 = %c0 to %12 step %c32 {
%18 = affine.min #map7(%arg7, %12)
%19 = affine.apply #map5(%arg7)
%20 = memref.subview %13[%arg6, %arg7] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c32, %18 : index
%22 = arith.andi %17, %21 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%28 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%28 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%29 = memref.subview %2[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%28, %29) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%30 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %30 : memref<?x?xf32, #map11>
}
%24 = vector.transfer_read %23[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x32xf32>
%25 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %24) -> (vector<9x32xf32>) {
%28 = affine.apply #map8(%arg8)
%29 = vector.transfer_read %8[%11, %16, %28, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%30 = vector.transfer_read %7[%11, %14, %19, %28, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%31 = vector.transpose %29, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%32 = vector.extract %31[0] : vector<16x9xf32>
%33 = vector.extract %30[0] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %31[1] : vector<16x9xf32>
%36 = vector.extract %30[1] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %31[2] : vector<16x9xf32>
%39 = vector.extract %30[2] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %31[3] : vector<16x9xf32>
%42 = vector.extract %30[3] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %31[4] : vector<16x9xf32>
%45 = vector.extract %30[4] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %31[5] : vector<16x9xf32>
%48 = vector.extract %30[5] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %31[6] : vector<16x9xf32>
%51 = vector.extract %30[6] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %31[7] : vector<16x9xf32>
%54 = vector.extract %30[7] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %31[8] : vector<16x9xf32>
%57 = vector.extract %30[8] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %31[9] : vector<16x9xf32>
%60 = vector.extract %30[9] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %31[10] : vector<16x9xf32>
%63 = vector.extract %30[10] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %31[11] : vector<16x9xf32>
%66 = vector.extract %30[11] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%68 = vector.extract %31[12] : vector<16x9xf32>
%69 = vector.extract %30[12] : vector<16x32xf32>
%70 = vector.outerproduct %68, %69, %67 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%71 = vector.extract %31[13] : vector<16x9xf32>
%72 = vector.extract %30[13] : vector<16x32xf32>
%73 = vector.outerproduct %71, %72, %70 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%74 = vector.extract %31[14] : vector<16x9xf32>
%75 = vector.extract %30[14] : vector<16x32xf32>
%76 = vector.outerproduct %74, %75, %73 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%77 = vector.extract %31[15] : vector<16x9xf32>
%78 = vector.extract %30[15] : vector<16x32xf32>
%79 = vector.outerproduct %77, %78, %76 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %79 : vector<9x32xf32>
}
%26 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%28 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
} else {
%28 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
}
vector.transfer_write %25, %26[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map11>
%27 = arith.xori %22, %true : i1
scf.if %27 {
%28 = memref.subview %3[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%29 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
linalg.copy(%28, %29) : memref<?x?xf32, #map12>, memref<?x?xf32, #map10>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6921d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2041 = arith.constant 2041 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%true = arith.constant true
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x2041xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%11 = affine.apply #map3(%arg4)
%12 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %12 step %c32 {
%13 = affine.apply #map5(%arg5)
%14 = affine.apply #map6(%arg5, %arg4)
%15 = affine.min #map7(%arg5, %12)
%16 = arith.cmpi sle, %c32, %15 : index
scf.for %arg6 = %c0 to %10 step %c16 {
%17 = affine.apply #map8(%arg6)
%18 = affine.apply #map6(%arg6, %arg3)
%19 = affine.min #map9(%arg6, %10)
%20 = memref.subview %arg1[%18, %14] [%19, %15] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c16, %19 : index
%22 = arith.andi %21, %16 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%25 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %25 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%25 = memref.subview %20[0, 0] [%19, %15] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%26 = memref.subview %0[0, 0] [%19, %15] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%25, %26) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%27 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %27 : memref<?x?xf32, #map11>
}
%24 = vector.transfer_read %23[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %24, %7[%9, %11, %13, %17, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map14(%arg5)
%13 = affine.apply #map6(%arg5, %arg3)
%14 = affine.min #map15(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map8(%arg6)
%17 = affine.apply #map6(%arg6, %arg4)
%18 = affine.min #map9(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map11>) {
scf.yield %19 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%24 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%25 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%24, %25) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%26 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %26 : memref<?x?xf32, #map11>
}
%23 = vector.transfer_read %22[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %23, %8[%10, %12, %16, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%12 = affine.min #map4(%arg5)
%13 = memref.subview %arg2[%arg3, %arg5] [%9, %12] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%14 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%15 = affine.min #map15(%arg6, %9)
%16 = affine.apply #map14(%arg6)
%17 = arith.cmpi sle, %c9, %15 : index
scf.for %arg7 = %c0 to %12 step %c32 {
%18 = affine.min #map7(%arg7, %12)
%19 = affine.apply #map5(%arg7)
%20 = memref.subview %13[%arg6, %arg7] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c32, %18 : index
%22 = arith.andi %17, %21 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%28 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%28 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%29 = memref.subview %2[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%28, %29) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%30 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %30 : memref<?x?xf32, #map11>
}
%24 = vector.transfer_read %23[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x32xf32>
%25 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %24) -> (vector<9x32xf32>) {
%28 = affine.apply #map8(%arg8)
%29 = vector.transfer_read %8[%11, %16, %28, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%30 = vector.transfer_read %7[%11, %14, %19, %28, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%31 = vector.transpose %29, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%32 = vector.extract %31[0] : vector<16x9xf32>
%33 = vector.extract %30[0] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %31[1] : vector<16x9xf32>
%36 = vector.extract %30[1] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %31[2] : vector<16x9xf32>
%39 = vector.extract %30[2] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %31[3] : vector<16x9xf32>
%42 = vector.extract %30[3] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %31[4] : vector<16x9xf32>
%45 = vector.extract %30[4] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %31[5] : vector<16x9xf32>
%48 = vector.extract %30[5] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %31[6] : vector<16x9xf32>
%51 = vector.extract %30[6] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %31[7] : vector<16x9xf32>
%54 = vector.extract %30[7] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %31[8] : vector<16x9xf32>
%57 = vector.extract %30[8] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %31[9] : vector<16x9xf32>
%60 = vector.extract %30[9] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %31[10] : vector<16x9xf32>
%63 = vector.extract %30[10] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %31[11] : vector<16x9xf32>
%66 = vector.extract %30[11] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%68 = vector.extract %31[12] : vector<16x9xf32>
%69 = vector.extract %30[12] : vector<16x32xf32>
%70 = vector.outerproduct %68, %69, %67 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%71 = vector.extract %31[13] : vector<16x9xf32>
%72 = vector.extract %30[13] : vector<16x32xf32>
%73 = vector.outerproduct %71, %72, %70 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%74 = vector.extract %31[14] : vector<16x9xf32>
%75 = vector.extract %30[14] : vector<16x32xf32>
%76 = vector.outerproduct %74, %75, %73 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%77 = vector.extract %31[15] : vector<16x9xf32>
%78 = vector.extract %30[15] : vector<16x32xf32>
%79 = vector.outerproduct %77, %78, %76 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %79 : vector<9x32xf32>
}
%26 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%28 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
} else {
%28 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
}
vector.transfer_write %25, %26[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map11>
%27 = arith.xori %22, %true : i1
scf.if %27 {
%28 = memref.subview %3[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%29 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
linalg.copy(%28, %29) : memref<?x?xf32, #map12>, memref<?x?xf32, #map10>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692150>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2041 = arith.constant 2041 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_1, %arg2) : f32, memref<?x2041xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%11 = affine.apply #map3(%arg4)
%12 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %12 step %c32 {
%13 = affine.apply #map5(%arg5)
%14 = affine.apply #map6(%arg5, %arg4)
%15 = affine.min #map7(%arg5, %12)
%16 = arith.cmpi sle, %c32, %15 : index
scf.for %arg6 = %c0 to %10 step %c16 {
%17 = affine.apply #map8(%arg6)
%18 = affine.apply #map6(%arg6, %arg3)
%19 = affine.min #map9(%arg6, %10)
%20 = memref.subview %arg1[%18, %14] [%19, %15] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c16, %19 : index
%22 = arith.andi %21, %16 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%40 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %40 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %0) : f32, memref<16x32xf32>
%40 = memref.subview %20[0, 0] [%19, %15] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%41 = memref.subview %0[0, 0] [%19, %15] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%40, %41) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%42 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %42 : memref<?x?xf32, #map11>
}
%24 = vector.load %23[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%25 = vector.load %23[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%26 = vector.load %23[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %23[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %23[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %23[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %23[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %23[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %23[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %23[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %23[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %23[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %23[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %23[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %23[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %23[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %24, %7[%9, %11, %13, %17, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %7[%9, %11, %13, %17, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %7[%9, %11, %13, %17, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %7[%9, %11, %13, %17, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %7[%9, %11, %13, %17, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %7[%9, %11, %13, %17, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %7[%9, %11, %13, %17, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %7[%9, %11, %13, %17, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %7[%9, %11, %13, %17, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %7[%9, %11, %13, %17, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %7[%9, %11, %13, %17, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %7[%9, %11, %13, %17, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %7[%9, %11, %13, %17, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %7[%9, %11, %13, %17, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %7[%9, %11, %13, %17, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %7[%9, %11, %13, %17, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map14(%arg5)
%13 = affine.apply #map6(%arg5, %arg3)
%14 = affine.min #map15(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map8(%arg6)
%17 = affine.apply #map6(%arg6, %arg4)
%18 = affine.min #map9(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map11>) {
scf.yield %19 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %1) : f32, memref<9x16xf32>
%32 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%33 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%32, %33) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%34 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %34 : memref<?x?xf32, #map11>
}
%23 = vector.load %22[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%24 = vector.load %22[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%25 = vector.load %22[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %22[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %22[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %22[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %22[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %22[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %22[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %23, %8[%10, %12, %16, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %8[%10, %12, %16, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %8[%10, %12, %16, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %8[%10, %12, %16, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %8[%10, %12, %16, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %8[%10, %12, %16, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %8[%10, %12, %16, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %8[%10, %12, %16, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %8[%10, %12, %16, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%12 = affine.min #map4(%arg5)
%13 = memref.subview %arg2[%arg3, %arg5] [%9, %12] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%14 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%15 = affine.min #map15(%arg6, %9)
%16 = affine.apply #map14(%arg6)
%17 = arith.cmpi sle, %c9, %15 : index
scf.for %arg7 = %c0 to %12 step %c32 {
%18 = affine.min #map7(%arg7, %12)
%19 = affine.apply #map5(%arg7)
%20 = memref.subview %13[%arg6, %arg7] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c32, %18 : index
%22 = arith.andi %17, %21 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%54 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %2) : f32, memref<9x32xf32>
%54 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%55 = memref.subview %2[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%54, %55) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%56 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%24 = vector.load %23[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%25 = vector.insert %24, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%26 = vector.load %23[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %25 [1] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %23[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [2] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %23[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [3] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %23[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [4] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %23[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [5] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %23[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [6] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %23[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [7] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %23[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [8] : vector<32xf32> into vector<9x32xf32>
%42 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %41) -> (vector<9x32xf32>) {
%54 = affine.apply #map8(%arg8)
%55 = vector.load %8[%11, %16, %54, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%56 = vector.insert %55, %cst [0] : vector<16xf32> into vector<9x16xf32>
%57 = vector.load %8[%11, %16, %54, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %56 [1] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %8[%11, %16, %54, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [2] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %8[%11, %16, %54, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [3] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %8[%11, %16, %54, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [4] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %8[%11, %16, %54, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [5] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %8[%11, %16, %54, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [6] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %8[%11, %16, %54, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [7] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %8[%11, %16, %54, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [8] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %7[%11, %14, %19, %54, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %7[%11, %14, %19, %54, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %7[%11, %14, %19, %54, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %7[%11, %14, %19, %54, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %7[%11, %14, %19, %54, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %7[%11, %14, %19, %54, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %7[%11, %14, %19, %54, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %7[%11, %14, %19, %54, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %7[%11, %14, %19, %54, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %7[%11, %14, %19, %54, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %7[%11, %14, %19, %54, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %7[%11, %14, %19, %54, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %7[%11, %14, %19, %54, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %7[%11, %14, %19, %54, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %7[%11, %14, %19, %54, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %7[%11, %14, %19, %54, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%89 = vector.transpose %72, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%90 = vector.extract %89[0] : vector<16x9xf32>
%91 = vector.outerproduct %90, %73, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%92 = vector.extract %89[1] : vector<16x9xf32>
%93 = vector.outerproduct %92, %74, %91 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %89[2] : vector<16x9xf32>
%95 = vector.outerproduct %94, %75, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %89[3] : vector<16x9xf32>
%97 = vector.outerproduct %96, %76, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %89[4] : vector<16x9xf32>
%99 = vector.outerproduct %98, %77, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %89[5] : vector<16x9xf32>
%101 = vector.outerproduct %100, %78, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %89[6] : vector<16x9xf32>
%103 = vector.outerproduct %102, %79, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %89[7] : vector<16x9xf32>
%105 = vector.outerproduct %104, %80, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %89[8] : vector<16x9xf32>
%107 = vector.outerproduct %106, %81, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %89[9] : vector<16x9xf32>
%109 = vector.outerproduct %108, %82, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %89[10] : vector<16x9xf32>
%111 = vector.outerproduct %110, %83, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %89[11] : vector<16x9xf32>
%113 = vector.outerproduct %112, %84, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %89[12] : vector<16x9xf32>
%115 = vector.outerproduct %114, %85, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %89[13] : vector<16x9xf32>
%117 = vector.outerproduct %116, %86, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%118 = vector.extract %89[14] : vector<16x9xf32>
%119 = vector.outerproduct %118, %87, %117 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%120 = vector.extract %89[15] : vector<16x9xf32>
%121 = vector.outerproduct %120, %88, %119 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %121 : vector<9x32xf32>
}
%43 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%54 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
} else {
%54 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
}
%44 = vector.extract %42[0] : vector<9x32xf32>
vector.store %44, %43[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%45 = vector.extract %42[1] : vector<9x32xf32>
vector.store %45, %43[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%46 = vector.extract %42[2] : vector<9x32xf32>
vector.store %46, %43[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %42[3] : vector<9x32xf32>
vector.store %47, %43[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %42[4] : vector<9x32xf32>
vector.store %48, %43[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %42[5] : vector<9x32xf32>
vector.store %49, %43[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %42[6] : vector<9x32xf32>
vector.store %50, %43[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %42[7] : vector<9x32xf32>
vector.store %51, %43[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %42[8] : vector<9x32xf32>
vector.store %52, %43[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = arith.xori %22, %true : i1
scf.if %53 {
%54 = memref.subview %3[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%55 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
linalg.copy(%54, %55) : memref<?x?xf32, #map12>, memref<?x?xf32, #map10>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692390>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2041 = arith.constant 2041 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%true = arith.constant true
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x2041xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%11 = affine.apply #map3(%arg4)
%12 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %12 step %c32 {
%13 = affine.apply #map5(%arg5)
%14 = affine.apply #map6(%arg5, %arg4)
%15 = affine.min #map7(%arg5, %12)
%16 = arith.cmpi sle, %c32, %15 : index
scf.for %arg6 = %c0 to %10 step %c16 {
%17 = affine.apply #map8(%arg6)
%18 = affine.apply #map6(%arg6, %arg3)
%19 = affine.min #map9(%arg6, %10)
%20 = memref.subview %arg1[%18, %14] [%19, %15] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c16, %19 : index
%22 = arith.andi %21, %16 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%40 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %40 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%40 = memref.subview %20[0, 0] [%19, %15] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%41 = memref.subview %0[0, 0] [%19, %15] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%40, %41) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%42 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %42 : memref<?x?xf32, #map11>
}
%24 = vector.load %23[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%25 = vector.load %23[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%26 = vector.load %23[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %23[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %23[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %23[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %23[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %23[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %23[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %23[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %23[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %23[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %23[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %23[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %23[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %23[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %24, %7[%9, %11, %13, %17, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %7[%9, %11, %13, %17, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %7[%9, %11, %13, %17, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %7[%9, %11, %13, %17, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %7[%9, %11, %13, %17, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %7[%9, %11, %13, %17, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %7[%9, %11, %13, %17, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %7[%9, %11, %13, %17, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %7[%9, %11, %13, %17, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %7[%9, %11, %13, %17, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %7[%9, %11, %13, %17, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %7[%9, %11, %13, %17, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %7[%9, %11, %13, %17, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %7[%9, %11, %13, %17, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %7[%9, %11, %13, %17, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %7[%9, %11, %13, %17, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map14(%arg5)
%13 = affine.apply #map6(%arg5, %arg3)
%14 = affine.min #map15(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map8(%arg6)
%17 = affine.apply #map6(%arg6, %arg4)
%18 = affine.min #map9(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map11>) {
scf.yield %19 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%32 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%33 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%32, %33) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%34 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %34 : memref<?x?xf32, #map11>
}
%23 = vector.load %22[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%24 = vector.load %22[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%25 = vector.load %22[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %22[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %22[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %22[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %22[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %22[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %22[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %23, %8[%10, %12, %16, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %8[%10, %12, %16, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %8[%10, %12, %16, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %8[%10, %12, %16, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %8[%10, %12, %16, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %8[%10, %12, %16, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %8[%10, %12, %16, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %8[%10, %12, %16, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %8[%10, %12, %16, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%12 = affine.min #map4(%arg5)
%13 = memref.subview %arg2[%arg3, %arg5] [%9, %12] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%14 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%15 = affine.min #map15(%arg6, %9)
%16 = affine.apply #map14(%arg6)
%17 = arith.cmpi sle, %c9, %15 : index
scf.for %arg7 = %c0 to %12 step %c32 {
%18 = affine.min #map7(%arg7, %12)
%19 = affine.apply #map5(%arg7)
%20 = memref.subview %13[%arg6, %arg7] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c32, %18 : index
%22 = arith.andi %17, %21 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%54 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%54 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%55 = memref.subview %2[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%54, %55) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%56 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%24 = vector.load %23[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%25 = vector.insert %24, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%26 = vector.load %23[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %25 [1] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %23[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [2] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %23[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [3] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %23[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [4] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %23[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [5] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %23[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [6] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %23[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [7] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %23[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [8] : vector<32xf32> into vector<9x32xf32>
%42 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %41) -> (vector<9x32xf32>) {
%54 = affine.apply #map8(%arg8)
%55 = vector.load %8[%11, %16, %54, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%56 = vector.insert %55, %cst_1 [0] : vector<16xf32> into vector<9x16xf32>
%57 = vector.load %8[%11, %16, %54, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %56 [1] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %8[%11, %16, %54, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [2] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %8[%11, %16, %54, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [3] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %8[%11, %16, %54, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [4] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %8[%11, %16, %54, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [5] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %8[%11, %16, %54, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [6] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %8[%11, %16, %54, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [7] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %8[%11, %16, %54, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [8] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %7[%11, %14, %19, %54, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %7[%11, %14, %19, %54, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %7[%11, %14, %19, %54, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %7[%11, %14, %19, %54, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %7[%11, %14, %19, %54, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %7[%11, %14, %19, %54, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %7[%11, %14, %19, %54, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %7[%11, %14, %19, %54, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %7[%11, %14, %19, %54, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %7[%11, %14, %19, %54, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %7[%11, %14, %19, %54, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %7[%11, %14, %19, %54, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %7[%11, %14, %19, %54, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %7[%11, %14, %19, %54, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %7[%11, %14, %19, %54, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %7[%11, %14, %19, %54, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%89 = vector.transpose %72, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%90 = vector.extract %89[0] : vector<16x9xf32>
%91 = vector.outerproduct %90, %73, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%92 = vector.extract %89[1] : vector<16x9xf32>
%93 = vector.outerproduct %92, %74, %91 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %89[2] : vector<16x9xf32>
%95 = vector.outerproduct %94, %75, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %89[3] : vector<16x9xf32>
%97 = vector.outerproduct %96, %76, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %89[4] : vector<16x9xf32>
%99 = vector.outerproduct %98, %77, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %89[5] : vector<16x9xf32>
%101 = vector.outerproduct %100, %78, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %89[6] : vector<16x9xf32>
%103 = vector.outerproduct %102, %79, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %89[7] : vector<16x9xf32>
%105 = vector.outerproduct %104, %80, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %89[8] : vector<16x9xf32>
%107 = vector.outerproduct %106, %81, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %89[9] : vector<16x9xf32>
%109 = vector.outerproduct %108, %82, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %89[10] : vector<16x9xf32>
%111 = vector.outerproduct %110, %83, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %89[11] : vector<16x9xf32>
%113 = vector.outerproduct %112, %84, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %89[12] : vector<16x9xf32>
%115 = vector.outerproduct %114, %85, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %89[13] : vector<16x9xf32>
%117 = vector.outerproduct %116, %86, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%118 = vector.extract %89[14] : vector<16x9xf32>
%119 = vector.outerproduct %118, %87, %117 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%120 = vector.extract %89[15] : vector<16x9xf32>
%121 = vector.outerproduct %120, %88, %119 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %121 : vector<9x32xf32>
}
%43 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%54 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
} else {
%54 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
}
%44 = vector.extract %42[0] : vector<9x32xf32>
vector.store %44, %43[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%45 = vector.extract %42[1] : vector<9x32xf32>
vector.store %45, %43[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%46 = vector.extract %42[2] : vector<9x32xf32>
vector.store %46, %43[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %42[3] : vector<9x32xf32>
vector.store %47, %43[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %42[4] : vector<9x32xf32>
vector.store %48, %43[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %42[5] : vector<9x32xf32>
vector.store %49, %43[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %42[6] : vector<9x32xf32>
vector.store %50, %43[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %42[7] : vector<9x32xf32>
vector.store %51, %43[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %42[8] : vector<9x32xf32>
vector.store %52, %43[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = arith.xori %22, %true : i1
scf.if %53 {
%54 = memref.subview %3[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%55 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
linalg.copy(%54, %55) : memref<?x?xf32, #map12>, memref<?x?xf32, #map10>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692250>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (128, -d0 + 2041)>
#map5 = affine_map<(d0) -> (d0 ceildiv 32)>
#map6 = affine_map<(d0, d1) -> (d0 + d1)>
#map7 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map8 = affine_map<(d0) -> (d0 ceildiv 16)>
#map9 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map10 = affine_map<(d0, d1)[s0] -> (d0 * 2041 + s0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%cst = arith.constant dense<0.000000e+00> : vector<16x9xf32>
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2041 = arith.constant 2041 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_2, %arg2) : f32, memref<?x2041xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2041 step %c128 {
%11 = affine.apply #map3(%arg4)
%12 = affine.min #map4(%arg4)
scf.for %arg5 = %c0 to %12 step %c32 {
%13 = affine.apply #map5(%arg5)
%14 = affine.apply #map6(%arg5, %arg4)
%15 = affine.min #map7(%arg5, %12)
%16 = arith.cmpi sle, %c32, %15 : index
scf.for %arg6 = %c0 to %10 step %c16 {
%17 = affine.apply #map8(%arg6)
%18 = affine.apply #map6(%arg6, %arg3)
%19 = affine.min #map9(%arg6, %10)
%20 = memref.subview %arg1[%18, %14] [%19, %15] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c16, %19 : index
%22 = arith.andi %21, %16 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%40 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %40 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_2, %0) : f32, memref<16x32xf32>
%40 = memref.subview %20[0, 0] [%19, %15] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%41 = memref.subview %0[0, 0] [%19, %15] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%40, %41) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%42 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %42 : memref<?x?xf32, #map11>
}
%24 = vector.load %23[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%25 = vector.load %23[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%26 = vector.load %23[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %23[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %23[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %23[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %23[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %23[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %23[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %23[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %23[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %23[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %23[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %23[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %23[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %23[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %24, %7[%9, %11, %13, %17, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %7[%9, %11, %13, %17, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %7[%9, %11, %13, %17, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %7[%9, %11, %13, %17, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %7[%9, %11, %13, %17, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %7[%9, %11, %13, %17, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %7[%9, %11, %13, %17, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %7[%9, %11, %13, %17, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %7[%9, %11, %13, %17, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %7[%9, %11, %13, %17, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %7[%9, %11, %13, %17, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %7[%9, %11, %13, %17, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %7[%9, %11, %13, %17, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %7[%9, %11, %13, %17, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %7[%9, %11, %13, %17, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %7[%9, %11, %13, %17, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map14(%arg5)
%13 = affine.apply #map6(%arg5, %arg3)
%14 = affine.min #map15(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map8(%arg6)
%17 = affine.apply #map6(%arg6, %arg4)
%18 = affine.min #map9(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map11>) {
scf.yield %19 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_2, %1) : f32, memref<9x16xf32>
%32 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%33 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%32, %33) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%34 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %34 : memref<?x?xf32, #map11>
}
%23 = vector.load %22[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%24 = vector.load %22[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%25 = vector.load %22[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %22[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %22[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %22[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %22[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %22[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %22[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %23, %8[%10, %12, %16, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %8[%10, %12, %16, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %8[%10, %12, %16, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %8[%10, %12, %16, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %8[%10, %12, %16, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %8[%10, %12, %16, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %8[%10, %12, %16, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %8[%10, %12, %16, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %8[%10, %12, %16, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2041 step %c128 {
%12 = affine.min #map4(%arg5)
%13 = memref.subview %arg2[%arg3, %arg5] [%9, %12] [1, 1] : memref<?x2041xf32> to memref<?x?xf32, #map10>
%14 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%15 = affine.min #map15(%arg6, %9)
%16 = affine.apply #map14(%arg6)
%17 = arith.cmpi sle, %c9, %15 : index
scf.for %arg7 = %c0 to %12 step %c32 {
%18 = affine.min #map7(%arg7, %12)
%19 = affine.apply #map5(%arg7)
%20 = memref.subview %13[%arg6, %arg7] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%21 = arith.cmpi sle, %c32, %18 : index
%22 = arith.andi %17, %21 : i1
%23 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%54 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_2, %2) : f32, memref<9x32xf32>
%54 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
%55 = memref.subview %2[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%54, %55) : memref<?x?xf32, #map10>, memref<?x?xf32, #map12>
%56 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%24 = vector.load %23[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%25 = vector.insert %24, %cst_1 [0] : vector<32xf32> into vector<9x32xf32>
%26 = vector.load %23[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %25 [1] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %23[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [2] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %23[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [3] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %23[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [4] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %23[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [5] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %23[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [6] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %23[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [7] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %23[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [8] : vector<32xf32> into vector<9x32xf32>
%42 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %41) -> (vector<9x32xf32>) {
%54 = affine.apply #map8(%arg8)
%55 = vector.load %8[%11, %16, %54, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%56 = vector.insert %55, %cst_0 [0] : vector<16xf32> into vector<9x16xf32>
%57 = vector.load %8[%11, %16, %54, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %56 [1] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %8[%11, %16, %54, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [2] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %8[%11, %16, %54, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [3] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %8[%11, %16, %54, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [4] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %8[%11, %16, %54, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [5] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %8[%11, %16, %54, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [6] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %8[%11, %16, %54, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [7] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %8[%11, %16, %54, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [8] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %7[%11, %14, %19, %54, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %7[%11, %14, %19, %54, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %7[%11, %14, %19, %54, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %7[%11, %14, %19, %54, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %7[%11, %14, %19, %54, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %7[%11, %14, %19, %54, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %7[%11, %14, %19, %54, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %7[%11, %14, %19, %54, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %7[%11, %14, %19, %54, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %7[%11, %14, %19, %54, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %7[%11, %14, %19, %54, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %7[%11, %14, %19, %54, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %7[%11, %14, %19, %54, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %7[%11, %14, %19, %54, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %7[%11, %14, %19, %54, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %7[%11, %14, %19, %54, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%89 = vector.extract %72[0, 0] : vector<9x16xf32>
%90 = vector.insert %89, %cst [0, 0] : f32 into vector<16x9xf32>
%91 = vector.extract %72[1, 0] : vector<9x16xf32>
%92 = vector.insert %91, %90 [0, 1] : f32 into vector<16x9xf32>
%93 = vector.extract %72[2, 0] : vector<9x16xf32>
%94 = vector.insert %93, %92 [0, 2] : f32 into vector<16x9xf32>
%95 = vector.extract %72[3, 0] : vector<9x16xf32>
%96 = vector.insert %95, %94 [0, 3] : f32 into vector<16x9xf32>
%97 = vector.extract %72[4, 0] : vector<9x16xf32>
%98 = vector.insert %97, %96 [0, 4] : f32 into vector<16x9xf32>
%99 = vector.extract %72[5, 0] : vector<9x16xf32>
%100 = vector.insert %99, %98 [0, 5] : f32 into vector<16x9xf32>
%101 = vector.extract %72[6, 0] : vector<9x16xf32>
%102 = vector.insert %101, %100 [0, 6] : f32 into vector<16x9xf32>
%103 = vector.extract %72[7, 0] : vector<9x16xf32>
%104 = vector.insert %103, %102 [0, 7] : f32 into vector<16x9xf32>
%105 = vector.extract %72[8, 0] : vector<9x16xf32>
%106 = vector.insert %105, %104 [0, 8] : f32 into vector<16x9xf32>
%107 = vector.extract %72[0, 1] : vector<9x16xf32>
%108 = vector.insert %107, %106 [1, 0] : f32 into vector<16x9xf32>
%109 = vector.extract %72[1, 1] : vector<9x16xf32>
%110 = vector.insert %109, %108 [1, 1] : f32 into vector<16x9xf32>
%111 = vector.extract %72[2, 1] : vector<9x16xf32>
%112 = vector.insert %111, %110 [1, 2] : f32 into vector<16x9xf32>
%113 = vector.extract %72[3, 1] : vector<9x16xf32>
%114 = vector.insert %113, %112 [1, 3] : f32 into vector<16x9xf32>
%115 = vector.extract %72[4, 1] : vector<9x16xf32>
%116 = vector.insert %115, %114 [1, 4] : f32 into vector<16x9xf32>
%117 = vector.extract %72[5, 1] : vector<9x16xf32>
%118 = vector.insert %117, %116 [1, 5] : f32 into vector<16x9xf32>
%119 = vector.extract %72[6, 1] : vector<9x16xf32>
%120 = vector.insert %119, %118 [1, 6] : f32 into vector<16x9xf32>
%121 = vector.extract %72[7, 1] : vector<9x16xf32>
%122 = vector.insert %121, %120 [1, 7] : f32 into vector<16x9xf32>
%123 = vector.extract %72[8, 1] : vector<9x16xf32>
%124 = vector.insert %123, %122 [1, 8] : f32 into vector<16x9xf32>
%125 = vector.extract %72[0, 2] : vector<9x16xf32>
%126 = vector.insert %125, %124 [2, 0] : f32 into vector<16x9xf32>
%127 = vector.extract %72[1, 2] : vector<9x16xf32>
%128 = vector.insert %127, %126 [2, 1] : f32 into vector<16x9xf32>
%129 = vector.extract %72[2, 2] : vector<9x16xf32>
%130 = vector.insert %129, %128 [2, 2] : f32 into vector<16x9xf32>
%131 = vector.extract %72[3, 2] : vector<9x16xf32>
%132 = vector.insert %131, %130 [2, 3] : f32 into vector<16x9xf32>
%133 = vector.extract %72[4, 2] : vector<9x16xf32>
%134 = vector.insert %133, %132 [2, 4] : f32 into vector<16x9xf32>
%135 = vector.extract %72[5, 2] : vector<9x16xf32>
%136 = vector.insert %135, %134 [2, 5] : f32 into vector<16x9xf32>
%137 = vector.extract %72[6, 2] : vector<9x16xf32>
%138 = vector.insert %137, %136 [2, 6] : f32 into vector<16x9xf32>
%139 = vector.extract %72[7, 2] : vector<9x16xf32>
%140 = vector.insert %139, %138 [2, 7] : f32 into vector<16x9xf32>
%141 = vector.extract %72[8, 2] : vector<9x16xf32>
%142 = vector.insert %141, %140 [2, 8] : f32 into vector<16x9xf32>
%143 = vector.extract %72[0, 3] : vector<9x16xf32>
%144 = vector.insert %143, %142 [3, 0] : f32 into vector<16x9xf32>
%145 = vector.extract %72[1, 3] : vector<9x16xf32>
%146 = vector.insert %145, %144 [3, 1] : f32 into vector<16x9xf32>
%147 = vector.extract %72[2, 3] : vector<9x16xf32>
%148 = vector.insert %147, %146 [3, 2] : f32 into vector<16x9xf32>
%149 = vector.extract %72[3, 3] : vector<9x16xf32>
%150 = vector.insert %149, %148 [3, 3] : f32 into vector<16x9xf32>
%151 = vector.extract %72[4, 3] : vector<9x16xf32>
%152 = vector.insert %151, %150 [3, 4] : f32 into vector<16x9xf32>
%153 = vector.extract %72[5, 3] : vector<9x16xf32>
%154 = vector.insert %153, %152 [3, 5] : f32 into vector<16x9xf32>
%155 = vector.extract %72[6, 3] : vector<9x16xf32>
%156 = vector.insert %155, %154 [3, 6] : f32 into vector<16x9xf32>
%157 = vector.extract %72[7, 3] : vector<9x16xf32>
%158 = vector.insert %157, %156 [3, 7] : f32 into vector<16x9xf32>
%159 = vector.extract %72[8, 3] : vector<9x16xf32>
%160 = vector.insert %159, %158 [3, 8] : f32 into vector<16x9xf32>
%161 = vector.extract %72[0, 4] : vector<9x16xf32>
%162 = vector.insert %161, %160 [4, 0] : f32 into vector<16x9xf32>
%163 = vector.extract %72[1, 4] : vector<9x16xf32>
%164 = vector.insert %163, %162 [4, 1] : f32 into vector<16x9xf32>
%165 = vector.extract %72[2, 4] : vector<9x16xf32>
%166 = vector.insert %165, %164 [4, 2] : f32 into vector<16x9xf32>
%167 = vector.extract %72[3, 4] : vector<9x16xf32>
%168 = vector.insert %167, %166 [4, 3] : f32 into vector<16x9xf32>
%169 = vector.extract %72[4, 4] : vector<9x16xf32>
%170 = vector.insert %169, %168 [4, 4] : f32 into vector<16x9xf32>
%171 = vector.extract %72[5, 4] : vector<9x16xf32>
%172 = vector.insert %171, %170 [4, 5] : f32 into vector<16x9xf32>
%173 = vector.extract %72[6, 4] : vector<9x16xf32>
%174 = vector.insert %173, %172 [4, 6] : f32 into vector<16x9xf32>
%175 = vector.extract %72[7, 4] : vector<9x16xf32>
%176 = vector.insert %175, %174 [4, 7] : f32 into vector<16x9xf32>
%177 = vector.extract %72[8, 4] : vector<9x16xf32>
%178 = vector.insert %177, %176 [4, 8] : f32 into vector<16x9xf32>
%179 = vector.extract %72[0, 5] : vector<9x16xf32>
%180 = vector.insert %179, %178 [5, 0] : f32 into vector<16x9xf32>
%181 = vector.extract %72[1, 5] : vector<9x16xf32>
%182 = vector.insert %181, %180 [5, 1] : f32 into vector<16x9xf32>
%183 = vector.extract %72[2, 5] : vector<9x16xf32>
%184 = vector.insert %183, %182 [5, 2] : f32 into vector<16x9xf32>
%185 = vector.extract %72[3, 5] : vector<9x16xf32>
%186 = vector.insert %185, %184 [5, 3] : f32 into vector<16x9xf32>
%187 = vector.extract %72[4, 5] : vector<9x16xf32>
%188 = vector.insert %187, %186 [5, 4] : f32 into vector<16x9xf32>
%189 = vector.extract %72[5, 5] : vector<9x16xf32>
%190 = vector.insert %189, %188 [5, 5] : f32 into vector<16x9xf32>
%191 = vector.extract %72[6, 5] : vector<9x16xf32>
%192 = vector.insert %191, %190 [5, 6] : f32 into vector<16x9xf32>
%193 = vector.extract %72[7, 5] : vector<9x16xf32>
%194 = vector.insert %193, %192 [5, 7] : f32 into vector<16x9xf32>
%195 = vector.extract %72[8, 5] : vector<9x16xf32>
%196 = vector.insert %195, %194 [5, 8] : f32 into vector<16x9xf32>
%197 = vector.extract %72[0, 6] : vector<9x16xf32>
%198 = vector.insert %197, %196 [6, 0] : f32 into vector<16x9xf32>
%199 = vector.extract %72[1, 6] : vector<9x16xf32>
%200 = vector.insert %199, %198 [6, 1] : f32 into vector<16x9xf32>
%201 = vector.extract %72[2, 6] : vector<9x16xf32>
%202 = vector.insert %201, %200 [6, 2] : f32 into vector<16x9xf32>
%203 = vector.extract %72[3, 6] : vector<9x16xf32>
%204 = vector.insert %203, %202 [6, 3] : f32 into vector<16x9xf32>
%205 = vector.extract %72[4, 6] : vector<9x16xf32>
%206 = vector.insert %205, %204 [6, 4] : f32 into vector<16x9xf32>
%207 = vector.extract %72[5, 6] : vector<9x16xf32>
%208 = vector.insert %207, %206 [6, 5] : f32 into vector<16x9xf32>
%209 = vector.extract %72[6, 6] : vector<9x16xf32>
%210 = vector.insert %209, %208 [6, 6] : f32 into vector<16x9xf32>
%211 = vector.extract %72[7, 6] : vector<9x16xf32>
%212 = vector.insert %211, %210 [6, 7] : f32 into vector<16x9xf32>
%213 = vector.extract %72[8, 6] : vector<9x16xf32>
%214 = vector.insert %213, %212 [6, 8] : f32 into vector<16x9xf32>
%215 = vector.extract %72[0, 7] : vector<9x16xf32>
%216 = vector.insert %215, %214 [7, 0] : f32 into vector<16x9xf32>
%217 = vector.extract %72[1, 7] : vector<9x16xf32>
%218 = vector.insert %217, %216 [7, 1] : f32 into vector<16x9xf32>
%219 = vector.extract %72[2, 7] : vector<9x16xf32>
%220 = vector.insert %219, %218 [7, 2] : f32 into vector<16x9xf32>
%221 = vector.extract %72[3, 7] : vector<9x16xf32>
%222 = vector.insert %221, %220 [7, 3] : f32 into vector<16x9xf32>
%223 = vector.extract %72[4, 7] : vector<9x16xf32>
%224 = vector.insert %223, %222 [7, 4] : f32 into vector<16x9xf32>
%225 = vector.extract %72[5, 7] : vector<9x16xf32>
%226 = vector.insert %225, %224 [7, 5] : f32 into vector<16x9xf32>
%227 = vector.extract %72[6, 7] : vector<9x16xf32>
%228 = vector.insert %227, %226 [7, 6] : f32 into vector<16x9xf32>
%229 = vector.extract %72[7, 7] : vector<9x16xf32>
%230 = vector.insert %229, %228 [7, 7] : f32 into vector<16x9xf32>
%231 = vector.extract %72[8, 7] : vector<9x16xf32>
%232 = vector.insert %231, %230 [7, 8] : f32 into vector<16x9xf32>
%233 = vector.extract %72[0, 8] : vector<9x16xf32>
%234 = vector.insert %233, %232 [8, 0] : f32 into vector<16x9xf32>
%235 = vector.extract %72[1, 8] : vector<9x16xf32>
%236 = vector.insert %235, %234 [8, 1] : f32 into vector<16x9xf32>
%237 = vector.extract %72[2, 8] : vector<9x16xf32>
%238 = vector.insert %237, %236 [8, 2] : f32 into vector<16x9xf32>
%239 = vector.extract %72[3, 8] : vector<9x16xf32>
%240 = vector.insert %239, %238 [8, 3] : f32 into vector<16x9xf32>
%241 = vector.extract %72[4, 8] : vector<9x16xf32>
%242 = vector.insert %241, %240 [8, 4] : f32 into vector<16x9xf32>
%243 = vector.extract %72[5, 8] : vector<9x16xf32>
%244 = vector.insert %243, %242 [8, 5] : f32 into vector<16x9xf32>
%245 = vector.extract %72[6, 8] : vector<9x16xf32>
%246 = vector.insert %245, %244 [8, 6] : f32 into vector<16x9xf32>
%247 = vector.extract %72[7, 8] : vector<9x16xf32>
%248 = vector.insert %247, %246 [8, 7] : f32 into vector<16x9xf32>
%249 = vector.extract %72[8, 8] : vector<9x16xf32>
%250 = vector.insert %249, %248 [8, 8] : f32 into vector<16x9xf32>
%251 = vector.extract %72[0, 9] : vector<9x16xf32>
%252 = vector.insert %251, %250 [9, 0] : f32 into vector<16x9xf32>
%253 = vector.extract %72[1, 9] : vector<9x16xf32>
%254 = vector.insert %253, %252 [9, 1] : f32 into vector<16x9xf32>
%255 = vector.extract %72[2, 9] : vector<9x16xf32>
%256 = vector.insert %255, %254 [9, 2] : f32 into vector<16x9xf32>
%257 = vector.extract %72[3, 9] : vector<9x16xf32>
%258 = vector.insert %257, %256 [9, 3] : f32 into vector<16x9xf32>
%259 = vector.extract %72[4, 9] : vector<9x16xf32>
%260 = vector.insert %259, %258 [9, 4] : f32 into vector<16x9xf32>
%261 = vector.extract %72[5, 9] : vector<9x16xf32>
%262 = vector.insert %261, %260 [9, 5] : f32 into vector<16x9xf32>
%263 = vector.extract %72[6, 9] : vector<9x16xf32>
%264 = vector.insert %263, %262 [9, 6] : f32 into vector<16x9xf32>
%265 = vector.extract %72[7, 9] : vector<9x16xf32>
%266 = vector.insert %265, %264 [9, 7] : f32 into vector<16x9xf32>
%267 = vector.extract %72[8, 9] : vector<9x16xf32>
%268 = vector.insert %267, %266 [9, 8] : f32 into vector<16x9xf32>
%269 = vector.extract %72[0, 10] : vector<9x16xf32>
%270 = vector.insert %269, %268 [10, 0] : f32 into vector<16x9xf32>
%271 = vector.extract %72[1, 10] : vector<9x16xf32>
%272 = vector.insert %271, %270 [10, 1] : f32 into vector<16x9xf32>
%273 = vector.extract %72[2, 10] : vector<9x16xf32>
%274 = vector.insert %273, %272 [10, 2] : f32 into vector<16x9xf32>
%275 = vector.extract %72[3, 10] : vector<9x16xf32>
%276 = vector.insert %275, %274 [10, 3] : f32 into vector<16x9xf32>
%277 = vector.extract %72[4, 10] : vector<9x16xf32>
%278 = vector.insert %277, %276 [10, 4] : f32 into vector<16x9xf32>
%279 = vector.extract %72[5, 10] : vector<9x16xf32>
%280 = vector.insert %279, %278 [10, 5] : f32 into vector<16x9xf32>
%281 = vector.extract %72[6, 10] : vector<9x16xf32>
%282 = vector.insert %281, %280 [10, 6] : f32 into vector<16x9xf32>
%283 = vector.extract %72[7, 10] : vector<9x16xf32>
%284 = vector.insert %283, %282 [10, 7] : f32 into vector<16x9xf32>
%285 = vector.extract %72[8, 10] : vector<9x16xf32>
%286 = vector.insert %285, %284 [10, 8] : f32 into vector<16x9xf32>
%287 = vector.extract %72[0, 11] : vector<9x16xf32>
%288 = vector.insert %287, %286 [11, 0] : f32 into vector<16x9xf32>
%289 = vector.extract %72[1, 11] : vector<9x16xf32>
%290 = vector.insert %289, %288 [11, 1] : f32 into vector<16x9xf32>
%291 = vector.extract %72[2, 11] : vector<9x16xf32>
%292 = vector.insert %291, %290 [11, 2] : f32 into vector<16x9xf32>
%293 = vector.extract %72[3, 11] : vector<9x16xf32>
%294 = vector.insert %293, %292 [11, 3] : f32 into vector<16x9xf32>
%295 = vector.extract %72[4, 11] : vector<9x16xf32>
%296 = vector.insert %295, %294 [11, 4] : f32 into vector<16x9xf32>
%297 = vector.extract %72[5, 11] : vector<9x16xf32>
%298 = vector.insert %297, %296 [11, 5] : f32 into vector<16x9xf32>
%299 = vector.extract %72[6, 11] : vector<9x16xf32>
%300 = vector.insert %299, %298 [11, 6] : f32 into vector<16x9xf32>
%301 = vector.extract %72[7, 11] : vector<9x16xf32>
%302 = vector.insert %301, %300 [11, 7] : f32 into vector<16x9xf32>
%303 = vector.extract %72[8, 11] : vector<9x16xf32>
%304 = vector.insert %303, %302 [11, 8] : f32 into vector<16x9xf32>
%305 = vector.extract %72[0, 12] : vector<9x16xf32>
%306 = vector.insert %305, %304 [12, 0] : f32 into vector<16x9xf32>
%307 = vector.extract %72[1, 12] : vector<9x16xf32>
%308 = vector.insert %307, %306 [12, 1] : f32 into vector<16x9xf32>
%309 = vector.extract %72[2, 12] : vector<9x16xf32>
%310 = vector.insert %309, %308 [12, 2] : f32 into vector<16x9xf32>
%311 = vector.extract %72[3, 12] : vector<9x16xf32>
%312 = vector.insert %311, %310 [12, 3] : f32 into vector<16x9xf32>
%313 = vector.extract %72[4, 12] : vector<9x16xf32>
%314 = vector.insert %313, %312 [12, 4] : f32 into vector<16x9xf32>
%315 = vector.extract %72[5, 12] : vector<9x16xf32>
%316 = vector.insert %315, %314 [12, 5] : f32 into vector<16x9xf32>
%317 = vector.extract %72[6, 12] : vector<9x16xf32>
%318 = vector.insert %317, %316 [12, 6] : f32 into vector<16x9xf32>
%319 = vector.extract %72[7, 12] : vector<9x16xf32>
%320 = vector.insert %319, %318 [12, 7] : f32 into vector<16x9xf32>
%321 = vector.extract %72[8, 12] : vector<9x16xf32>
%322 = vector.insert %321, %320 [12, 8] : f32 into vector<16x9xf32>
%323 = vector.extract %72[0, 13] : vector<9x16xf32>
%324 = vector.insert %323, %322 [13, 0] : f32 into vector<16x9xf32>
%325 = vector.extract %72[1, 13] : vector<9x16xf32>
%326 = vector.insert %325, %324 [13, 1] : f32 into vector<16x9xf32>
%327 = vector.extract %72[2, 13] : vector<9x16xf32>
%328 = vector.insert %327, %326 [13, 2] : f32 into vector<16x9xf32>
%329 = vector.extract %72[3, 13] : vector<9x16xf32>
%330 = vector.insert %329, %328 [13, 3] : f32 into vector<16x9xf32>
%331 = vector.extract %72[4, 13] : vector<9x16xf32>
%332 = vector.insert %331, %330 [13, 4] : f32 into vector<16x9xf32>
%333 = vector.extract %72[5, 13] : vector<9x16xf32>
%334 = vector.insert %333, %332 [13, 5] : f32 into vector<16x9xf32>
%335 = vector.extract %72[6, 13] : vector<9x16xf32>
%336 = vector.insert %335, %334 [13, 6] : f32 into vector<16x9xf32>
%337 = vector.extract %72[7, 13] : vector<9x16xf32>
%338 = vector.insert %337, %336 [13, 7] : f32 into vector<16x9xf32>
%339 = vector.extract %72[8, 13] : vector<9x16xf32>
%340 = vector.insert %339, %338 [13, 8] : f32 into vector<16x9xf32>
%341 = vector.extract %72[0, 14] : vector<9x16xf32>
%342 = vector.insert %341, %340 [14, 0] : f32 into vector<16x9xf32>
%343 = vector.extract %72[1, 14] : vector<9x16xf32>
%344 = vector.insert %343, %342 [14, 1] : f32 into vector<16x9xf32>
%345 = vector.extract %72[2, 14] : vector<9x16xf32>
%346 = vector.insert %345, %344 [14, 2] : f32 into vector<16x9xf32>
%347 = vector.extract %72[3, 14] : vector<9x16xf32>
%348 = vector.insert %347, %346 [14, 3] : f32 into vector<16x9xf32>
%349 = vector.extract %72[4, 14] : vector<9x16xf32>
%350 = vector.insert %349, %348 [14, 4] : f32 into vector<16x9xf32>
%351 = vector.extract %72[5, 14] : vector<9x16xf32>
%352 = vector.insert %351, %350 [14, 5] : f32 into vector<16x9xf32>
%353 = vector.extract %72[6, 14] : vector<9x16xf32>
%354 = vector.insert %353, %352 [14, 6] : f32 into vector<16x9xf32>
%355 = vector.extract %72[7, 14] : vector<9x16xf32>
%356 = vector.insert %355, %354 [14, 7] : f32 into vector<16x9xf32>
%357 = vector.extract %72[8, 14] : vector<9x16xf32>
%358 = vector.insert %357, %356 [14, 8] : f32 into vector<16x9xf32>
%359 = vector.extract %72[0, 15] : vector<9x16xf32>
%360 = vector.insert %359, %358 [15, 0] : f32 into vector<16x9xf32>
%361 = vector.extract %72[1, 15] : vector<9x16xf32>
%362 = vector.insert %361, %360 [15, 1] : f32 into vector<16x9xf32>
%363 = vector.extract %72[2, 15] : vector<9x16xf32>
%364 = vector.insert %363, %362 [15, 2] : f32 into vector<16x9xf32>
%365 = vector.extract %72[3, 15] : vector<9x16xf32>
%366 = vector.insert %365, %364 [15, 3] : f32 into vector<16x9xf32>
%367 = vector.extract %72[4, 15] : vector<9x16xf32>
%368 = vector.insert %367, %366 [15, 4] : f32 into vector<16x9xf32>
%369 = vector.extract %72[5, 15] : vector<9x16xf32>
%370 = vector.insert %369, %368 [15, 5] : f32 into vector<16x9xf32>
%371 = vector.extract %72[6, 15] : vector<9x16xf32>
%372 = vector.insert %371, %370 [15, 6] : f32 into vector<16x9xf32>
%373 = vector.extract %72[7, 15] : vector<9x16xf32>
%374 = vector.insert %373, %372 [15, 7] : f32 into vector<16x9xf32>
%375 = vector.extract %72[8, 15] : vector<9x16xf32>
%376 = vector.insert %375, %374 [15, 8] : f32 into vector<16x9xf32>
%377 = vector.extract %376[0] : vector<16x9xf32>
%378 = vector.outerproduct %377, %73, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%379 = vector.extract %376[1] : vector<16x9xf32>
%380 = vector.outerproduct %379, %74, %378 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%381 = vector.extract %376[2] : vector<16x9xf32>
%382 = vector.outerproduct %381, %75, %380 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%383 = vector.extract %376[3] : vector<16x9xf32>
%384 = vector.outerproduct %383, %76, %382 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%385 = vector.extract %376[4] : vector<16x9xf32>
%386 = vector.outerproduct %385, %77, %384 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%387 = vector.extract %376[5] : vector<16x9xf32>
%388 = vector.outerproduct %387, %78, %386 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%389 = vector.extract %376[6] : vector<16x9xf32>
%390 = vector.outerproduct %389, %79, %388 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%391 = vector.extract %376[7] : vector<16x9xf32>
%392 = vector.outerproduct %391, %80, %390 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%393 = vector.extract %376[8] : vector<16x9xf32>
%394 = vector.outerproduct %393, %81, %392 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%395 = vector.extract %376[9] : vector<16x9xf32>
%396 = vector.outerproduct %395, %82, %394 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%397 = vector.extract %376[10] : vector<16x9xf32>
%398 = vector.outerproduct %397, %83, %396 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%399 = vector.extract %376[11] : vector<16x9xf32>
%400 = vector.outerproduct %399, %84, %398 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%401 = vector.extract %376[12] : vector<16x9xf32>
%402 = vector.outerproduct %401, %85, %400 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%403 = vector.extract %376[13] : vector<16x9xf32>
%404 = vector.outerproduct %403, %86, %402 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%405 = vector.extract %376[14] : vector<16x9xf32>
%406 = vector.outerproduct %405, %87, %404 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%407 = vector.extract %376[15] : vector<16x9xf32>
%408 = vector.outerproduct %407, %88, %406 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %408 : vector<9x32xf32>
}
%43 = scf.if %22 -> (memref<?x?xf32, #map11>) {
%54 = memref.cast %20 : memref<?x?xf32, #map10> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
} else {
%54 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %54 : memref<?x?xf32, #map11>
}
%44 = vector.extract %42[0] : vector<9x32xf32>
vector.store %44, %43[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%45 = vector.extract %42[1] : vector<9x32xf32>
vector.store %45, %43[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%46 = vector.extract %42[2] : vector<9x32xf32>
vector.store %46, %43[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %42[3] : vector<9x32xf32>
vector.store %47, %43[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %42[4] : vector<9x32xf32>
vector.store %48, %43[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %42[5] : vector<9x32xf32>
vector.store %49, %43[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %42[6] : vector<9x32xf32>
vector.store %50, %43[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %42[7] : vector<9x32xf32>
vector.store %51, %43[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %42[8] : vector<9x32xf32>
vector.store %52, %43[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = arith.xori %22, %true : i1
scf.if %53 {
%54 = memref.subview %3[0, 0] [%15, %18] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%55 = memref.subview %20[0, 0] [%15, %18] [1, 1] : memref<?x?xf32, #map10> to memref<?x?xf32, #map10>
linalg.copy(%54, %55) : memref<?x?xf32, #map12>, memref<?x?xf32, #map10>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2041xf32>, %arg2: memref<?x2041xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2041xf32>, memref<?x2041xf32>) -> ()
}
return
}
}
compilation in 0.2498s
xxxxxxxxxx : 10 iters time on 1 threads in 0.1555s per iter sec (109.4 GFlop/s, 0.3215 GB/s) total time 1.555s
###############################################################
Runtime problem size {'M': 2040, 'N': 2041, 'K': 2042}
Compile-time problem size {'M': -1, 'N': -1, 'K': -1}
Problem types [<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>]
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43d6affd0>]]]
#map0 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map2 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x?xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x?xf32>) {
%5 = affine.min #map0(%arg3)[%1]
%6 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%7 = affine.min #map1(%arg5)[%2]
%8 = tensor.extract_slice %arg0[%arg3, %arg5] [%5, %7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%9 = scf.for %arg7 = %c0 to %3 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
%10 = affine.min #map2(%arg7)[%3]
%11 = tensor.extract_slice %arg1[%arg5, %arg7] [%7, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%12 = tensor.extract_slice %arg8[%arg3, %arg7] [%5, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%13 = linalg.matmul ins(%8, %11 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) -> tensor<?x?xf32>
%14 = tensor.insert_slice %13 into %arg8[%arg3, %arg7] [%5, %10] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %14 : tensor<?x?xf32>
}
scf.yield %9 : tensor<?x?xf32>
}
scf.yield %6 : tensor<?x?xf32>
}
return %4 : tensor<?x?xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x?xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x?xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %1 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43b6dbc50>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (-d0 + 32)>
#map10 = affine_map<(d0) -> (d0 ceildiv 16)>
#map11 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map12 = affine_map<(d0) -> (-d0 + 16)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0) -> (-d0 + 9)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x?xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = affine.apply #map0()[%2]
%5 = affine.apply #map1()[%3]
%6 = linalg.init_tensor [%4, %5, 4, 32, 16, 32] : tensor<?x?x4x32x16x32xf32>
%7 = tensor.cast %6 : tensor<?x?x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%8 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %7) -> (tensor<?x?x?x?x16x32xf32>) {
%12 = affine.apply #map2(%arg3)
%13 = affine.min #map3(%arg3)[%2]
%14 = scf.for %arg5 = %c0 to %3 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%15 = affine.apply #map4(%arg5)
%16 = affine.min #map5(%arg5)[%3]
%17 = scf.for %arg7 = %c0 to %16 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map6(%arg7)
%19 = affine.apply #map7(%arg7, %arg5)
%20 = affine.min #map8(%arg7, %16)
%21 = affine.apply #map9(%20)
%22 = scf.for %arg9 = %c0 to %13 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%23 = affine.apply #map10(%arg9)
%24 = affine.apply #map7(%arg9, %arg3)
%25 = affine.min #map11(%arg9, %13)
%26 = tensor.extract_slice %arg1[%24, %19] [%25, %20] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%27 = affine.apply #map12(%25)
%28 = linalg.pad_tensor %26 nofold low[%c0, %c0] high[%27, %21] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<16x32xf32>
%29 = tensor.insert_slice %28 into %arg10[%12, %15, %18, %23, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<?x?x?x?x16x32xf32>
scf.yield %29 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %22 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %14 : tensor<?x?x?x?x16x32xf32>
}
%9 = linalg.init_tensor [%4, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%10 = tensor.cast %9 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%11 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x?xf32>) {
%12 = affine.min #map13(%arg3)[%1]
%13 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %10) -> (tensor<?x?x?x9x16xf32>) {
%15 = affine.apply #map2(%arg5)
%16 = affine.min #map3(%arg5)[%2]
%17 = scf.for %arg7 = %c0 to %12 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%18 = affine.apply #map14(%arg7)
%19 = affine.apply #map7(%arg7, %arg3)
%20 = affine.min #map15(%arg7, %12)
%21 = affine.apply #map16(%20)
%22 = scf.for %arg9 = %c0 to %16 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%23 = affine.apply #map10(%arg9)
%24 = affine.apply #map7(%arg9, %arg5)
%25 = affine.min #map11(%arg9, %16)
%26 = tensor.extract_slice %arg0[%19, %24] [%20, %25] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%27 = affine.apply #map12(%25)
%28 = linalg.pad_tensor %26 nofold low[%c0, %c0] high[%21, %27] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x16xf32>
%29 = tensor.insert_slice %28 into %arg10[%15, %18, %23, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<9x16xf32> into tensor<?x?x?x9x16xf32>
scf.yield %29 : tensor<?x?x?x9x16xf32>
}
scf.yield %22 : tensor<?x?x?x9x16xf32>
}
scf.yield %17 : tensor<?x?x?x9x16xf32>
}
%14 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%15 = affine.min #map3(%arg5)[%2]
%16 = affine.apply #map2(%arg5)
%17 = scf.for %arg7 = %c0 to %3 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
%18 = affine.min #map5(%arg7)[%3]
%19 = tensor.extract_slice %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%20 = affine.apply #map4(%arg7)
%21 = scf.for %arg9 = %c0 to %12 step %c9 iter_args(%arg10 = %19) -> (tensor<?x?xf32>) {
%23 = affine.min #map15(%arg9, %12)
%24 = affine.apply #map14(%arg9)
%25 = affine.apply #map16(%23)
%26 = scf.for %arg11 = %c0 to %18 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%27 = affine.min #map8(%arg11, %18)
%28 = affine.apply #map6(%arg11)
%29 = affine.apply #map9(%27)
%30 = scf.for %arg13 = %c0 to %15 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%31 = tensor.extract_slice %arg14[%arg9, %arg11] [%23, %27] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%32 = affine.apply #map10(%arg13)
%33 = tensor.extract_slice %13[%16, %24, %32, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<?x?x?x9x16xf32> to tensor<9x16xf32>
%34 = tensor.extract_slice %8[%16, %20, %28, %32, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<?x?x?x?x16x32xf32> to tensor<16x32xf32>
%35 = linalg.pad_tensor %31 low[%c0, %c0] high[%25, %29] {
^bb0(%arg15: index, %arg16: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x32xf32>
%36 = linalg.matmul ins(%33, %34 : tensor<9x16xf32>, tensor<16x32xf32>) outs(%35 : tensor<9x32xf32>) -> tensor<9x32xf32>
%37 = tensor.extract_slice %36[0, 0] [%23, %27] [1, 1] : tensor<9x32xf32> to tensor<?x?xf32>
%38 = tensor.insert_slice %37 into %arg14[%arg9, %arg11] [%23, %27] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %38 : tensor<?x?xf32>
}
scf.yield %30 : tensor<?x?xf32>
}
scf.yield %26 : tensor<?x?xf32>
}
%22 = tensor.insert_slice %21 into %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %22 : tensor<?x?xf32>
}
scf.yield %17 : tensor<?x?xf32>
}
scf.yield %14 : tensor<?x?xf32>
}
return %11 : tensor<?x?xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x?xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x?xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %1 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Vectorize object at 0x7fb43b6bdc10>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map15 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map16 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x?xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = affine.apply #map0()[%2]
%5 = affine.apply #map1()[%3]
%6 = linalg.init_tensor [%4, %5, 4, 32, 16, 32] : tensor<?x?x4x32x16x32xf32>
%7 = tensor.cast %6 : tensor<?x?x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%8 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %7) -> (tensor<?x?x?x?x16x32xf32>) {
%12 = affine.apply #map2(%arg3)
%13 = affine.min #map3(%arg3)[%2]
%14 = scf.for %arg5 = %c0 to %3 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%15 = affine.apply #map4(%arg5)
%16 = affine.min #map5(%arg5)[%3]
%17 = scf.for %arg7 = %c0 to %16 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map6(%arg7)
%19 = affine.apply #map7(%arg7, %arg5)
%20 = affine.min #map8(%arg7, %16)
%21 = scf.for %arg9 = %c0 to %13 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%22 = affine.apply #map9(%arg9)
%23 = affine.apply #map7(%arg9, %arg3)
%24 = affine.min #map10(%arg9, %13)
%25 = tensor.extract_slice %arg1[%23, %19] [%24, %20] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = vector.transfer_read %25[%c0, %c0], %cst : tensor<?x?xf32>, vector<16x32xf32>
%27 = vector.transfer_write %26, %arg10[%12, %15, %18, %22, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, tensor<?x?x?x?x16x32xf32>
scf.yield %27 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %21 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %14 : tensor<?x?x?x?x16x32xf32>
}
%9 = linalg.init_tensor [%4, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%10 = tensor.cast %9 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%11 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x?xf32>) {
%12 = affine.min #map11(%arg3)[%1]
%13 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %10) -> (tensor<?x?x?x9x16xf32>) {
%15 = affine.apply #map2(%arg5)
%16 = affine.min #map3(%arg5)[%2]
%17 = scf.for %arg7 = %c0 to %12 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%18 = affine.apply #map12(%arg7)
%19 = affine.apply #map7(%arg7, %arg3)
%20 = affine.min #map13(%arg7, %12)
%21 = scf.for %arg9 = %c0 to %16 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%22 = affine.apply #map9(%arg9)
%23 = affine.apply #map7(%arg9, %arg5)
%24 = affine.min #map10(%arg9, %16)
%25 = tensor.extract_slice %arg0[%19, %23] [%20, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = vector.transfer_read %25[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x16xf32>
%27 = vector.transfer_write %26, %arg10[%15, %18, %22, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, tensor<?x?x?x9x16xf32>
scf.yield %27 : tensor<?x?x?x9x16xf32>
}
scf.yield %21 : tensor<?x?x?x9x16xf32>
}
scf.yield %17 : tensor<?x?x?x9x16xf32>
}
%14 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%15 = affine.min #map3(%arg5)[%2]
%16 = affine.apply #map2(%arg5)
%17 = scf.for %arg7 = %c0 to %3 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
%18 = affine.min #map5(%arg7)[%3]
%19 = tensor.extract_slice %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%20 = affine.apply #map4(%arg7)
%21 = scf.for %arg9 = %c0 to %12 step %c9 iter_args(%arg10 = %19) -> (tensor<?x?xf32>) {
%23 = affine.min #map13(%arg9, %12)
%24 = affine.apply #map12(%arg9)
%25 = scf.for %arg11 = %c0 to %18 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%26 = affine.min #map8(%arg11, %18)
%27 = affine.apply #map6(%arg11)
%28 = scf.for %arg13 = %c0 to %15 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%29 = tensor.extract_slice %arg14[%arg9, %arg11] [%23, %26] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%30 = affine.apply #map9(%arg13)
%31 = vector.transfer_read %13[%16, %24, %30, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x9x16xf32>, vector<9x16xf32>
%32 = vector.transfer_read %8[%16, %20, %27, %30, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x?x16x32xf32>, vector<16x32xf32>
%33 = vector.transfer_read %29[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x32xf32>
%34 = vector.contract {indexing_maps = [#map14, #map15, #map16], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %31, %32, %33 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
%35 = vector.transfer_write %34, %29[%c0, %c0] : vector<9x32xf32>, tensor<?x?xf32>
%36 = tensor.insert_slice %35 into %arg14[%arg9, %arg11] [%23, %26] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %36 : tensor<?x?xf32>
}
scf.yield %28 : tensor<?x?xf32>
}
scf.yield %25 : tensor<?x?xf32>
}
%22 = tensor.insert_slice %21 into %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %22 : tensor<?x?xf32>
}
scf.yield %17 : tensor<?x?xf32>
}
scf.yield %14 : tensor<?x?xf32>
}
return %11 : tensor<?x?xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x?xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x?xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %1 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Bufferize object at 0x7fb43b692110>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map16 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map17 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%3 = affine.apply #map0()[%1]
%4 = affine.apply #map1()[%2]
%5 = memref.alloc(%3, %4) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%7 = affine.apply #map2(%arg3)
%8 = affine.min #map3(%arg3)[%1]
scf.for %arg4 = %c0 to %2 step %c128 {
%9 = affine.apply #map4(%arg4)
%10 = affine.min #map5(%arg4)[%2]
scf.for %arg5 = %c0 to %10 step %c32 {
%11 = affine.apply #map6(%arg5)
%12 = affine.apply #map7(%arg5, %arg4)
%13 = affine.min #map8(%arg5, %10)
scf.for %arg6 = %c0 to %8 step %c16 {
%14 = affine.apply #map9(%arg6)
%15 = affine.apply #map7(%arg6, %arg3)
%16 = affine.min #map10(%arg6, %8)
%17 = memref.subview %arg1[%15, %12] [%16, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %18, %5[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%6 = memref.alloc(%3) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%7 = affine.min #map12(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)[%1]
scf.for %arg5 = %c0 to %7 step %c9 {
%10 = affine.apply #map13(%arg5)
%11 = affine.apply #map7(%arg5, %arg3)
%12 = affine.min #map14(%arg5, %7)
scf.for %arg6 = %c0 to %9 step %c16 {
%13 = affine.apply #map9(%arg6)
%14 = affine.apply #map7(%arg6, %arg4)
%15 = affine.min #map10(%arg6, %9)
%16 = memref.subview %arg0[%11, %14] [%12, %15] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%17 = vector.transfer_read %16[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %17, %6[%8, %10, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.min #map3(%arg4)[%1]
%9 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %2 step %c128 {
%10 = affine.min #map5(%arg5)[%2]
%11 = memref.subview %arg2[%arg3, %arg5] [%7, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%12 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %7 step %c9 {
%13 = affine.min #map14(%arg6, %7)
%14 = affine.apply #map13(%arg6)
scf.for %arg7 = %c0 to %10 step %c32 {
%15 = affine.min #map8(%arg7, %10)
%16 = affine.apply #map6(%arg7)
%17 = memref.subview %11[%arg6, %arg7] [%13, %15] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x32xf32>
%19 = scf.for %arg8 = %c0 to %8 step %c16 iter_args(%arg9 = %18) -> (vector<9x32xf32>) {
%20 = affine.apply #map9(%arg8)
%21 = vector.transfer_read %6[%9, %14, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%22 = vector.transfer_read %5[%9, %12, %16, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%23 = vector.contract {indexing_maps = [#map15, #map16, #map17], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %21, %22, %arg9 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
scf.yield %23 : vector<9x32xf32>
}
vector.transfer_write %19, %17[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map11>
}
}
}
}
}
memref.dealloc %5 : memref<?x?x4x32x16x32xf32>
memref.dealloc %6 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692190>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%3 = affine.apply #map0()[%1]
%4 = affine.apply #map1()[%2]
%5 = memref.alloc(%3, %4) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%7 = affine.apply #map2(%arg3)
%8 = affine.min #map3(%arg3)[%1]
scf.for %arg4 = %c0 to %2 step %c128 {
%9 = affine.apply #map4(%arg4)
%10 = affine.min #map5(%arg4)[%2]
scf.for %arg5 = %c0 to %10 step %c32 {
%11 = affine.apply #map6(%arg5)
%12 = affine.apply #map7(%arg5, %arg4)
%13 = affine.min #map8(%arg5, %10)
scf.for %arg6 = %c0 to %8 step %c16 {
%14 = affine.apply #map9(%arg6)
%15 = affine.apply #map7(%arg6, %arg3)
%16 = affine.min #map10(%arg6, %8)
%17 = memref.subview %arg1[%15, %12] [%16, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %18, %5[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%6 = memref.alloc(%3) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%7 = affine.min #map12(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)[%1]
scf.for %arg5 = %c0 to %7 step %c9 {
%10 = affine.apply #map13(%arg5)
%11 = affine.apply #map7(%arg5, %arg3)
%12 = affine.min #map14(%arg5, %7)
scf.for %arg6 = %c0 to %9 step %c16 {
%13 = affine.apply #map9(%arg6)
%14 = affine.apply #map7(%arg6, %arg4)
%15 = affine.min #map10(%arg6, %9)
%16 = memref.subview %arg0[%11, %14] [%12, %15] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%17 = vector.transfer_read %16[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %17, %6[%8, %10, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.min #map3(%arg4)[%1]
%9 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %2 step %c128 {
%10 = affine.min #map5(%arg5)[%2]
%11 = memref.subview %arg2[%arg3, %arg5] [%7, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%12 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %7 step %c9 {
%13 = affine.min #map14(%arg6, %7)
%14 = affine.apply #map13(%arg6)
scf.for %arg7 = %c0 to %10 step %c32 {
%15 = affine.min #map8(%arg7, %10)
%16 = affine.apply #map6(%arg7)
%17 = memref.subview %11[%arg6, %arg7] [%13, %15] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x32xf32>
%19 = scf.for %arg8 = %c0 to %8 step %c16 iter_args(%arg9 = %18) -> (vector<9x32xf32>) {
%20 = affine.apply #map9(%arg8)
%21 = vector.transfer_read %6[%9, %14, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%22 = vector.transfer_read %5[%9, %12, %16, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%23 = vector.transpose %21, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%24 = vector.extract %23[0] : vector<16x9xf32>
%25 = vector.extract %22[0] : vector<16x32xf32>
%26 = vector.outerproduct %24, %25, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%27 = vector.extract %23[1] : vector<16x9xf32>
%28 = vector.extract %22[1] : vector<16x32xf32>
%29 = vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%30 = vector.extract %23[2] : vector<16x9xf32>
%31 = vector.extract %22[2] : vector<16x32xf32>
%32 = vector.outerproduct %30, %31, %29 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%33 = vector.extract %23[3] : vector<16x9xf32>
%34 = vector.extract %22[3] : vector<16x32xf32>
%35 = vector.outerproduct %33, %34, %32 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%36 = vector.extract %23[4] : vector<16x9xf32>
%37 = vector.extract %22[4] : vector<16x32xf32>
%38 = vector.outerproduct %36, %37, %35 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%39 = vector.extract %23[5] : vector<16x9xf32>
%40 = vector.extract %22[5] : vector<16x32xf32>
%41 = vector.outerproduct %39, %40, %38 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%42 = vector.extract %23[6] : vector<16x9xf32>
%43 = vector.extract %22[6] : vector<16x32xf32>
%44 = vector.outerproduct %42, %43, %41 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%45 = vector.extract %23[7] : vector<16x9xf32>
%46 = vector.extract %22[7] : vector<16x32xf32>
%47 = vector.outerproduct %45, %46, %44 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%48 = vector.extract %23[8] : vector<16x9xf32>
%49 = vector.extract %22[8] : vector<16x32xf32>
%50 = vector.outerproduct %48, %49, %47 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%51 = vector.extract %23[9] : vector<16x9xf32>
%52 = vector.extract %22[9] : vector<16x32xf32>
%53 = vector.outerproduct %51, %52, %50 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%54 = vector.extract %23[10] : vector<16x9xf32>
%55 = vector.extract %22[10] : vector<16x32xf32>
%56 = vector.outerproduct %54, %55, %53 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%57 = vector.extract %23[11] : vector<16x9xf32>
%58 = vector.extract %22[11] : vector<16x32xf32>
%59 = vector.outerproduct %57, %58, %56 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%60 = vector.extract %23[12] : vector<16x9xf32>
%61 = vector.extract %22[12] : vector<16x32xf32>
%62 = vector.outerproduct %60, %61, %59 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%63 = vector.extract %23[13] : vector<16x9xf32>
%64 = vector.extract %22[13] : vector<16x32xf32>
%65 = vector.outerproduct %63, %64, %62 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%66 = vector.extract %23[14] : vector<16x9xf32>
%67 = vector.extract %22[14] : vector<16x32xf32>
%68 = vector.outerproduct %66, %67, %65 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%69 = vector.extract %23[15] : vector<16x9xf32>
%70 = vector.extract %22[15] : vector<16x32xf32>
%71 = vector.outerproduct %69, %70, %68 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %71 : vector<9x32xf32>
}
vector.transfer_write %19, %17[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map11>
}
}
}
}
}
memref.dealloc %5 : memref<?x?x4x32x16x32xf32>
memref.dealloc %6 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6923d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%3 = affine.apply #map0()[%1]
%4 = affine.apply #map1()[%2]
%5 = memref.alloc(%3, %4) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%7 = affine.apply #map2(%arg3)
%8 = affine.min #map3(%arg3)[%1]
scf.for %arg4 = %c0 to %2 step %c128 {
%9 = affine.apply #map4(%arg4)
%10 = affine.min #map5(%arg4)[%2]
scf.for %arg5 = %c0 to %10 step %c32 {
%11 = affine.apply #map6(%arg5)
%12 = affine.apply #map7(%arg5, %arg4)
%13 = affine.min #map8(%arg5, %10)
scf.for %arg6 = %c0 to %8 step %c16 {
%14 = affine.apply #map9(%arg6)
%15 = affine.apply #map7(%arg6, %arg3)
%16 = affine.min #map10(%arg6, %8)
%17 = memref.subview %arg1[%15, %12] [%16, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %18, %5[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%6 = memref.alloc(%3) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%7 = affine.min #map12(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)[%1]
scf.for %arg5 = %c0 to %7 step %c9 {
%10 = affine.apply #map13(%arg5)
%11 = affine.apply #map7(%arg5, %arg3)
%12 = affine.min #map14(%arg5, %7)
scf.for %arg6 = %c0 to %9 step %c16 {
%13 = affine.apply #map9(%arg6)
%14 = affine.apply #map7(%arg6, %arg4)
%15 = affine.min #map10(%arg6, %9)
%16 = memref.subview %arg0[%11, %14] [%12, %15] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%17 = vector.transfer_read %16[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %17, %6[%8, %10, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.min #map3(%arg4)[%1]
%9 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %2 step %c128 {
%10 = affine.min #map5(%arg5)[%2]
%11 = memref.subview %arg2[%arg3, %arg5] [%7, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%12 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %7 step %c9 {
%13 = affine.min #map14(%arg6, %7)
%14 = affine.apply #map13(%arg6)
scf.for %arg7 = %c0 to %10 step %c32 {
%15 = affine.min #map8(%arg7, %10)
%16 = affine.apply #map6(%arg7)
%17 = memref.subview %11[%arg6, %arg7] [%13, %15] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x32xf32>
%19 = scf.for %arg8 = %c0 to %8 step %c16 iter_args(%arg9 = %18) -> (vector<9x32xf32>) {
%20 = affine.apply #map9(%arg8)
%21 = vector.transfer_read %6[%9, %14, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%22 = vector.transfer_read %5[%9, %12, %16, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%23 = vector.transpose %21, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%24 = vector.extract %23[0] : vector<16x9xf32>
%25 = vector.extract %22[0] : vector<16x32xf32>
%26 = vector.outerproduct %24, %25, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%27 = vector.extract %23[1] : vector<16x9xf32>
%28 = vector.extract %22[1] : vector<16x32xf32>
%29 = vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%30 = vector.extract %23[2] : vector<16x9xf32>
%31 = vector.extract %22[2] : vector<16x32xf32>
%32 = vector.outerproduct %30, %31, %29 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%33 = vector.extract %23[3] : vector<16x9xf32>
%34 = vector.extract %22[3] : vector<16x32xf32>
%35 = vector.outerproduct %33, %34, %32 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%36 = vector.extract %23[4] : vector<16x9xf32>
%37 = vector.extract %22[4] : vector<16x32xf32>
%38 = vector.outerproduct %36, %37, %35 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%39 = vector.extract %23[5] : vector<16x9xf32>
%40 = vector.extract %22[5] : vector<16x32xf32>
%41 = vector.outerproduct %39, %40, %38 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%42 = vector.extract %23[6] : vector<16x9xf32>
%43 = vector.extract %22[6] : vector<16x32xf32>
%44 = vector.outerproduct %42, %43, %41 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%45 = vector.extract %23[7] : vector<16x9xf32>
%46 = vector.extract %22[7] : vector<16x32xf32>
%47 = vector.outerproduct %45, %46, %44 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%48 = vector.extract %23[8] : vector<16x9xf32>
%49 = vector.extract %22[8] : vector<16x32xf32>
%50 = vector.outerproduct %48, %49, %47 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%51 = vector.extract %23[9] : vector<16x9xf32>
%52 = vector.extract %22[9] : vector<16x32xf32>
%53 = vector.outerproduct %51, %52, %50 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%54 = vector.extract %23[10] : vector<16x9xf32>
%55 = vector.extract %22[10] : vector<16x32xf32>
%56 = vector.outerproduct %54, %55, %53 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%57 = vector.extract %23[11] : vector<16x9xf32>
%58 = vector.extract %22[11] : vector<16x32xf32>
%59 = vector.outerproduct %57, %58, %56 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%60 = vector.extract %23[12] : vector<16x9xf32>
%61 = vector.extract %22[12] : vector<16x32xf32>
%62 = vector.outerproduct %60, %61, %59 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%63 = vector.extract %23[13] : vector<16x9xf32>
%64 = vector.extract %22[13] : vector<16x32xf32>
%65 = vector.outerproduct %63, %64, %62 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%66 = vector.extract %23[14] : vector<16x9xf32>
%67 = vector.extract %22[14] : vector<16x32xf32>
%68 = vector.outerproduct %66, %67, %65 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%69 = vector.extract %23[15] : vector<16x9xf32>
%70 = vector.extract %22[15] : vector<16x32xf32>
%71 = vector.outerproduct %69, %70, %68 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %71 : vector<9x32xf32>
}
vector.transfer_write %19, %17[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map11>
}
}
}
}
}
memref.dealloc %5 : memref<?x?x4x32x16x32xf32>
memref.dealloc %6 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692210>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%true = arith.constant true
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%27 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%28 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%27, %28) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%29 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %29 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %26, %9[%11, %13, %15, %19, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%26 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%27 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%26, %27) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%28 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
}
%25 = vector.transfer_read %24[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %25, %10[%12, %14, %18, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%30 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%31 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%30, %31) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%32 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %32 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x32xf32>
%27 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %26) -> (vector<9x32xf32>) {
%30 = affine.apply #map9(%arg8)
%31 = vector.transfer_read %10[%13, %18, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%32 = vector.transfer_read %9[%13, %16, %21, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%33 = vector.transpose %31, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%34 = vector.extract %33[0] : vector<16x9xf32>
%35 = vector.extract %32[0] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %33[1] : vector<16x9xf32>
%38 = vector.extract %32[1] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %33[2] : vector<16x9xf32>
%41 = vector.extract %32[2] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %33[3] : vector<16x9xf32>
%44 = vector.extract %32[3] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %33[4] : vector<16x9xf32>
%47 = vector.extract %32[4] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %33[5] : vector<16x9xf32>
%50 = vector.extract %32[5] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %33[6] : vector<16x9xf32>
%53 = vector.extract %32[6] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %33[7] : vector<16x9xf32>
%56 = vector.extract %32[7] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %33[8] : vector<16x9xf32>
%59 = vector.extract %32[8] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %33[9] : vector<16x9xf32>
%62 = vector.extract %32[9] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %33[10] : vector<16x9xf32>
%65 = vector.extract %32[10] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %33[11] : vector<16x9xf32>
%68 = vector.extract %32[11] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%70 = vector.extract %33[12] : vector<16x9xf32>
%71 = vector.extract %32[12] : vector<16x32xf32>
%72 = vector.outerproduct %70, %71, %69 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%73 = vector.extract %33[13] : vector<16x9xf32>
%74 = vector.extract %32[13] : vector<16x32xf32>
%75 = vector.outerproduct %73, %74, %72 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%76 = vector.extract %33[14] : vector<16x9xf32>
%77 = vector.extract %32[14] : vector<16x32xf32>
%78 = vector.outerproduct %76, %77, %75 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%79 = vector.extract %33[15] : vector<16x9xf32>
%80 = vector.extract %32[15] : vector<16x32xf32>
%81 = vector.outerproduct %79, %80, %78 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %81 : vector<9x32xf32>
}
%28 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%30 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %30 : memref<?x?xf32, #map11>
}
vector.transfer_write %27, %28[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map11>
%29 = arith.xori %24, %true : i1
scf.if %29 {
%30 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%31 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%30, %31) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6921d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%true = arith.constant true
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%27 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%28 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%27, %28) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%29 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %29 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %26, %9[%11, %13, %15, %19, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%26 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%27 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%26, %27) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%28 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
}
%25 = vector.transfer_read %24[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %25, %10[%12, %14, %18, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%30 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%31 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%30, %31) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%32 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %32 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x32xf32>
%27 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %26) -> (vector<9x32xf32>) {
%30 = affine.apply #map9(%arg8)
%31 = vector.transfer_read %10[%13, %18, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%32 = vector.transfer_read %9[%13, %16, %21, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%33 = vector.transpose %31, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%34 = vector.extract %33[0] : vector<16x9xf32>
%35 = vector.extract %32[0] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %33[1] : vector<16x9xf32>
%38 = vector.extract %32[1] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %33[2] : vector<16x9xf32>
%41 = vector.extract %32[2] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %33[3] : vector<16x9xf32>
%44 = vector.extract %32[3] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %33[4] : vector<16x9xf32>
%47 = vector.extract %32[4] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %33[5] : vector<16x9xf32>
%50 = vector.extract %32[5] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %33[6] : vector<16x9xf32>
%53 = vector.extract %32[6] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %33[7] : vector<16x9xf32>
%56 = vector.extract %32[7] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %33[8] : vector<16x9xf32>
%59 = vector.extract %32[8] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %33[9] : vector<16x9xf32>
%62 = vector.extract %32[9] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %33[10] : vector<16x9xf32>
%65 = vector.extract %32[10] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %33[11] : vector<16x9xf32>
%68 = vector.extract %32[11] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%70 = vector.extract %33[12] : vector<16x9xf32>
%71 = vector.extract %32[12] : vector<16x32xf32>
%72 = vector.outerproduct %70, %71, %69 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%73 = vector.extract %33[13] : vector<16x9xf32>
%74 = vector.extract %32[13] : vector<16x32xf32>
%75 = vector.outerproduct %73, %74, %72 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%76 = vector.extract %33[14] : vector<16x9xf32>
%77 = vector.extract %32[14] : vector<16x32xf32>
%78 = vector.outerproduct %76, %77, %75 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%79 = vector.extract %33[15] : vector<16x9xf32>
%80 = vector.extract %32[15] : vector<16x32xf32>
%81 = vector.outerproduct %79, %80, %78 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %81 : vector<9x32xf32>
}
%28 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%30 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %30 : memref<?x?xf32, #map11>
}
vector.transfer_write %27, %28[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map11>
%29 = arith.xori %24, %true : i1
scf.if %29 {
%30 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%31 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%30, %31) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692150>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_1, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %0) : f32, memref<16x32xf32>
%42 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%43 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%42, %43) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%44 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %44 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %25[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %25[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %25[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %25[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %25[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%40 = vector.load %25[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.load %25[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %26, %9[%11, %13, %15, %19, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %9[%11, %13, %15, %19, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %9[%11, %13, %15, %19, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %9[%11, %13, %15, %19, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %9[%11, %13, %15, %19, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %9[%11, %13, %15, %19, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %9[%11, %13, %15, %19, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %9[%11, %13, %15, %19, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %9[%11, %13, %15, %19, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %9[%11, %13, %15, %19, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %9[%11, %13, %15, %19, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %9[%11, %13, %15, %19, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %9[%11, %13, %15, %19, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %9[%11, %13, %15, %19, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %40, %9[%11, %13, %15, %19, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %9[%11, %13, %15, %19, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %1) : f32, memref<9x16xf32>
%34 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%35 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%34, %35) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%36 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %36 : memref<?x?xf32, #map11>
}
%25 = vector.load %24[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %24[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %24[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %24[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %24[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %24[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %24[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%32 = vector.load %24[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%33 = vector.load %24[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %25, %10[%12, %14, %18, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %10[%12, %14, %18, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %10[%12, %14, %18, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %10[%12, %14, %18, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %10[%12, %14, %18, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %10[%12, %14, %18, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %10[%12, %14, %18, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %32, %10[%12, %14, %18, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %33, %10[%12, %14, %18, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %2) : f32, memref<9x32xf32>
%56 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%57 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%56, %57) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%58 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %58 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [1] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [2] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [3] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [4] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [5] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [6] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [7] : vector<32xf32> into vector<9x32xf32>
%42 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%43 = vector.insert %42, %41 [8] : vector<32xf32> into vector<9x32xf32>
%44 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %43) -> (vector<9x32xf32>) {
%56 = affine.apply #map9(%arg8)
%57 = vector.load %10[%13, %18, %56, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %cst [0] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %10[%13, %18, %56, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [1] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %10[%13, %18, %56, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [2] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %10[%13, %18, %56, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [3] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %10[%13, %18, %56, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [4] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %10[%13, %18, %56, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [5] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %10[%13, %18, %56, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [6] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %10[%13, %18, %56, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [7] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %10[%13, %18, %56, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%74 = vector.insert %73, %72 [8] : vector<16xf32> into vector<9x16xf32>
%75 = vector.load %9[%13, %16, %21, %56, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %9[%13, %16, %21, %56, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %9[%13, %16, %21, %56, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %9[%13, %16, %21, %56, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %9[%13, %16, %21, %56, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %9[%13, %16, %21, %56, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %9[%13, %16, %21, %56, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %9[%13, %16, %21, %56, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %9[%13, %16, %21, %56, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %9[%13, %16, %21, %56, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %9[%13, %16, %21, %56, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %9[%13, %16, %21, %56, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %9[%13, %16, %21, %56, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %9[%13, %16, %21, %56, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%89 = vector.load %9[%13, %16, %21, %56, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%90 = vector.load %9[%13, %16, %21, %56, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%91 = vector.transpose %74, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%92 = vector.extract %91[0] : vector<16x9xf32>
%93 = vector.outerproduct %92, %75, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %91[1] : vector<16x9xf32>
%95 = vector.outerproduct %94, %76, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %91[2] : vector<16x9xf32>
%97 = vector.outerproduct %96, %77, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %91[3] : vector<16x9xf32>
%99 = vector.outerproduct %98, %78, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %91[4] : vector<16x9xf32>
%101 = vector.outerproduct %100, %79, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %91[5] : vector<16x9xf32>
%103 = vector.outerproduct %102, %80, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %91[6] : vector<16x9xf32>
%105 = vector.outerproduct %104, %81, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %91[7] : vector<16x9xf32>
%107 = vector.outerproduct %106, %82, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %91[8] : vector<16x9xf32>
%109 = vector.outerproduct %108, %83, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %91[9] : vector<16x9xf32>
%111 = vector.outerproduct %110, %84, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %91[10] : vector<16x9xf32>
%113 = vector.outerproduct %112, %85, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %91[11] : vector<16x9xf32>
%115 = vector.outerproduct %114, %86, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %91[12] : vector<16x9xf32>
%117 = vector.outerproduct %116, %87, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%118 = vector.extract %91[13] : vector<16x9xf32>
%119 = vector.outerproduct %118, %88, %117 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%120 = vector.extract %91[14] : vector<16x9xf32>
%121 = vector.outerproduct %120, %89, %119 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%122 = vector.extract %91[15] : vector<16x9xf32>
%123 = vector.outerproduct %122, %90, %121 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %123 : vector<9x32xf32>
}
%45 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%56 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%46 = vector.extract %44[0] : vector<9x32xf32>
vector.store %46, %45[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %44[1] : vector<9x32xf32>
vector.store %47, %45[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %44[2] : vector<9x32xf32>
vector.store %48, %45[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %44[3] : vector<9x32xf32>
vector.store %49, %45[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %44[4] : vector<9x32xf32>
vector.store %50, %45[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %44[5] : vector<9x32xf32>
vector.store %51, %45[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %44[6] : vector<9x32xf32>
vector.store %52, %45[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = vector.extract %44[7] : vector<9x32xf32>
vector.store %53, %45[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%54 = vector.extract %44[8] : vector<9x32xf32>
vector.store %54, %45[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%55 = arith.xori %24, %true : i1
scf.if %55 {
%56 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%57 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%56, %57) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692390>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%true = arith.constant true
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%42 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%43 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%42, %43) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%44 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %44 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %25[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %25[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %25[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %25[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %25[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%40 = vector.load %25[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.load %25[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %26, %9[%11, %13, %15, %19, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %9[%11, %13, %15, %19, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %9[%11, %13, %15, %19, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %9[%11, %13, %15, %19, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %9[%11, %13, %15, %19, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %9[%11, %13, %15, %19, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %9[%11, %13, %15, %19, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %9[%11, %13, %15, %19, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %9[%11, %13, %15, %19, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %9[%11, %13, %15, %19, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %9[%11, %13, %15, %19, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %9[%11, %13, %15, %19, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %9[%11, %13, %15, %19, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %9[%11, %13, %15, %19, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %40, %9[%11, %13, %15, %19, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %9[%11, %13, %15, %19, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%34 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%35 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%34, %35) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%36 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %36 : memref<?x?xf32, #map11>
}
%25 = vector.load %24[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %24[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %24[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %24[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %24[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %24[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %24[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%32 = vector.load %24[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%33 = vector.load %24[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %25, %10[%12, %14, %18, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %10[%12, %14, %18, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %10[%12, %14, %18, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %10[%12, %14, %18, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %10[%12, %14, %18, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %10[%12, %14, %18, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %10[%12, %14, %18, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %32, %10[%12, %14, %18, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %33, %10[%12, %14, %18, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%56 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%57 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%56, %57) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%58 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %58 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [1] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [2] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [3] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [4] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [5] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [6] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [7] : vector<32xf32> into vector<9x32xf32>
%42 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%43 = vector.insert %42, %41 [8] : vector<32xf32> into vector<9x32xf32>
%44 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %43) -> (vector<9x32xf32>) {
%56 = affine.apply #map9(%arg8)
%57 = vector.load %10[%13, %18, %56, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %cst_1 [0] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %10[%13, %18, %56, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [1] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %10[%13, %18, %56, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [2] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %10[%13, %18, %56, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [3] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %10[%13, %18, %56, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [4] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %10[%13, %18, %56, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [5] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %10[%13, %18, %56, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [6] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %10[%13, %18, %56, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [7] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %10[%13, %18, %56, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%74 = vector.insert %73, %72 [8] : vector<16xf32> into vector<9x16xf32>
%75 = vector.load %9[%13, %16, %21, %56, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %9[%13, %16, %21, %56, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %9[%13, %16, %21, %56, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %9[%13, %16, %21, %56, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %9[%13, %16, %21, %56, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %9[%13, %16, %21, %56, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %9[%13, %16, %21, %56, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %9[%13, %16, %21, %56, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %9[%13, %16, %21, %56, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %9[%13, %16, %21, %56, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %9[%13, %16, %21, %56, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %9[%13, %16, %21, %56, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %9[%13, %16, %21, %56, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %9[%13, %16, %21, %56, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%89 = vector.load %9[%13, %16, %21, %56, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%90 = vector.load %9[%13, %16, %21, %56, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%91 = vector.transpose %74, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%92 = vector.extract %91[0] : vector<16x9xf32>
%93 = vector.outerproduct %92, %75, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %91[1] : vector<16x9xf32>
%95 = vector.outerproduct %94, %76, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %91[2] : vector<16x9xf32>
%97 = vector.outerproduct %96, %77, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %91[3] : vector<16x9xf32>
%99 = vector.outerproduct %98, %78, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %91[4] : vector<16x9xf32>
%101 = vector.outerproduct %100, %79, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %91[5] : vector<16x9xf32>
%103 = vector.outerproduct %102, %80, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %91[6] : vector<16x9xf32>
%105 = vector.outerproduct %104, %81, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %91[7] : vector<16x9xf32>
%107 = vector.outerproduct %106, %82, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %91[8] : vector<16x9xf32>
%109 = vector.outerproduct %108, %83, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %91[9] : vector<16x9xf32>
%111 = vector.outerproduct %110, %84, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %91[10] : vector<16x9xf32>
%113 = vector.outerproduct %112, %85, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %91[11] : vector<16x9xf32>
%115 = vector.outerproduct %114, %86, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %91[12] : vector<16x9xf32>
%117 = vector.outerproduct %116, %87, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%118 = vector.extract %91[13] : vector<16x9xf32>
%119 = vector.outerproduct %118, %88, %117 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%120 = vector.extract %91[14] : vector<16x9xf32>
%121 = vector.outerproduct %120, %89, %119 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%122 = vector.extract %91[15] : vector<16x9xf32>
%123 = vector.outerproduct %122, %90, %121 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %123 : vector<9x32xf32>
}
%45 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%56 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%46 = vector.extract %44[0] : vector<9x32xf32>
vector.store %46, %45[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %44[1] : vector<9x32xf32>
vector.store %47, %45[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %44[2] : vector<9x32xf32>
vector.store %48, %45[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %44[3] : vector<9x32xf32>
vector.store %49, %45[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %44[4] : vector<9x32xf32>
vector.store %50, %45[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %44[5] : vector<9x32xf32>
vector.store %51, %45[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %44[6] : vector<9x32xf32>
vector.store %52, %45[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = vector.extract %44[7] : vector<9x32xf32>
vector.store %53, %45[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%54 = vector.extract %44[8] : vector<9x32xf32>
vector.store %54, %45[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%55 = arith.xori %24, %true : i1
scf.if %55 {
%56 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%57 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%56, %57) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692250>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%cst = arith.constant dense<0.000000e+00> : vector<16x9xf32>
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_2, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_2, %0) : f32, memref<16x32xf32>
%42 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%43 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%42, %43) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%44 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %44 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %25[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %25[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %25[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %25[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %25[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%40 = vector.load %25[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.load %25[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %26, %9[%11, %13, %15, %19, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %9[%11, %13, %15, %19, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %9[%11, %13, %15, %19, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %9[%11, %13, %15, %19, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %9[%11, %13, %15, %19, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %9[%11, %13, %15, %19, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %9[%11, %13, %15, %19, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %9[%11, %13, %15, %19, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %9[%11, %13, %15, %19, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %9[%11, %13, %15, %19, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %9[%11, %13, %15, %19, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %9[%11, %13, %15, %19, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %9[%11, %13, %15, %19, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %9[%11, %13, %15, %19, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %40, %9[%11, %13, %15, %19, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %9[%11, %13, %15, %19, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_2, %1) : f32, memref<9x16xf32>
%34 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%35 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%34, %35) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%36 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %36 : memref<?x?xf32, #map11>
}
%25 = vector.load %24[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %24[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %24[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %24[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %24[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %24[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %24[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%32 = vector.load %24[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%33 = vector.load %24[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %25, %10[%12, %14, %18, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %10[%12, %14, %18, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %10[%12, %14, %18, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %10[%12, %14, %18, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %10[%12, %14, %18, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %10[%12, %14, %18, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %10[%12, %14, %18, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %32, %10[%12, %14, %18, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %33, %10[%12, %14, %18, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_2, %2) : f32, memref<9x32xf32>
%56 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%57 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%56, %57) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%58 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %58 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %cst_1 [0] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [1] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [2] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [3] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [4] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [5] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [6] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [7] : vector<32xf32> into vector<9x32xf32>
%42 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%43 = vector.insert %42, %41 [8] : vector<32xf32> into vector<9x32xf32>
%44 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %43) -> (vector<9x32xf32>) {
%56 = affine.apply #map9(%arg8)
%57 = vector.load %10[%13, %18, %56, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %cst_0 [0] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %10[%13, %18, %56, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [1] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %10[%13, %18, %56, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [2] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %10[%13, %18, %56, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [3] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %10[%13, %18, %56, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [4] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %10[%13, %18, %56, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [5] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %10[%13, %18, %56, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [6] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %10[%13, %18, %56, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [7] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %10[%13, %18, %56, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%74 = vector.insert %73, %72 [8] : vector<16xf32> into vector<9x16xf32>
%75 = vector.load %9[%13, %16, %21, %56, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %9[%13, %16, %21, %56, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %9[%13, %16, %21, %56, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %9[%13, %16, %21, %56, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %9[%13, %16, %21, %56, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %9[%13, %16, %21, %56, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %9[%13, %16, %21, %56, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %9[%13, %16, %21, %56, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %9[%13, %16, %21, %56, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %9[%13, %16, %21, %56, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %9[%13, %16, %21, %56, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %9[%13, %16, %21, %56, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %9[%13, %16, %21, %56, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %9[%13, %16, %21, %56, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%89 = vector.load %9[%13, %16, %21, %56, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%90 = vector.load %9[%13, %16, %21, %56, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%91 = vector.extract %74[0, 0] : vector<9x16xf32>
%92 = vector.insert %91, %cst [0, 0] : f32 into vector<16x9xf32>
%93 = vector.extract %74[1, 0] : vector<9x16xf32>
%94 = vector.insert %93, %92 [0, 1] : f32 into vector<16x9xf32>
%95 = vector.extract %74[2, 0] : vector<9x16xf32>
%96 = vector.insert %95, %94 [0, 2] : f32 into vector<16x9xf32>
%97 = vector.extract %74[3, 0] : vector<9x16xf32>
%98 = vector.insert %97, %96 [0, 3] : f32 into vector<16x9xf32>
%99 = vector.extract %74[4, 0] : vector<9x16xf32>
%100 = vector.insert %99, %98 [0, 4] : f32 into vector<16x9xf32>
%101 = vector.extract %74[5, 0] : vector<9x16xf32>
%102 = vector.insert %101, %100 [0, 5] : f32 into vector<16x9xf32>
%103 = vector.extract %74[6, 0] : vector<9x16xf32>
%104 = vector.insert %103, %102 [0, 6] : f32 into vector<16x9xf32>
%105 = vector.extract %74[7, 0] : vector<9x16xf32>
%106 = vector.insert %105, %104 [0, 7] : f32 into vector<16x9xf32>
%107 = vector.extract %74[8, 0] : vector<9x16xf32>
%108 = vector.insert %107, %106 [0, 8] : f32 into vector<16x9xf32>
%109 = vector.extract %74[0, 1] : vector<9x16xf32>
%110 = vector.insert %109, %108 [1, 0] : f32 into vector<16x9xf32>
%111 = vector.extract %74[1, 1] : vector<9x16xf32>
%112 = vector.insert %111, %110 [1, 1] : f32 into vector<16x9xf32>
%113 = vector.extract %74[2, 1] : vector<9x16xf32>
%114 = vector.insert %113, %112 [1, 2] : f32 into vector<16x9xf32>
%115 = vector.extract %74[3, 1] : vector<9x16xf32>
%116 = vector.insert %115, %114 [1, 3] : f32 into vector<16x9xf32>
%117 = vector.extract %74[4, 1] : vector<9x16xf32>
%118 = vector.insert %117, %116 [1, 4] : f32 into vector<16x9xf32>
%119 = vector.extract %74[5, 1] : vector<9x16xf32>
%120 = vector.insert %119, %118 [1, 5] : f32 into vector<16x9xf32>
%121 = vector.extract %74[6, 1] : vector<9x16xf32>
%122 = vector.insert %121, %120 [1, 6] : f32 into vector<16x9xf32>
%123 = vector.extract %74[7, 1] : vector<9x16xf32>
%124 = vector.insert %123, %122 [1, 7] : f32 into vector<16x9xf32>
%125 = vector.extract %74[8, 1] : vector<9x16xf32>
%126 = vector.insert %125, %124 [1, 8] : f32 into vector<16x9xf32>
%127 = vector.extract %74[0, 2] : vector<9x16xf32>
%128 = vector.insert %127, %126 [2, 0] : f32 into vector<16x9xf32>
%129 = vector.extract %74[1, 2] : vector<9x16xf32>
%130 = vector.insert %129, %128 [2, 1] : f32 into vector<16x9xf32>
%131 = vector.extract %74[2, 2] : vector<9x16xf32>
%132 = vector.insert %131, %130 [2, 2] : f32 into vector<16x9xf32>
%133 = vector.extract %74[3, 2] : vector<9x16xf32>
%134 = vector.insert %133, %132 [2, 3] : f32 into vector<16x9xf32>
%135 = vector.extract %74[4, 2] : vector<9x16xf32>
%136 = vector.insert %135, %134 [2, 4] : f32 into vector<16x9xf32>
%137 = vector.extract %74[5, 2] : vector<9x16xf32>
%138 = vector.insert %137, %136 [2, 5] : f32 into vector<16x9xf32>
%139 = vector.extract %74[6, 2] : vector<9x16xf32>
%140 = vector.insert %139, %138 [2, 6] : f32 into vector<16x9xf32>
%141 = vector.extract %74[7, 2] : vector<9x16xf32>
%142 = vector.insert %141, %140 [2, 7] : f32 into vector<16x9xf32>
%143 = vector.extract %74[8, 2] : vector<9x16xf32>
%144 = vector.insert %143, %142 [2, 8] : f32 into vector<16x9xf32>
%145 = vector.extract %74[0, 3] : vector<9x16xf32>
%146 = vector.insert %145, %144 [3, 0] : f32 into vector<16x9xf32>
%147 = vector.extract %74[1, 3] : vector<9x16xf32>
%148 = vector.insert %147, %146 [3, 1] : f32 into vector<16x9xf32>
%149 = vector.extract %74[2, 3] : vector<9x16xf32>
%150 = vector.insert %149, %148 [3, 2] : f32 into vector<16x9xf32>
%151 = vector.extract %74[3, 3] : vector<9x16xf32>
%152 = vector.insert %151, %150 [3, 3] : f32 into vector<16x9xf32>
%153 = vector.extract %74[4, 3] : vector<9x16xf32>
%154 = vector.insert %153, %152 [3, 4] : f32 into vector<16x9xf32>
%155 = vector.extract %74[5, 3] : vector<9x16xf32>
%156 = vector.insert %155, %154 [3, 5] : f32 into vector<16x9xf32>
%157 = vector.extract %74[6, 3] : vector<9x16xf32>
%158 = vector.insert %157, %156 [3, 6] : f32 into vector<16x9xf32>
%159 = vector.extract %74[7, 3] : vector<9x16xf32>
%160 = vector.insert %159, %158 [3, 7] : f32 into vector<16x9xf32>
%161 = vector.extract %74[8, 3] : vector<9x16xf32>
%162 = vector.insert %161, %160 [3, 8] : f32 into vector<16x9xf32>
%163 = vector.extract %74[0, 4] : vector<9x16xf32>
%164 = vector.insert %163, %162 [4, 0] : f32 into vector<16x9xf32>
%165 = vector.extract %74[1, 4] : vector<9x16xf32>
%166 = vector.insert %165, %164 [4, 1] : f32 into vector<16x9xf32>
%167 = vector.extract %74[2, 4] : vector<9x16xf32>
%168 = vector.insert %167, %166 [4, 2] : f32 into vector<16x9xf32>
%169 = vector.extract %74[3, 4] : vector<9x16xf32>
%170 = vector.insert %169, %168 [4, 3] : f32 into vector<16x9xf32>
%171 = vector.extract %74[4, 4] : vector<9x16xf32>
%172 = vector.insert %171, %170 [4, 4] : f32 into vector<16x9xf32>
%173 = vector.extract %74[5, 4] : vector<9x16xf32>
%174 = vector.insert %173, %172 [4, 5] : f32 into vector<16x9xf32>
%175 = vector.extract %74[6, 4] : vector<9x16xf32>
%176 = vector.insert %175, %174 [4, 6] : f32 into vector<16x9xf32>
%177 = vector.extract %74[7, 4] : vector<9x16xf32>
%178 = vector.insert %177, %176 [4, 7] : f32 into vector<16x9xf32>
%179 = vector.extract %74[8, 4] : vector<9x16xf32>
%180 = vector.insert %179, %178 [4, 8] : f32 into vector<16x9xf32>
%181 = vector.extract %74[0, 5] : vector<9x16xf32>
%182 = vector.insert %181, %180 [5, 0] : f32 into vector<16x9xf32>
%183 = vector.extract %74[1, 5] : vector<9x16xf32>
%184 = vector.insert %183, %182 [5, 1] : f32 into vector<16x9xf32>
%185 = vector.extract %74[2, 5] : vector<9x16xf32>
%186 = vector.insert %185, %184 [5, 2] : f32 into vector<16x9xf32>
%187 = vector.extract %74[3, 5] : vector<9x16xf32>
%188 = vector.insert %187, %186 [5, 3] : f32 into vector<16x9xf32>
%189 = vector.extract %74[4, 5] : vector<9x16xf32>
%190 = vector.insert %189, %188 [5, 4] : f32 into vector<16x9xf32>
%191 = vector.extract %74[5, 5] : vector<9x16xf32>
%192 = vector.insert %191, %190 [5, 5] : f32 into vector<16x9xf32>
%193 = vector.extract %74[6, 5] : vector<9x16xf32>
%194 = vector.insert %193, %192 [5, 6] : f32 into vector<16x9xf32>
%195 = vector.extract %74[7, 5] : vector<9x16xf32>
%196 = vector.insert %195, %194 [5, 7] : f32 into vector<16x9xf32>
%197 = vector.extract %74[8, 5] : vector<9x16xf32>
%198 = vector.insert %197, %196 [5, 8] : f32 into vector<16x9xf32>
%199 = vector.extract %74[0, 6] : vector<9x16xf32>
%200 = vector.insert %199, %198 [6, 0] : f32 into vector<16x9xf32>
%201 = vector.extract %74[1, 6] : vector<9x16xf32>
%202 = vector.insert %201, %200 [6, 1] : f32 into vector<16x9xf32>
%203 = vector.extract %74[2, 6] : vector<9x16xf32>
%204 = vector.insert %203, %202 [6, 2] : f32 into vector<16x9xf32>
%205 = vector.extract %74[3, 6] : vector<9x16xf32>
%206 = vector.insert %205, %204 [6, 3] : f32 into vector<16x9xf32>
%207 = vector.extract %74[4, 6] : vector<9x16xf32>
%208 = vector.insert %207, %206 [6, 4] : f32 into vector<16x9xf32>
%209 = vector.extract %74[5, 6] : vector<9x16xf32>
%210 = vector.insert %209, %208 [6, 5] : f32 into vector<16x9xf32>
%211 = vector.extract %74[6, 6] : vector<9x16xf32>
%212 = vector.insert %211, %210 [6, 6] : f32 into vector<16x9xf32>
%213 = vector.extract %74[7, 6] : vector<9x16xf32>
%214 = vector.insert %213, %212 [6, 7] : f32 into vector<16x9xf32>
%215 = vector.extract %74[8, 6] : vector<9x16xf32>
%216 = vector.insert %215, %214 [6, 8] : f32 into vector<16x9xf32>
%217 = vector.extract %74[0, 7] : vector<9x16xf32>
%218 = vector.insert %217, %216 [7, 0] : f32 into vector<16x9xf32>
%219 = vector.extract %74[1, 7] : vector<9x16xf32>
%220 = vector.insert %219, %218 [7, 1] : f32 into vector<16x9xf32>
%221 = vector.extract %74[2, 7] : vector<9x16xf32>
%222 = vector.insert %221, %220 [7, 2] : f32 into vector<16x9xf32>
%223 = vector.extract %74[3, 7] : vector<9x16xf32>
%224 = vector.insert %223, %222 [7, 3] : f32 into vector<16x9xf32>
%225 = vector.extract %74[4, 7] : vector<9x16xf32>
%226 = vector.insert %225, %224 [7, 4] : f32 into vector<16x9xf32>
%227 = vector.extract %74[5, 7] : vector<9x16xf32>
%228 = vector.insert %227, %226 [7, 5] : f32 into vector<16x9xf32>
%229 = vector.extract %74[6, 7] : vector<9x16xf32>
%230 = vector.insert %229, %228 [7, 6] : f32 into vector<16x9xf32>
%231 = vector.extract %74[7, 7] : vector<9x16xf32>
%232 = vector.insert %231, %230 [7, 7] : f32 into vector<16x9xf32>
%233 = vector.extract %74[8, 7] : vector<9x16xf32>
%234 = vector.insert %233, %232 [7, 8] : f32 into vector<16x9xf32>
%235 = vector.extract %74[0, 8] : vector<9x16xf32>
%236 = vector.insert %235, %234 [8, 0] : f32 into vector<16x9xf32>
%237 = vector.extract %74[1, 8] : vector<9x16xf32>
%238 = vector.insert %237, %236 [8, 1] : f32 into vector<16x9xf32>
%239 = vector.extract %74[2, 8] : vector<9x16xf32>
%240 = vector.insert %239, %238 [8, 2] : f32 into vector<16x9xf32>
%241 = vector.extract %74[3, 8] : vector<9x16xf32>
%242 = vector.insert %241, %240 [8, 3] : f32 into vector<16x9xf32>
%243 = vector.extract %74[4, 8] : vector<9x16xf32>
%244 = vector.insert %243, %242 [8, 4] : f32 into vector<16x9xf32>
%245 = vector.extract %74[5, 8] : vector<9x16xf32>
%246 = vector.insert %245, %244 [8, 5] : f32 into vector<16x9xf32>
%247 = vector.extract %74[6, 8] : vector<9x16xf32>
%248 = vector.insert %247, %246 [8, 6] : f32 into vector<16x9xf32>
%249 = vector.extract %74[7, 8] : vector<9x16xf32>
%250 = vector.insert %249, %248 [8, 7] : f32 into vector<16x9xf32>
%251 = vector.extract %74[8, 8] : vector<9x16xf32>
%252 = vector.insert %251, %250 [8, 8] : f32 into vector<16x9xf32>
%253 = vector.extract %74[0, 9] : vector<9x16xf32>
%254 = vector.insert %253, %252 [9, 0] : f32 into vector<16x9xf32>
%255 = vector.extract %74[1, 9] : vector<9x16xf32>
%256 = vector.insert %255, %254 [9, 1] : f32 into vector<16x9xf32>
%257 = vector.extract %74[2, 9] : vector<9x16xf32>
%258 = vector.insert %257, %256 [9, 2] : f32 into vector<16x9xf32>
%259 = vector.extract %74[3, 9] : vector<9x16xf32>
%260 = vector.insert %259, %258 [9, 3] : f32 into vector<16x9xf32>
%261 = vector.extract %74[4, 9] : vector<9x16xf32>
%262 = vector.insert %261, %260 [9, 4] : f32 into vector<16x9xf32>
%263 = vector.extract %74[5, 9] : vector<9x16xf32>
%264 = vector.insert %263, %262 [9, 5] : f32 into vector<16x9xf32>
%265 = vector.extract %74[6, 9] : vector<9x16xf32>
%266 = vector.insert %265, %264 [9, 6] : f32 into vector<16x9xf32>
%267 = vector.extract %74[7, 9] : vector<9x16xf32>
%268 = vector.insert %267, %266 [9, 7] : f32 into vector<16x9xf32>
%269 = vector.extract %74[8, 9] : vector<9x16xf32>
%270 = vector.insert %269, %268 [9, 8] : f32 into vector<16x9xf32>
%271 = vector.extract %74[0, 10] : vector<9x16xf32>
%272 = vector.insert %271, %270 [10, 0] : f32 into vector<16x9xf32>
%273 = vector.extract %74[1, 10] : vector<9x16xf32>
%274 = vector.insert %273, %272 [10, 1] : f32 into vector<16x9xf32>
%275 = vector.extract %74[2, 10] : vector<9x16xf32>
%276 = vector.insert %275, %274 [10, 2] : f32 into vector<16x9xf32>
%277 = vector.extract %74[3, 10] : vector<9x16xf32>
%278 = vector.insert %277, %276 [10, 3] : f32 into vector<16x9xf32>
%279 = vector.extract %74[4, 10] : vector<9x16xf32>
%280 = vector.insert %279, %278 [10, 4] : f32 into vector<16x9xf32>
%281 = vector.extract %74[5, 10] : vector<9x16xf32>
%282 = vector.insert %281, %280 [10, 5] : f32 into vector<16x9xf32>
%283 = vector.extract %74[6, 10] : vector<9x16xf32>
%284 = vector.insert %283, %282 [10, 6] : f32 into vector<16x9xf32>
%285 = vector.extract %74[7, 10] : vector<9x16xf32>
%286 = vector.insert %285, %284 [10, 7] : f32 into vector<16x9xf32>
%287 = vector.extract %74[8, 10] : vector<9x16xf32>
%288 = vector.insert %287, %286 [10, 8] : f32 into vector<16x9xf32>
%289 = vector.extract %74[0, 11] : vector<9x16xf32>
%290 = vector.insert %289, %288 [11, 0] : f32 into vector<16x9xf32>
%291 = vector.extract %74[1, 11] : vector<9x16xf32>
%292 = vector.insert %291, %290 [11, 1] : f32 into vector<16x9xf32>
%293 = vector.extract %74[2, 11] : vector<9x16xf32>
%294 = vector.insert %293, %292 [11, 2] : f32 into vector<16x9xf32>
%295 = vector.extract %74[3, 11] : vector<9x16xf32>
%296 = vector.insert %295, %294 [11, 3] : f32 into vector<16x9xf32>
%297 = vector.extract %74[4, 11] : vector<9x16xf32>
%298 = vector.insert %297, %296 [11, 4] : f32 into vector<16x9xf32>
%299 = vector.extract %74[5, 11] : vector<9x16xf32>
%300 = vector.insert %299, %298 [11, 5] : f32 into vector<16x9xf32>
%301 = vector.extract %74[6, 11] : vector<9x16xf32>
%302 = vector.insert %301, %300 [11, 6] : f32 into vector<16x9xf32>
%303 = vector.extract %74[7, 11] : vector<9x16xf32>
%304 = vector.insert %303, %302 [11, 7] : f32 into vector<16x9xf32>
%305 = vector.extract %74[8, 11] : vector<9x16xf32>
%306 = vector.insert %305, %304 [11, 8] : f32 into vector<16x9xf32>
%307 = vector.extract %74[0, 12] : vector<9x16xf32>
%308 = vector.insert %307, %306 [12, 0] : f32 into vector<16x9xf32>
%309 = vector.extract %74[1, 12] : vector<9x16xf32>
%310 = vector.insert %309, %308 [12, 1] : f32 into vector<16x9xf32>
%311 = vector.extract %74[2, 12] : vector<9x16xf32>
%312 = vector.insert %311, %310 [12, 2] : f32 into vector<16x9xf32>
%313 = vector.extract %74[3, 12] : vector<9x16xf32>
%314 = vector.insert %313, %312 [12, 3] : f32 into vector<16x9xf32>
%315 = vector.extract %74[4, 12] : vector<9x16xf32>
%316 = vector.insert %315, %314 [12, 4] : f32 into vector<16x9xf32>
%317 = vector.extract %74[5, 12] : vector<9x16xf32>
%318 = vector.insert %317, %316 [12, 5] : f32 into vector<16x9xf32>
%319 = vector.extract %74[6, 12] : vector<9x16xf32>
%320 = vector.insert %319, %318 [12, 6] : f32 into vector<16x9xf32>
%321 = vector.extract %74[7, 12] : vector<9x16xf32>
%322 = vector.insert %321, %320 [12, 7] : f32 into vector<16x9xf32>
%323 = vector.extract %74[8, 12] : vector<9x16xf32>
%324 = vector.insert %323, %322 [12, 8] : f32 into vector<16x9xf32>
%325 = vector.extract %74[0, 13] : vector<9x16xf32>
%326 = vector.insert %325, %324 [13, 0] : f32 into vector<16x9xf32>
%327 = vector.extract %74[1, 13] : vector<9x16xf32>
%328 = vector.insert %327, %326 [13, 1] : f32 into vector<16x9xf32>
%329 = vector.extract %74[2, 13] : vector<9x16xf32>
%330 = vector.insert %329, %328 [13, 2] : f32 into vector<16x9xf32>
%331 = vector.extract %74[3, 13] : vector<9x16xf32>
%332 = vector.insert %331, %330 [13, 3] : f32 into vector<16x9xf32>
%333 = vector.extract %74[4, 13] : vector<9x16xf32>
%334 = vector.insert %333, %332 [13, 4] : f32 into vector<16x9xf32>
%335 = vector.extract %74[5, 13] : vector<9x16xf32>
%336 = vector.insert %335, %334 [13, 5] : f32 into vector<16x9xf32>
%337 = vector.extract %74[6, 13] : vector<9x16xf32>
%338 = vector.insert %337, %336 [13, 6] : f32 into vector<16x9xf32>
%339 = vector.extract %74[7, 13] : vector<9x16xf32>
%340 = vector.insert %339, %338 [13, 7] : f32 into vector<16x9xf32>
%341 = vector.extract %74[8, 13] : vector<9x16xf32>
%342 = vector.insert %341, %340 [13, 8] : f32 into vector<16x9xf32>
%343 = vector.extract %74[0, 14] : vector<9x16xf32>
%344 = vector.insert %343, %342 [14, 0] : f32 into vector<16x9xf32>
%345 = vector.extract %74[1, 14] : vector<9x16xf32>
%346 = vector.insert %345, %344 [14, 1] : f32 into vector<16x9xf32>
%347 = vector.extract %74[2, 14] : vector<9x16xf32>
%348 = vector.insert %347, %346 [14, 2] : f32 into vector<16x9xf32>
%349 = vector.extract %74[3, 14] : vector<9x16xf32>
%350 = vector.insert %349, %348 [14, 3] : f32 into vector<16x9xf32>
%351 = vector.extract %74[4, 14] : vector<9x16xf32>
%352 = vector.insert %351, %350 [14, 4] : f32 into vector<16x9xf32>
%353 = vector.extract %74[5, 14] : vector<9x16xf32>
%354 = vector.insert %353, %352 [14, 5] : f32 into vector<16x9xf32>
%355 = vector.extract %74[6, 14] : vector<9x16xf32>
%356 = vector.insert %355, %354 [14, 6] : f32 into vector<16x9xf32>
%357 = vector.extract %74[7, 14] : vector<9x16xf32>
%358 = vector.insert %357, %356 [14, 7] : f32 into vector<16x9xf32>
%359 = vector.extract %74[8, 14] : vector<9x16xf32>
%360 = vector.insert %359, %358 [14, 8] : f32 into vector<16x9xf32>
%361 = vector.extract %74[0, 15] : vector<9x16xf32>
%362 = vector.insert %361, %360 [15, 0] : f32 into vector<16x9xf32>
%363 = vector.extract %74[1, 15] : vector<9x16xf32>
%364 = vector.insert %363, %362 [15, 1] : f32 into vector<16x9xf32>
%365 = vector.extract %74[2, 15] : vector<9x16xf32>
%366 = vector.insert %365, %364 [15, 2] : f32 into vector<16x9xf32>
%367 = vector.extract %74[3, 15] : vector<9x16xf32>
%368 = vector.insert %367, %366 [15, 3] : f32 into vector<16x9xf32>
%369 = vector.extract %74[4, 15] : vector<9x16xf32>
%370 = vector.insert %369, %368 [15, 4] : f32 into vector<16x9xf32>
%371 = vector.extract %74[5, 15] : vector<9x16xf32>
%372 = vector.insert %371, %370 [15, 5] : f32 into vector<16x9xf32>
%373 = vector.extract %74[6, 15] : vector<9x16xf32>
%374 = vector.insert %373, %372 [15, 6] : f32 into vector<16x9xf32>
%375 = vector.extract %74[7, 15] : vector<9x16xf32>
%376 = vector.insert %375, %374 [15, 7] : f32 into vector<16x9xf32>
%377 = vector.extract %74[8, 15] : vector<9x16xf32>
%378 = vector.insert %377, %376 [15, 8] : f32 into vector<16x9xf32>
%379 = vector.extract %378[0] : vector<16x9xf32>
%380 = vector.outerproduct %379, %75, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%381 = vector.extract %378[1] : vector<16x9xf32>
%382 = vector.outerproduct %381, %76, %380 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%383 = vector.extract %378[2] : vector<16x9xf32>
%384 = vector.outerproduct %383, %77, %382 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%385 = vector.extract %378[3] : vector<16x9xf32>
%386 = vector.outerproduct %385, %78, %384 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%387 = vector.extract %378[4] : vector<16x9xf32>
%388 = vector.outerproduct %387, %79, %386 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%389 = vector.extract %378[5] : vector<16x9xf32>
%390 = vector.outerproduct %389, %80, %388 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%391 = vector.extract %378[6] : vector<16x9xf32>
%392 = vector.outerproduct %391, %81, %390 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%393 = vector.extract %378[7] : vector<16x9xf32>
%394 = vector.outerproduct %393, %82, %392 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%395 = vector.extract %378[8] : vector<16x9xf32>
%396 = vector.outerproduct %395, %83, %394 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%397 = vector.extract %378[9] : vector<16x9xf32>
%398 = vector.outerproduct %397, %84, %396 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%399 = vector.extract %378[10] : vector<16x9xf32>
%400 = vector.outerproduct %399, %85, %398 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%401 = vector.extract %378[11] : vector<16x9xf32>
%402 = vector.outerproduct %401, %86, %400 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%403 = vector.extract %378[12] : vector<16x9xf32>
%404 = vector.outerproduct %403, %87, %402 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%405 = vector.extract %378[13] : vector<16x9xf32>
%406 = vector.outerproduct %405, %88, %404 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%407 = vector.extract %378[14] : vector<16x9xf32>
%408 = vector.outerproduct %407, %89, %406 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%409 = vector.extract %378[15] : vector<16x9xf32>
%410 = vector.outerproduct %409, %90, %408 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %410 : vector<9x32xf32>
}
%45 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%56 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%46 = vector.extract %44[0] : vector<9x32xf32>
vector.store %46, %45[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %44[1] : vector<9x32xf32>
vector.store %47, %45[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %44[2] : vector<9x32xf32>
vector.store %48, %45[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %44[3] : vector<9x32xf32>
vector.store %49, %45[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %44[4] : vector<9x32xf32>
vector.store %50, %45[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %44[5] : vector<9x32xf32>
vector.store %51, %45[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %44[6] : vector<9x32xf32>
vector.store %52, %45[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = vector.extract %44[7] : vector<9x32xf32>
vector.store %53, %45[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%54 = vector.extract %44[8] : vector<9x32xf32>
vector.store %54, %45[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%55 = arith.xori %24, %true : i1
scf.if %55 {
%56 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%57 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%56, %57) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
compilation in 0.2461s
xxxxxxxxxx : 10 iters time on 1 threads in 0.1642s per iter sec (103.6 GFlop/s, 0.3045 GB/s) total time 1.642s
###############################################################
Runtime problem size {'M': 2048, 'N': 2048, 'K': 2048}
Compile-time problem size {'M': 2048, 'N': 2048, 'K': 2048}
Problem types [<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>]
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43d6affd0>]]]
#map = affine_map<(d0) -> (288, -d0 + 2048)>
module {
func @matmul_on_tensors(%arg0: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<2048x2048xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg2) : f32, tensor<2048x2048xf32> -> tensor<2048x2048xf32>
%1 = scf.for %arg3 = %c0 to %c2048 step %c288 iter_args(%arg4 = %0) -> (tensor<2048x2048xf32>) {
%2 = affine.min #map(%arg3)
%3 = scf.for %arg5 = %c0 to %c2048 step %c512 iter_args(%arg6 = %arg4) -> (tensor<2048x2048xf32>) {
%4 = tensor.extract_slice %arg0[%arg3, %arg5] [%2, 512] [1, 1] : tensor<2048x2048xf32> to tensor<?x512xf32>
%5 = scf.for %arg7 = %c0 to %c2048 step %c128 iter_args(%arg8 = %arg6) -> (tensor<2048x2048xf32>) {
%6 = tensor.extract_slice %arg1[%arg5, %arg7] [512, 128] [1, 1] : tensor<2048x2048xf32> to tensor<512x128xf32>
%7 = tensor.extract_slice %arg8[%arg3, %arg7] [%2, 128] [1, 1] : tensor<2048x2048xf32> to tensor<?x128xf32>
%8 = linalg.matmul ins(%4, %6 : tensor<?x512xf32>, tensor<512x128xf32>) outs(%7 : tensor<?x128xf32>) -> tensor<?x128xf32>
%9 = tensor.insert_slice %8 into %arg8[%arg3, %arg7] [%2, 128] [1, 1] : tensor<?x128xf32> into tensor<2048x2048xf32>
scf.yield %9 : tensor<2048x2048xf32>
}
scf.yield %5 : tensor<2048x2048xf32>
}
scf.yield %3 : tensor<2048x2048xf32>
}
return %1 : tensor<2048x2048xf32>
}
func public @matmul_main(%arg0: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<2048x2048xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<2048x2048xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<2048x2048xf32>, tensor<2048x2048xf32>, tensor<2048x2048xf32>) -> tensor<2048x2048xf32>
scf.yield %1 : tensor<2048x2048xf32>
}
return %0 : tensor<2048x2048xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43b6dbc50>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0) -> (-d0 + 9)>
module {
func @matmul_on_tensors(%arg0: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<2048x2048xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<2048x2048xf32> -> tensor<2048x2048xf32>
%1 = linalg.init_tensor [4, 16, 4, 32, 16, 32] : tensor<4x16x4x32x16x32xf32>
%2 = tensor.cast %1 : tensor<4x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%3 = scf.for %arg3 = %c0 to %c2048 step %c512 iter_args(%arg4 = %2) -> (tensor<?x?x?x?x16x32xf32>) {
%7 = affine.apply #map0(%arg3)
%8 = scf.for %arg5 = %c0 to %c2048 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%9 = affine.apply #map1(%arg5)
%10 = scf.for %arg7 = %c0 to %c128 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%11 = affine.apply #map2(%arg7)
%12 = affine.apply #map3(%arg7, %arg5)
%13 = scf.for %arg9 = %c0 to %c512 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%14 = affine.apply #map4(%arg9)
%15 = affine.apply #map3(%arg9, %arg3)
%16 = tensor.extract_slice %arg1[%15, %12] [16, 32] [1, 1] : tensor<2048x2048xf32> to tensor<16x32xf32>
%17 = linalg.pad_tensor %16 nofold low[%c0, %c0] high[%c0, %c0] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<16x32xf32> to tensor<16x32xf32>
%18 = tensor.insert_slice %17 into %arg10[%7, %9, %11, %14, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<?x?x?x?x16x32xf32>
scf.yield %18 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %13 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %10 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %8 : tensor<?x?x?x?x16x32xf32>
}
%4 = linalg.init_tensor [4, 32, 32, 9, 16] : tensor<4x32x32x9x16xf32>
%5 = tensor.cast %4 : tensor<4x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%6 = scf.for %arg3 = %c0 to %c2048 step %c288 iter_args(%arg4 = %0) -> (tensor<2048x2048xf32>) {
%7 = affine.min #map5(%arg3)
%8 = scf.for %arg5 = %c0 to %c2048 step %c512 iter_args(%arg6 = %5) -> (tensor<?x?x?x9x16xf32>) {
%10 = affine.apply #map0(%arg5)
%11 = scf.for %arg7 = %c0 to %7 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%12 = affine.apply #map6(%arg7)
%13 = affine.apply #map3(%arg7, %arg3)
%14 = affine.min #map7(%arg7, %7)
%15 = affine.apply #map8(%14)
%16 = scf.for %arg9 = %c0 to %c512 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%17 = affine.apply #map4(%arg9)
%18 = affine.apply #map3(%arg9, %arg5)
%19 = tensor.extract_slice %arg0[%13, %18] [%14, 16] [1, 1] : tensor<2048x2048xf32> to tensor<?x16xf32>
%20 = linalg.pad_tensor %19 nofold low[%c0, %c0] high[%15, %c0] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x16xf32> to tensor<9x16xf32>
%21 = tensor.insert_slice %20 into %arg10[%10, %12, %17, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<9x16xf32> into tensor<?x?x?x9x16xf32>
scf.yield %21 : tensor<?x?x?x9x16xf32>
}
scf.yield %16 : tensor<?x?x?x9x16xf32>
}
scf.yield %11 : tensor<?x?x?x9x16xf32>
}
%9 = scf.for %arg5 = %c0 to %c2048 step %c512 iter_args(%arg6 = %arg4) -> (tensor<2048x2048xf32>) {
%10 = affine.apply #map0(%arg5)
%11 = scf.for %arg7 = %c0 to %c2048 step %c128 iter_args(%arg8 = %arg6) -> (tensor<2048x2048xf32>) {
%12 = tensor.extract_slice %arg8[%arg3, %arg7] [%7, 128] [1, 1] : tensor<2048x2048xf32> to tensor<?x128xf32>
%13 = affine.apply #map1(%arg7)
%14 = scf.for %arg9 = %c0 to %7 step %c9 iter_args(%arg10 = %12) -> (tensor<?x128xf32>) {
%16 = affine.min #map7(%arg9, %7)
%17 = affine.apply #map6(%arg9)
%18 = affine.apply #map8(%16)
%19 = scf.for %arg11 = %c0 to %c128 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x128xf32>) {
%20 = affine.apply #map2(%arg11)
%21 = scf.for %arg13 = %c0 to %c512 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x128xf32>) {
%22 = tensor.extract_slice %arg14[%arg9, %arg11] [%16, 32] [1, 1] : tensor<?x128xf32> to tensor<?x32xf32>
%23 = affine.apply #map4(%arg13)
%24 = tensor.extract_slice %8[%10, %17, %23, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<?x?x?x9x16xf32> to tensor<9x16xf32>
%25 = tensor.extract_slice %3[%10, %13, %20, %23, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<?x?x?x?x16x32xf32> to tensor<16x32xf32>
%26 = linalg.pad_tensor %22 low[%c0, %c0] high[%18, %c0] {
^bb0(%arg15: index, %arg16: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x32xf32> to tensor<9x32xf32>
%27 = linalg.matmul ins(%24, %25 : tensor<9x16xf32>, tensor<16x32xf32>) outs(%26 : tensor<9x32xf32>) -> tensor<9x32xf32>
%28 = tensor.extract_slice %27[0, 0] [%16, 32] [1, 1] : tensor<9x32xf32> to tensor<?x32xf32>
%29 = tensor.insert_slice %28 into %arg14[%arg9, %arg11] [%16, 32] [1, 1] : tensor<?x32xf32> into tensor<?x128xf32>
scf.yield %29 : tensor<?x128xf32>
}
scf.yield %21 : tensor<?x128xf32>
}
scf.yield %19 : tensor<?x128xf32>
}
%15 = tensor.insert_slice %14 into %arg8[%arg3, %arg7] [%7, 128] [1, 1] : tensor<?x128xf32> into tensor<2048x2048xf32>
scf.yield %15 : tensor<2048x2048xf32>
}
scf.yield %11 : tensor<2048x2048xf32>
}
scf.yield %9 : tensor<2048x2048xf32>
}
return %6 : tensor<2048x2048xf32>
}
func public @matmul_main(%arg0: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<2048x2048xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<2048x2048xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<2048x2048xf32>, tensor<2048x2048xf32>, tensor<2048x2048xf32>) -> tensor<2048x2048xf32>
scf.yield %1 : tensor<2048x2048xf32>
}
return %0 : tensor<2048x2048xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Vectorize object at 0x7fb43b6bdc10>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map9 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map10 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<2048x2048xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<2048x2048xf32> -> tensor<2048x2048xf32>
%1 = linalg.init_tensor [4, 16, 4, 32, 16, 32] : tensor<4x16x4x32x16x32xf32>
%2 = tensor.cast %1 : tensor<4x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%3 = scf.for %arg3 = %c0 to %c2048 step %c512 iter_args(%arg4 = %2) -> (tensor<?x?x?x?x16x32xf32>) {
%7 = affine.apply #map0(%arg3)
%8 = scf.for %arg5 = %c0 to %c2048 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%9 = affine.apply #map1(%arg5)
%10 = scf.for %arg7 = %c0 to %c128 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%11 = affine.apply #map2(%arg7)
%12 = affine.apply #map3(%arg7, %arg5)
%13 = scf.for %arg9 = %c0 to %c512 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%14 = affine.apply #map4(%arg9)
%15 = affine.apply #map3(%arg9, %arg3)
%16 = vector.transfer_read %arg1[%15, %12], %cst {in_bounds = [true, true]} : tensor<2048x2048xf32>, vector<16x32xf32>
%17 = vector.transfer_write %16, %arg10[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, tensor<?x?x?x?x16x32xf32>
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %13 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %10 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %8 : tensor<?x?x?x?x16x32xf32>
}
%4 = linalg.init_tensor [4, 32, 32, 9, 16] : tensor<4x32x32x9x16xf32>
%5 = tensor.cast %4 : tensor<4x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%6 = scf.for %arg3 = %c0 to %c2048 step %c288 iter_args(%arg4 = %0) -> (tensor<2048x2048xf32>) {
%7 = affine.min #map5(%arg3)
%8 = scf.for %arg5 = %c0 to %c2048 step %c512 iter_args(%arg6 = %5) -> (tensor<?x?x?x9x16xf32>) {
%10 = affine.apply #map0(%arg5)
%11 = scf.for %arg7 = %c0 to %7 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%12 = affine.apply #map6(%arg7)
%13 = affine.apply #map3(%arg7, %arg3)
%14 = affine.min #map7(%arg7, %7)
%15 = scf.for %arg9 = %c0 to %c512 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%16 = affine.apply #map4(%arg9)
%17 = affine.apply #map3(%arg9, %arg5)
%18 = tensor.extract_slice %arg0[%13, %17] [%14, 16] [1, 1] : tensor<2048x2048xf32> to tensor<?x16xf32>
%19 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [false, true]} : tensor<?x16xf32>, vector<9x16xf32>
%20 = vector.transfer_write %19, %arg10[%10, %12, %16, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, tensor<?x?x?x9x16xf32>
scf.yield %20 : tensor<?x?x?x9x16xf32>
}
scf.yield %15 : tensor<?x?x?x9x16xf32>
}
scf.yield %11 : tensor<?x?x?x9x16xf32>
}
%9 = scf.for %arg5 = %c0 to %c2048 step %c512 iter_args(%arg6 = %arg4) -> (tensor<2048x2048xf32>) {
%10 = affine.apply #map0(%arg5)
%11 = scf.for %arg7 = %c0 to %c2048 step %c128 iter_args(%arg8 = %arg6) -> (tensor<2048x2048xf32>) {
%12 = tensor.extract_slice %arg8[%arg3, %arg7] [%7, 128] [1, 1] : tensor<2048x2048xf32> to tensor<?x128xf32>
%13 = affine.apply #map1(%arg7)
%14 = scf.for %arg9 = %c0 to %7 step %c9 iter_args(%arg10 = %12) -> (tensor<?x128xf32>) {
%16 = affine.min #map7(%arg9, %7)
%17 = affine.apply #map6(%arg9)
%18 = scf.for %arg11 = %c0 to %c128 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x128xf32>) {
%19 = affine.apply #map2(%arg11)
%20 = scf.for %arg13 = %c0 to %c512 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x128xf32>) {
%21 = tensor.extract_slice %arg14[%arg9, %arg11] [%16, 32] [1, 1] : tensor<?x128xf32> to tensor<?x32xf32>
%22 = affine.apply #map4(%arg13)
%23 = vector.transfer_read %8[%10, %17, %22, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x9x16xf32>, vector<9x16xf32>
%24 = vector.transfer_read %3[%10, %13, %19, %22, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x?x16x32xf32>, vector<16x32xf32>
%25 = vector.transfer_read %21[%c0, %c0], %cst {in_bounds = [false, true]} : tensor<?x32xf32>, vector<9x32xf32>
%26 = vector.contract {indexing_maps = [#map8, #map9, #map10], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %23, %24, %25 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
%27 = vector.transfer_write %26, %21[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, tensor<?x32xf32>
%28 = tensor.insert_slice %27 into %arg14[%arg9, %arg11] [%16, 32] [1, 1] : tensor<?x32xf32> into tensor<?x128xf32>
scf.yield %28 : tensor<?x128xf32>
}
scf.yield %20 : tensor<?x128xf32>
}
scf.yield %18 : tensor<?x128xf32>
}
%15 = tensor.insert_slice %14 into %arg8[%arg3, %arg7] [%7, 128] [1, 1] : tensor<?x128xf32> into tensor<2048x2048xf32>
scf.yield %15 : tensor<2048x2048xf32>
}
scf.yield %11 : tensor<2048x2048xf32>
}
scf.yield %9 : tensor<2048x2048xf32>
}
return %6 : tensor<2048x2048xf32>
}
func public @matmul_main(%arg0: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<2048x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<2048x2048xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<2048x2048xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<2048x2048xf32>, tensor<2048x2048xf32>, tensor<2048x2048xf32>) -> tensor<2048x2048xf32>
scf.yield %1 : tensor<2048x2048xf32>
}
return %0 : tensor<2048x2048xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Bufferize object at 0x7fb43b692110>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map10 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map11 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%2 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%3 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%4 = affine.apply #map2(%arg5)
%5 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%6 = affine.apply #map4(%arg6)
%7 = affine.apply #map3(%arg6, %arg3)
%8 = vector.transfer_read %arg1[%7, %5], %cst {in_bounds = [true, true]} : memref<2048x2048xf32>, vector<16x32xf32>
vector.transfer_write %8, %1[%2, %3, %4, %6, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%2 = affine.min #map5(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%3 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %2 step %c9 {
%4 = affine.apply #map6(%arg5)
%5 = affine.apply #map3(%arg5, %arg3)
%6 = affine.min #map7(%arg5, %2)
scf.for %arg6 = %c0 to %c512 step %c16 {
%7 = affine.apply #map4(%arg6)
%8 = affine.apply #map3(%arg6, %arg4)
%9 = memref.subview %arg0[%5, %8] [%6, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map8>
%10 = vector.transfer_read %9[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x16xf32, #map8>, vector<9x16xf32>
vector.transfer_write %10, %0[%3, %4, %7, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%3 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%4 = memref.subview %arg2[%arg3, %arg5] [%2, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map8>
%5 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %2 step %c9 {
%6 = affine.min #map7(%arg6, %2)
%7 = affine.apply #map6(%arg6)
scf.for %arg7 = %c0 to %c128 step %c32 {
%8 = affine.apply #map2(%arg7)
%9 = memref.subview %4[%arg6, %arg7] [%6, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%10 = vector.transfer_read %9[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<9x32xf32>
%11 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %10) -> (vector<9x32xf32>) {
%12 = affine.apply #map4(%arg8)
%13 = vector.transfer_read %0[%3, %7, %12, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%14 = vector.transfer_read %1[%3, %5, %8, %12, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%15 = vector.contract {indexing_maps = [#map9, #map10, #map11], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %13, %14, %arg9 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
scf.yield %15 : vector<9x32xf32>
}
vector.transfer_write %11, %9[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, memref<?x32xf32, #map8>
}
}
}
}
}
memref.dealloc %1 : memref<4x16x4x32x16x32xf32>
memref.dealloc %0 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692190>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%2 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%3 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%4 = affine.apply #map2(%arg5)
%5 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%6 = affine.apply #map4(%arg6)
%7 = affine.apply #map3(%arg6, %arg3)
%8 = vector.transfer_read %arg1[%7, %5], %cst {in_bounds = [true, true]} : memref<2048x2048xf32>, vector<16x32xf32>
vector.transfer_write %8, %1[%2, %3, %4, %6, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%2 = affine.min #map5(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%3 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %2 step %c9 {
%4 = affine.apply #map6(%arg5)
%5 = affine.apply #map3(%arg5, %arg3)
%6 = affine.min #map7(%arg5, %2)
scf.for %arg6 = %c0 to %c512 step %c16 {
%7 = affine.apply #map4(%arg6)
%8 = affine.apply #map3(%arg6, %arg4)
%9 = memref.subview %arg0[%5, %8] [%6, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map8>
%10 = vector.transfer_read %9[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x16xf32, #map8>, vector<9x16xf32>
vector.transfer_write %10, %0[%3, %4, %7, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%3 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%4 = memref.subview %arg2[%arg3, %arg5] [%2, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map8>
%5 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %2 step %c9 {
%6 = affine.min #map7(%arg6, %2)
%7 = affine.apply #map6(%arg6)
scf.for %arg7 = %c0 to %c128 step %c32 {
%8 = affine.apply #map2(%arg7)
%9 = memref.subview %4[%arg6, %arg7] [%6, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%10 = vector.transfer_read %9[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<9x32xf32>
%11 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %10) -> (vector<9x32xf32>) {
%12 = affine.apply #map4(%arg8)
%13 = vector.transfer_read %0[%3, %7, %12, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%14 = vector.transfer_read %1[%3, %5, %8, %12, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%15 = vector.transpose %13, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%16 = vector.extract %15[0] : vector<16x9xf32>
%17 = vector.extract %14[0] : vector<16x32xf32>
%18 = vector.outerproduct %16, %17, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%19 = vector.extract %15[1] : vector<16x9xf32>
%20 = vector.extract %14[1] : vector<16x32xf32>
%21 = vector.outerproduct %19, %20, %18 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%22 = vector.extract %15[2] : vector<16x9xf32>
%23 = vector.extract %14[2] : vector<16x32xf32>
%24 = vector.outerproduct %22, %23, %21 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%25 = vector.extract %15[3] : vector<16x9xf32>
%26 = vector.extract %14[3] : vector<16x32xf32>
%27 = vector.outerproduct %25, %26, %24 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%28 = vector.extract %15[4] : vector<16x9xf32>
%29 = vector.extract %14[4] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %27 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %15[5] : vector<16x9xf32>
%32 = vector.extract %14[5] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %15[6] : vector<16x9xf32>
%35 = vector.extract %14[6] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %15[7] : vector<16x9xf32>
%38 = vector.extract %14[7] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %15[8] : vector<16x9xf32>
%41 = vector.extract %14[8] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %15[9] : vector<16x9xf32>
%44 = vector.extract %14[9] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %15[10] : vector<16x9xf32>
%47 = vector.extract %14[10] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %15[11] : vector<16x9xf32>
%50 = vector.extract %14[11] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %15[12] : vector<16x9xf32>
%53 = vector.extract %14[12] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %15[13] : vector<16x9xf32>
%56 = vector.extract %14[13] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %15[14] : vector<16x9xf32>
%59 = vector.extract %14[14] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %15[15] : vector<16x9xf32>
%62 = vector.extract %14[15] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %63 : vector<9x32xf32>
}
vector.transfer_write %11, %9[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, memref<?x32xf32, #map8>
}
}
}
}
}
memref.dealloc %1 : memref<4x16x4x32x16x32xf32>
memref.dealloc %0 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6923d0>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%2 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%3 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%4 = affine.apply #map2(%arg5)
%5 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%6 = affine.apply #map4(%arg6)
%7 = affine.apply #map3(%arg6, %arg3)
%8 = vector.transfer_read %arg1[%7, %5], %cst {in_bounds = [true, true]} : memref<2048x2048xf32>, vector<16x32xf32>
vector.transfer_write %8, %1[%2, %3, %4, %6, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%2 = affine.min #map5(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%3 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %2 step %c9 {
%4 = affine.apply #map6(%arg5)
%5 = affine.apply #map3(%arg5, %arg3)
%6 = affine.min #map7(%arg5, %2)
scf.for %arg6 = %c0 to %c512 step %c16 {
%7 = affine.apply #map4(%arg6)
%8 = affine.apply #map3(%arg6, %arg4)
%9 = memref.subview %arg0[%5, %8] [%6, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map8>
%10 = vector.transfer_read %9[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x16xf32, #map8>, vector<9x16xf32>
vector.transfer_write %10, %0[%3, %4, %7, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%3 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%4 = memref.subview %arg2[%arg3, %arg5] [%2, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map8>
%5 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %2 step %c9 {
%6 = affine.min #map7(%arg6, %2)
%7 = affine.apply #map6(%arg6)
scf.for %arg7 = %c0 to %c128 step %c32 {
%8 = affine.apply #map2(%arg7)
%9 = memref.subview %4[%arg6, %arg7] [%6, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%10 = vector.transfer_read %9[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<9x32xf32>
%11 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %10) -> (vector<9x32xf32>) {
%12 = affine.apply #map4(%arg8)
%13 = vector.transfer_read %0[%3, %7, %12, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%14 = vector.transfer_read %1[%3, %5, %8, %12, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%15 = vector.transpose %13, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%16 = vector.extract %15[0] : vector<16x9xf32>
%17 = vector.extract %14[0] : vector<16x32xf32>
%18 = vector.outerproduct %16, %17, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%19 = vector.extract %15[1] : vector<16x9xf32>
%20 = vector.extract %14[1] : vector<16x32xf32>
%21 = vector.outerproduct %19, %20, %18 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%22 = vector.extract %15[2] : vector<16x9xf32>
%23 = vector.extract %14[2] : vector<16x32xf32>
%24 = vector.outerproduct %22, %23, %21 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%25 = vector.extract %15[3] : vector<16x9xf32>
%26 = vector.extract %14[3] : vector<16x32xf32>
%27 = vector.outerproduct %25, %26, %24 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%28 = vector.extract %15[4] : vector<16x9xf32>
%29 = vector.extract %14[4] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %27 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %15[5] : vector<16x9xf32>
%32 = vector.extract %14[5] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %15[6] : vector<16x9xf32>
%35 = vector.extract %14[6] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %15[7] : vector<16x9xf32>
%38 = vector.extract %14[7] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %15[8] : vector<16x9xf32>
%41 = vector.extract %14[8] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %15[9] : vector<16x9xf32>
%44 = vector.extract %14[9] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %15[10] : vector<16x9xf32>
%47 = vector.extract %14[10] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %15[11] : vector<16x9xf32>
%50 = vector.extract %14[11] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %15[12] : vector<16x9xf32>
%53 = vector.extract %14[12] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %15[13] : vector<16x9xf32>
%56 = vector.extract %14[13] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %15[14] : vector<16x9xf32>
%59 = vector.extract %14[14] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %15[15] : vector<16x9xf32>
%62 = vector.extract %14[15] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %63 : vector<9x32xf32>
}
vector.transfer_write %11, %9[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, memref<?x32xf32, #map8>
}
}
}
}
}
memref.dealloc %1 : memref<4x16x4x32x16x32xf32>
memref.dealloc %0 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692210>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%5 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%6 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%7 = affine.apply #map2(%arg5)
%8 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%9 = affine.apply #map4(%arg6)
%10 = affine.apply #map3(%arg6, %arg3)
%11 = vector.transfer_read %arg1[%10, %8], %cst {in_bounds = [true, true]} : memref<2048x2048xf32>, vector<16x32xf32>
vector.transfer_write %11, %4[%5, %6, %7, %9, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%5 = affine.min #map5(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %5 step %c9 {
%7 = affine.apply #map6(%arg5)
%8 = affine.apply #map3(%arg5, %arg3)
%9 = affine.min #map7(%arg5, %5)
%10 = arith.cmpi sle, %c9, %9 : index
scf.for %arg6 = %c0 to %c512 step %c16 {
%11 = affine.apply #map4(%arg6)
%12 = affine.apply #map3(%arg6, %arg4)
%13 = memref.subview %arg0[%8, %12] [%9, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map8>
%14 = scf.if %10 -> (memref<?x16xf32, #map9>) {
%16 = memref.cast %13 : memref<?x16xf32, #map8> to memref<?x16xf32, #map9>
scf.yield %16 : memref<?x16xf32, #map9>
} else {
linalg.fill(%cst, %0) : f32, memref<9x16xf32>
%16 = memref.subview %13[0, 0] [%9, 16] [1, 1] : memref<?x16xf32, #map8> to memref<?x16xf32, #map8>
%17 = memref.subview %0[0, 0] [%9, 16] [1, 1] : memref<9x16xf32> to memref<?x16xf32, #map10>
linalg.copy(%16, %17) : memref<?x16xf32, #map8>, memref<?x16xf32, #map10>
%18 = memref.cast %0 : memref<9x16xf32> to memref<?x16xf32, #map9>
scf.yield %18 : memref<?x16xf32, #map9>
}
%15 = vector.transfer_read %14[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16xf32, #map9>, vector<9x16xf32>
vector.transfer_write %15, %3[%6, %7, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%7 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map8>
%8 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%9 = affine.min #map7(%arg6, %5)
%10 = affine.apply #map6(%arg6)
%11 = arith.cmpi sle, %c9, %9 : index
%12 = arith.cmpi sgt, %c9, %9 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%13 = affine.apply #map2(%arg7)
%14 = memref.subview %7[%arg6, %arg7] [%9, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%15 = scf.if %11 -> (memref<?x32xf32, #map9>) {
%19 = memref.cast %14 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %19 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %1) : f32, memref<9x32xf32>
%19 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%20 = memref.subview %1[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map11>
linalg.copy(%19, %20) : memref<?x32xf32, #map8>, memref<?x32xf32, #map11>
%21 = memref.cast %1 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %21 : memref<?x32xf32, #map9>
}
%16 = vector.transfer_read %15[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32xf32, #map9>, vector<9x32xf32>
%17 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %16) -> (vector<9x32xf32>) {
%19 = affine.apply #map4(%arg8)
%20 = vector.transfer_read %3[%6, %10, %19, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%21 = vector.transfer_read %4[%6, %8, %13, %19, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%22 = vector.transpose %20, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%23 = vector.extract %22[0] : vector<16x9xf32>
%24 = vector.extract %21[0] : vector<16x32xf32>
%25 = vector.outerproduct %23, %24, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%26 = vector.extract %22[1] : vector<16x9xf32>
%27 = vector.extract %21[1] : vector<16x32xf32>
%28 = vector.outerproduct %26, %27, %25 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%29 = vector.extract %22[2] : vector<16x9xf32>
%30 = vector.extract %21[2] : vector<16x32xf32>
%31 = vector.outerproduct %29, %30, %28 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%32 = vector.extract %22[3] : vector<16x9xf32>
%33 = vector.extract %21[3] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %31 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %22[4] : vector<16x9xf32>
%36 = vector.extract %21[4] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %22[5] : vector<16x9xf32>
%39 = vector.extract %21[5] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %22[6] : vector<16x9xf32>
%42 = vector.extract %21[6] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %22[7] : vector<16x9xf32>
%45 = vector.extract %21[7] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %22[8] : vector<16x9xf32>
%48 = vector.extract %21[8] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %22[9] : vector<16x9xf32>
%51 = vector.extract %21[9] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %22[10] : vector<16x9xf32>
%54 = vector.extract %21[10] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %22[11] : vector<16x9xf32>
%57 = vector.extract %21[11] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %22[12] : vector<16x9xf32>
%60 = vector.extract %21[12] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %22[13] : vector<16x9xf32>
%63 = vector.extract %21[13] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %22[14] : vector<16x9xf32>
%66 = vector.extract %21[14] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%68 = vector.extract %22[15] : vector<16x9xf32>
%69 = vector.extract %21[15] : vector<16x32xf32>
%70 = vector.outerproduct %68, %69, %67 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %70 : vector<9x32xf32>
}
%18 = scf.if %11 -> (memref<?x32xf32, #map9>) {
%19 = memref.cast %14 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %19 : memref<?x32xf32, #map9>
} else {
%19 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %19 : memref<?x32xf32, #map9>
}
vector.transfer_write %17, %18[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x32xf32, #map9>
scf.if %12 {
%19 = memref.subview %2[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map11>
%20 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%19, %20) : memref<?x32xf32, #map11>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %4 : memref<4x16x4x32x16x32xf32>
memref.dealloc %3 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6921d0>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0) -> (288, -d0 + 2048)>
#map6 = affine_map<(d0) -> (d0 ceildiv 9)>
#map7 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
#map11 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%5 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%6 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%7 = affine.apply #map2(%arg5)
%8 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%9 = affine.apply #map4(%arg6)
%10 = affine.apply #map3(%arg6, %arg3)
%11 = vector.transfer_read %arg1[%10, %8], %cst {in_bounds = [true, true]} : memref<2048x2048xf32>, vector<16x32xf32>
vector.transfer_write %11, %4[%5, %6, %7, %9, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<4x16x4x32x16x32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%5 = affine.min #map5(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %5 step %c9 {
%7 = affine.apply #map6(%arg5)
%8 = affine.apply #map3(%arg5, %arg3)
%9 = affine.min #map7(%arg5, %5)
%10 = arith.cmpi sle, %c9, %9 : index
scf.for %arg6 = %c0 to %c512 step %c16 {
%11 = affine.apply #map4(%arg6)
%12 = affine.apply #map3(%arg6, %arg4)
%13 = memref.subview %arg0[%8, %12] [%9, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map8>
%14 = scf.if %10 -> (memref<?x16xf32, #map9>) {
%16 = memref.cast %13 : memref<?x16xf32, #map8> to memref<?x16xf32, #map9>
scf.yield %16 : memref<?x16xf32, #map9>
} else {
linalg.fill(%cst, %0) : f32, memref<9x16xf32>
%16 = memref.subview %13[0, 0] [%9, 16] [1, 1] : memref<?x16xf32, #map8> to memref<?x16xf32, #map8>
%17 = memref.subview %0[0, 0] [%9, 16] [1, 1] : memref<9x16xf32> to memref<?x16xf32, #map10>
linalg.copy(%16, %17) : memref<?x16xf32, #map8>, memref<?x16xf32, #map10>
%18 = memref.cast %0 : memref<9x16xf32> to memref<?x16xf32, #map9>
scf.yield %18 : memref<?x16xf32, #map9>
}
%15 = vector.transfer_read %14[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16xf32, #map9>, vector<9x16xf32>
vector.transfer_write %15, %3[%6, %7, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<4x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%7 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map8>
%8 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%9 = affine.min #map7(%arg6, %5)
%10 = affine.apply #map6(%arg6)
%11 = arith.cmpi sle, %c9, %9 : index
%12 = arith.cmpi sgt, %c9, %9 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%13 = affine.apply #map2(%arg7)
%14 = memref.subview %7[%arg6, %arg7] [%9, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%15 = scf.if %11 -> (memref<?x32xf32, #map9>) {
%19 = memref.cast %14 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %19 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %1) : f32, memref<9x32xf32>
%19 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%20 = memref.subview %1[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map11>
linalg.copy(%19, %20) : memref<?x32xf32, #map8>, memref<?x32xf32, #map11>
%21 = memref.cast %1 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %21 : memref<?x32xf32, #map9>
}
%16 = vector.transfer_read %15[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32xf32, #map9>, vector<9x32xf32>
%17 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %16) -> (vector<9x32xf32>) {
%19 = affine.apply #map4(%arg8)
%20 = vector.transfer_read %3[%6, %10, %19, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x32x32x9x16xf32>, vector<9x16xf32>
%21 = vector.transfer_read %4[%6, %8, %13, %19, %c0, %c0], %cst {in_bounds = [true, true]} : memref<4x16x4x32x16x32xf32>, vector<16x32xf32>
%22 = vector.transpose %20, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%23 = vector.extract %22[0] : vector<16x9xf32>
%24 = vector.extract %21[0] : vector<16x32xf32>
%25 = vector.outerproduct %23, %24, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%26 = vector.extract %22[1] : vector<16x9xf32>
%27 = vector.extract %21[1] : vector<16x32xf32>
%28 = vector.outerproduct %26, %27, %25 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%29 = vector.extract %22[2] : vector<16x9xf32>
%30 = vector.extract %21[2] : vector<16x32xf32>
%31 = vector.outerproduct %29, %30, %28 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%32 = vector.extract %22[3] : vector<16x9xf32>
%33 = vector.extract %21[3] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %31 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %22[4] : vector<16x9xf32>
%36 = vector.extract %21[4] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %22[5] : vector<16x9xf32>
%39 = vector.extract %21[5] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %22[6] : vector<16x9xf32>
%42 = vector.extract %21[6] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %22[7] : vector<16x9xf32>
%45 = vector.extract %21[7] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %22[8] : vector<16x9xf32>
%48 = vector.extract %21[8] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %22[9] : vector<16x9xf32>
%51 = vector.extract %21[9] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %22[10] : vector<16x9xf32>
%54 = vector.extract %21[10] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %22[11] : vector<16x9xf32>
%57 = vector.extract %21[11] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %22[12] : vector<16x9xf32>
%60 = vector.extract %21[12] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %22[13] : vector<16x9xf32>
%63 = vector.extract %21[13] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %22[14] : vector<16x9xf32>
%66 = vector.extract %21[14] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%68 = vector.extract %22[15] : vector<16x9xf32>
%69 = vector.extract %21[15] : vector<16x32xf32>
%70 = vector.outerproduct %68, %69, %67 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %70 : vector<9x32xf32>
}
%18 = scf.if %11 -> (memref<?x32xf32, #map9>) {
%19 = memref.cast %14 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %19 : memref<?x32xf32, #map9>
} else {
%19 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %19 : memref<?x32xf32, #map9>
}
vector.transfer_write %17, %18[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x32xf32, #map9>
scf.if %12 {
%19 = memref.subview %2[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map11>
%20 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%19, %20) : memref<?x32xf32, #map11>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %4 : memref<4x16x4x32x16x32xf32>
memref.dealloc %3 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692150>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0, d1) -> (d0 + d1 + 1)>
#map6 = affine_map<(d0, d1) -> (d0 + d1 + 2)>
#map7 = affine_map<(d0, d1) -> (d0 + d1 + 3)>
#map8 = affine_map<(d0, d1) -> (d0 + d1 + 4)>
#map9 = affine_map<(d0, d1) -> (d0 + d1 + 5)>
#map10 = affine_map<(d0, d1) -> (d0 + d1 + 6)>
#map11 = affine_map<(d0, d1) -> (d0 + d1 + 7)>
#map12 = affine_map<(d0, d1) -> (d0 + d1 + 8)>
#map13 = affine_map<(d0, d1) -> (d0 + d1 + 9)>
#map14 = affine_map<(d0, d1) -> (d0 + d1 + 10)>
#map15 = affine_map<(d0, d1) -> (d0 + d1 + 11)>
#map16 = affine_map<(d0, d1) -> (d0 + d1 + 12)>
#map17 = affine_map<(d0, d1) -> (d0 + d1 + 13)>
#map18 = affine_map<(d0, d1) -> (d0 + d1 + 14)>
#map19 = affine_map<(d0, d1) -> (d0 + d1 + 15)>
#map20 = affine_map<(d0) -> (288, -d0 + 2048)>
#map21 = affine_map<(d0) -> (d0 ceildiv 9)>
#map22 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map23 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map24 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map25 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
#map26 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst_1, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%5 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%6 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%7 = affine.apply #map2(%arg5)
%8 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%9 = affine.apply #map4(%arg6)
%10 = affine.apply #map3(%arg6, %arg3)
%11 = vector.load %arg1[%10, %8] : memref<2048x2048xf32>, vector<32xf32>
%12 = affine.apply #map5(%arg6, %arg3)
%13 = vector.load %arg1[%12, %8] : memref<2048x2048xf32>, vector<32xf32>
%14 = affine.apply #map6(%arg6, %arg3)
%15 = vector.load %arg1[%14, %8] : memref<2048x2048xf32>, vector<32xf32>
%16 = affine.apply #map7(%arg6, %arg3)
%17 = vector.load %arg1[%16, %8] : memref<2048x2048xf32>, vector<32xf32>
%18 = affine.apply #map8(%arg6, %arg3)
%19 = vector.load %arg1[%18, %8] : memref<2048x2048xf32>, vector<32xf32>
%20 = affine.apply #map9(%arg6, %arg3)
%21 = vector.load %arg1[%20, %8] : memref<2048x2048xf32>, vector<32xf32>
%22 = affine.apply #map10(%arg6, %arg3)
%23 = vector.load %arg1[%22, %8] : memref<2048x2048xf32>, vector<32xf32>
%24 = affine.apply #map11(%arg6, %arg3)
%25 = vector.load %arg1[%24, %8] : memref<2048x2048xf32>, vector<32xf32>
%26 = affine.apply #map12(%arg6, %arg3)
%27 = vector.load %arg1[%26, %8] : memref<2048x2048xf32>, vector<32xf32>
%28 = affine.apply #map13(%arg6, %arg3)
%29 = vector.load %arg1[%28, %8] : memref<2048x2048xf32>, vector<32xf32>
%30 = affine.apply #map14(%arg6, %arg3)
%31 = vector.load %arg1[%30, %8] : memref<2048x2048xf32>, vector<32xf32>
%32 = affine.apply #map15(%arg6, %arg3)
%33 = vector.load %arg1[%32, %8] : memref<2048x2048xf32>, vector<32xf32>
%34 = affine.apply #map16(%arg6, %arg3)
%35 = vector.load %arg1[%34, %8] : memref<2048x2048xf32>, vector<32xf32>
%36 = affine.apply #map17(%arg6, %arg3)
%37 = vector.load %arg1[%36, %8] : memref<2048x2048xf32>, vector<32xf32>
%38 = affine.apply #map18(%arg6, %arg3)
%39 = vector.load %arg1[%38, %8] : memref<2048x2048xf32>, vector<32xf32>
%40 = affine.apply #map19(%arg6, %arg3)
%41 = vector.load %arg1[%40, %8] : memref<2048x2048xf32>, vector<32xf32>
vector.store %11, %4[%5, %6, %7, %9, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %13, %4[%5, %6, %7, %9, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %15, %4[%5, %6, %7, %9, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %17, %4[%5, %6, %7, %9, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %19, %4[%5, %6, %7, %9, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %21, %4[%5, %6, %7, %9, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %4[%5, %6, %7, %9, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %4[%5, %6, %7, %9, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %4[%5, %6, %7, %9, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %4[%5, %6, %7, %9, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %4[%5, %6, %7, %9, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %4[%5, %6, %7, %9, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %4[%5, %6, %7, %9, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %4[%5, %6, %7, %9, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %4[%5, %6, %7, %9, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %4[%5, %6, %7, %9, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%5 = affine.min #map20(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %5 step %c9 {
%7 = affine.apply #map21(%arg5)
%8 = affine.apply #map3(%arg5, %arg3)
%9 = affine.min #map22(%arg5, %5)
%10 = arith.cmpi sle, %c9, %9 : index
scf.for %arg6 = %c0 to %c512 step %c16 {
%11 = affine.apply #map4(%arg6)
%12 = affine.apply #map3(%arg6, %arg4)
%13 = memref.subview %arg0[%8, %12] [%9, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map23>
%14 = scf.if %10 -> (memref<?x16xf32, #map24>) {
%24 = memref.cast %13 : memref<?x16xf32, #map23> to memref<?x16xf32, #map24>
scf.yield %24 : memref<?x16xf32, #map24>
} else {
linalg.fill(%cst_1, %0) : f32, memref<9x16xf32>
%24 = memref.subview %13[0, 0] [%9, 16] [1, 1] : memref<?x16xf32, #map23> to memref<?x16xf32, #map23>
%25 = memref.subview %0[0, 0] [%9, 16] [1, 1] : memref<9x16xf32> to memref<?x16xf32, #map25>
linalg.copy(%24, %25) : memref<?x16xf32, #map23>, memref<?x16xf32, #map25>
%26 = memref.cast %0 : memref<9x16xf32> to memref<?x16xf32, #map24>
scf.yield %26 : memref<?x16xf32, #map24>
}
%15 = vector.load %14[%c0, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%16 = vector.load %14[%c1, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%17 = vector.load %14[%c2, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%18 = vector.load %14[%c3, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%19 = vector.load %14[%c4, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%20 = vector.load %14[%c5, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%21 = vector.load %14[%c6, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%22 = vector.load %14[%c7, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%23 = vector.load %14[%c8, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
vector.store %15, %3[%6, %7, %11, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %16, %3[%6, %7, %11, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %17, %3[%6, %7, %11, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %18, %3[%6, %7, %11, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %19, %3[%6, %7, %11, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %20, %3[%6, %7, %11, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %21, %3[%6, %7, %11, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %22, %3[%6, %7, %11, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %23, %3[%6, %7, %11, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%7 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map23>
%8 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%9 = affine.min #map22(%arg6, %5)
%10 = affine.apply #map21(%arg6)
%11 = arith.cmpi sle, %c9, %9 : index
%12 = arith.cmpi sgt, %c9, %9 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%13 = affine.apply #map2(%arg7)
%14 = memref.subview %7[%arg6, %arg7] [%9, 32] [1, 1] : memref<?x128xf32, #map23> to memref<?x32xf32, #map23>
%15 = scf.if %11 -> (memref<?x32xf32, #map24>) {
%45 = memref.cast %14 : memref<?x32xf32, #map23> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
} else {
linalg.fill(%cst_1, %1) : f32, memref<9x32xf32>
%45 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map23> to memref<?x32xf32, #map23>
%46 = memref.subview %1[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map26>
linalg.copy(%45, %46) : memref<?x32xf32, #map23>, memref<?x32xf32, #map26>
%47 = memref.cast %1 : memref<9x32xf32> to memref<?x32xf32, #map24>
scf.yield %47 : memref<?x32xf32, #map24>
}
%16 = vector.load %15[%c0, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%17 = vector.insert %16, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%18 = vector.load %15[%c1, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%19 = vector.insert %18, %17 [1] : vector<32xf32> into vector<9x32xf32>
%20 = vector.load %15[%c2, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%21 = vector.insert %20, %19 [2] : vector<32xf32> into vector<9x32xf32>
%22 = vector.load %15[%c3, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%23 = vector.insert %22, %21 [3] : vector<32xf32> into vector<9x32xf32>
%24 = vector.load %15[%c4, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%25 = vector.insert %24, %23 [4] : vector<32xf32> into vector<9x32xf32>
%26 = vector.load %15[%c5, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%27 = vector.insert %26, %25 [5] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %15[%c6, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%29 = vector.insert %28, %27 [6] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %15[%c7, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%31 = vector.insert %30, %29 [7] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %15[%c8, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%33 = vector.insert %32, %31 [8] : vector<32xf32> into vector<9x32xf32>
%34 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %33) -> (vector<9x32xf32>) {
%45 = affine.apply #map4(%arg8)
%46 = vector.load %3[%6, %10, %45, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%47 = vector.insert %46, %cst [0] : vector<16xf32> into vector<9x16xf32>
%48 = vector.load %3[%6, %10, %45, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%49 = vector.insert %48, %47 [1] : vector<16xf32> into vector<9x16xf32>
%50 = vector.load %3[%6, %10, %45, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%51 = vector.insert %50, %49 [2] : vector<16xf32> into vector<9x16xf32>
%52 = vector.load %3[%6, %10, %45, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%53 = vector.insert %52, %51 [3] : vector<16xf32> into vector<9x16xf32>
%54 = vector.load %3[%6, %10, %45, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%55 = vector.insert %54, %53 [4] : vector<16xf32> into vector<9x16xf32>
%56 = vector.load %3[%6, %10, %45, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%57 = vector.insert %56, %55 [5] : vector<16xf32> into vector<9x16xf32>
%58 = vector.load %3[%6, %10, %45, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%59 = vector.insert %58, %57 [6] : vector<16xf32> into vector<9x16xf32>
%60 = vector.load %3[%6, %10, %45, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%61 = vector.insert %60, %59 [7] : vector<16xf32> into vector<9x16xf32>
%62 = vector.load %3[%6, %10, %45, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%63 = vector.insert %62, %61 [8] : vector<16xf32> into vector<9x16xf32>
%64 = vector.load %4[%6, %8, %13, %45, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%65 = vector.load %4[%6, %8, %13, %45, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%66 = vector.load %4[%6, %8, %13, %45, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%67 = vector.load %4[%6, %8, %13, %45, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%68 = vector.load %4[%6, %8, %13, %45, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%69 = vector.load %4[%6, %8, %13, %45, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%70 = vector.load %4[%6, %8, %13, %45, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %4[%6, %8, %13, %45, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %4[%6, %8, %13, %45, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %4[%6, %8, %13, %45, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %4[%6, %8, %13, %45, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %4[%6, %8, %13, %45, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %4[%6, %8, %13, %45, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %4[%6, %8, %13, %45, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %4[%6, %8, %13, %45, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %4[%6, %8, %13, %45, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.transpose %63, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%81 = vector.extract %80[0] : vector<16x9xf32>
%82 = vector.outerproduct %81, %64, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%83 = vector.extract %80[1] : vector<16x9xf32>
%84 = vector.outerproduct %83, %65, %82 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%85 = vector.extract %80[2] : vector<16x9xf32>
%86 = vector.outerproduct %85, %66, %84 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%87 = vector.extract %80[3] : vector<16x9xf32>
%88 = vector.outerproduct %87, %67, %86 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%89 = vector.extract %80[4] : vector<16x9xf32>
%90 = vector.outerproduct %89, %68, %88 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%91 = vector.extract %80[5] : vector<16x9xf32>
%92 = vector.outerproduct %91, %69, %90 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%93 = vector.extract %80[6] : vector<16x9xf32>
%94 = vector.outerproduct %93, %70, %92 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%95 = vector.extract %80[7] : vector<16x9xf32>
%96 = vector.outerproduct %95, %71, %94 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%97 = vector.extract %80[8] : vector<16x9xf32>
%98 = vector.outerproduct %97, %72, %96 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%99 = vector.extract %80[9] : vector<16x9xf32>
%100 = vector.outerproduct %99, %73, %98 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%101 = vector.extract %80[10] : vector<16x9xf32>
%102 = vector.outerproduct %101, %74, %100 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%103 = vector.extract %80[11] : vector<16x9xf32>
%104 = vector.outerproduct %103, %75, %102 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%105 = vector.extract %80[12] : vector<16x9xf32>
%106 = vector.outerproduct %105, %76, %104 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%107 = vector.extract %80[13] : vector<16x9xf32>
%108 = vector.outerproduct %107, %77, %106 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%109 = vector.extract %80[14] : vector<16x9xf32>
%110 = vector.outerproduct %109, %78, %108 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%111 = vector.extract %80[15] : vector<16x9xf32>
%112 = vector.outerproduct %111, %79, %110 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %112 : vector<9x32xf32>
}
%35 = scf.if %11 -> (memref<?x32xf32, #map24>) {
%45 = memref.cast %14 : memref<?x32xf32, #map23> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
} else {
%45 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
}
%36 = vector.extract %34[0] : vector<9x32xf32>
vector.store %36, %35[%c0, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%37 = vector.extract %34[1] : vector<9x32xf32>
vector.store %37, %35[%c1, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%38 = vector.extract %34[2] : vector<9x32xf32>
vector.store %38, %35[%c2, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%39 = vector.extract %34[3] : vector<9x32xf32>
vector.store %39, %35[%c3, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%40 = vector.extract %34[4] : vector<9x32xf32>
vector.store %40, %35[%c4, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%41 = vector.extract %34[5] : vector<9x32xf32>
vector.store %41, %35[%c5, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%42 = vector.extract %34[6] : vector<9x32xf32>
vector.store %42, %35[%c6, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%43 = vector.extract %34[7] : vector<9x32xf32>
vector.store %43, %35[%c7, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%44 = vector.extract %34[8] : vector<9x32xf32>
vector.store %44, %35[%c8, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
scf.if %12 {
%45 = memref.subview %2[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map26>
%46 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map23> to memref<?x32xf32, #map23>
linalg.copy(%45, %46) : memref<?x32xf32, #map26>, memref<?x32xf32, #map23>
}
}
}
}
}
}
memref.dealloc %4 : memref<4x16x4x32x16x32xf32>
memref.dealloc %3 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692390>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0, d1) -> (d0 + d1 + 1)>
#map6 = affine_map<(d0, d1) -> (d0 + d1 + 2)>
#map7 = affine_map<(d0, d1) -> (d0 + d1 + 3)>
#map8 = affine_map<(d0, d1) -> (d0 + d1 + 4)>
#map9 = affine_map<(d0, d1) -> (d0 + d1 + 5)>
#map10 = affine_map<(d0, d1) -> (d0 + d1 + 6)>
#map11 = affine_map<(d0, d1) -> (d0 + d1 + 7)>
#map12 = affine_map<(d0, d1) -> (d0 + d1 + 8)>
#map13 = affine_map<(d0, d1) -> (d0 + d1 + 9)>
#map14 = affine_map<(d0, d1) -> (d0 + d1 + 10)>
#map15 = affine_map<(d0, d1) -> (d0 + d1 + 11)>
#map16 = affine_map<(d0, d1) -> (d0 + d1 + 12)>
#map17 = affine_map<(d0, d1) -> (d0 + d1 + 13)>
#map18 = affine_map<(d0, d1) -> (d0 + d1 + 14)>
#map19 = affine_map<(d0, d1) -> (d0 + d1 + 15)>
#map20 = affine_map<(d0) -> (288, -d0 + 2048)>
#map21 = affine_map<(d0) -> (d0 ceildiv 9)>
#map22 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map23 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map24 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map25 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
#map26 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%5 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%6 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%7 = affine.apply #map2(%arg5)
%8 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%9 = affine.apply #map4(%arg6)
%10 = affine.apply #map3(%arg6, %arg3)
%11 = vector.load %arg1[%10, %8] : memref<2048x2048xf32>, vector<32xf32>
%12 = affine.apply #map5(%arg6, %arg3)
%13 = vector.load %arg1[%12, %8] : memref<2048x2048xf32>, vector<32xf32>
%14 = affine.apply #map6(%arg6, %arg3)
%15 = vector.load %arg1[%14, %8] : memref<2048x2048xf32>, vector<32xf32>
%16 = affine.apply #map7(%arg6, %arg3)
%17 = vector.load %arg1[%16, %8] : memref<2048x2048xf32>, vector<32xf32>
%18 = affine.apply #map8(%arg6, %arg3)
%19 = vector.load %arg1[%18, %8] : memref<2048x2048xf32>, vector<32xf32>
%20 = affine.apply #map9(%arg6, %arg3)
%21 = vector.load %arg1[%20, %8] : memref<2048x2048xf32>, vector<32xf32>
%22 = affine.apply #map10(%arg6, %arg3)
%23 = vector.load %arg1[%22, %8] : memref<2048x2048xf32>, vector<32xf32>
%24 = affine.apply #map11(%arg6, %arg3)
%25 = vector.load %arg1[%24, %8] : memref<2048x2048xf32>, vector<32xf32>
%26 = affine.apply #map12(%arg6, %arg3)
%27 = vector.load %arg1[%26, %8] : memref<2048x2048xf32>, vector<32xf32>
%28 = affine.apply #map13(%arg6, %arg3)
%29 = vector.load %arg1[%28, %8] : memref<2048x2048xf32>, vector<32xf32>
%30 = affine.apply #map14(%arg6, %arg3)
%31 = vector.load %arg1[%30, %8] : memref<2048x2048xf32>, vector<32xf32>
%32 = affine.apply #map15(%arg6, %arg3)
%33 = vector.load %arg1[%32, %8] : memref<2048x2048xf32>, vector<32xf32>
%34 = affine.apply #map16(%arg6, %arg3)
%35 = vector.load %arg1[%34, %8] : memref<2048x2048xf32>, vector<32xf32>
%36 = affine.apply #map17(%arg6, %arg3)
%37 = vector.load %arg1[%36, %8] : memref<2048x2048xf32>, vector<32xf32>
%38 = affine.apply #map18(%arg6, %arg3)
%39 = vector.load %arg1[%38, %8] : memref<2048x2048xf32>, vector<32xf32>
%40 = affine.apply #map19(%arg6, %arg3)
%41 = vector.load %arg1[%40, %8] : memref<2048x2048xf32>, vector<32xf32>
vector.store %11, %4[%5, %6, %7, %9, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %13, %4[%5, %6, %7, %9, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %15, %4[%5, %6, %7, %9, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %17, %4[%5, %6, %7, %9, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %19, %4[%5, %6, %7, %9, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %21, %4[%5, %6, %7, %9, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %4[%5, %6, %7, %9, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %4[%5, %6, %7, %9, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %4[%5, %6, %7, %9, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %4[%5, %6, %7, %9, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %4[%5, %6, %7, %9, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %4[%5, %6, %7, %9, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %4[%5, %6, %7, %9, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %4[%5, %6, %7, %9, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %4[%5, %6, %7, %9, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %4[%5, %6, %7, %9, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%5 = affine.min #map20(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %5 step %c9 {
%7 = affine.apply #map21(%arg5)
%8 = affine.apply #map3(%arg5, %arg3)
%9 = affine.min #map22(%arg5, %5)
%10 = arith.cmpi sle, %c9, %9 : index
scf.for %arg6 = %c0 to %c512 step %c16 {
%11 = affine.apply #map4(%arg6)
%12 = affine.apply #map3(%arg6, %arg4)
%13 = memref.subview %arg0[%8, %12] [%9, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map23>
%14 = scf.if %10 -> (memref<?x16xf32, #map24>) {
%24 = memref.cast %13 : memref<?x16xf32, #map23> to memref<?x16xf32, #map24>
scf.yield %24 : memref<?x16xf32, #map24>
} else {
linalg.fill(%cst, %0) : f32, memref<9x16xf32>
%24 = memref.subview %13[0, 0] [%9, 16] [1, 1] : memref<?x16xf32, #map23> to memref<?x16xf32, #map23>
%25 = memref.subview %0[0, 0] [%9, 16] [1, 1] : memref<9x16xf32> to memref<?x16xf32, #map25>
linalg.copy(%24, %25) : memref<?x16xf32, #map23>, memref<?x16xf32, #map25>
%26 = memref.cast %0 : memref<9x16xf32> to memref<?x16xf32, #map24>
scf.yield %26 : memref<?x16xf32, #map24>
}
%15 = vector.load %14[%c0, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%16 = vector.load %14[%c1, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%17 = vector.load %14[%c2, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%18 = vector.load %14[%c3, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%19 = vector.load %14[%c4, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%20 = vector.load %14[%c5, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%21 = vector.load %14[%c6, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%22 = vector.load %14[%c7, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%23 = vector.load %14[%c8, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
vector.store %15, %3[%6, %7, %11, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %16, %3[%6, %7, %11, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %17, %3[%6, %7, %11, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %18, %3[%6, %7, %11, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %19, %3[%6, %7, %11, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %20, %3[%6, %7, %11, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %21, %3[%6, %7, %11, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %22, %3[%6, %7, %11, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %23, %3[%6, %7, %11, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%7 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map23>
%8 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%9 = affine.min #map22(%arg6, %5)
%10 = affine.apply #map21(%arg6)
%11 = arith.cmpi sle, %c9, %9 : index
%12 = arith.cmpi sgt, %c9, %9 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%13 = affine.apply #map2(%arg7)
%14 = memref.subview %7[%arg6, %arg7] [%9, 32] [1, 1] : memref<?x128xf32, #map23> to memref<?x32xf32, #map23>
%15 = scf.if %11 -> (memref<?x32xf32, #map24>) {
%45 = memref.cast %14 : memref<?x32xf32, #map23> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
} else {
linalg.fill(%cst, %1) : f32, memref<9x32xf32>
%45 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map23> to memref<?x32xf32, #map23>
%46 = memref.subview %1[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map26>
linalg.copy(%45, %46) : memref<?x32xf32, #map23>, memref<?x32xf32, #map26>
%47 = memref.cast %1 : memref<9x32xf32> to memref<?x32xf32, #map24>
scf.yield %47 : memref<?x32xf32, #map24>
}
%16 = vector.load %15[%c0, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%17 = vector.insert %16, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%18 = vector.load %15[%c1, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%19 = vector.insert %18, %17 [1] : vector<32xf32> into vector<9x32xf32>
%20 = vector.load %15[%c2, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%21 = vector.insert %20, %19 [2] : vector<32xf32> into vector<9x32xf32>
%22 = vector.load %15[%c3, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%23 = vector.insert %22, %21 [3] : vector<32xf32> into vector<9x32xf32>
%24 = vector.load %15[%c4, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%25 = vector.insert %24, %23 [4] : vector<32xf32> into vector<9x32xf32>
%26 = vector.load %15[%c5, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%27 = vector.insert %26, %25 [5] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %15[%c6, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%29 = vector.insert %28, %27 [6] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %15[%c7, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%31 = vector.insert %30, %29 [7] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %15[%c8, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%33 = vector.insert %32, %31 [8] : vector<32xf32> into vector<9x32xf32>
%34 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %33) -> (vector<9x32xf32>) {
%45 = affine.apply #map4(%arg8)
%46 = vector.load %3[%6, %10, %45, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%47 = vector.insert %46, %cst_1 [0] : vector<16xf32> into vector<9x16xf32>
%48 = vector.load %3[%6, %10, %45, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%49 = vector.insert %48, %47 [1] : vector<16xf32> into vector<9x16xf32>
%50 = vector.load %3[%6, %10, %45, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%51 = vector.insert %50, %49 [2] : vector<16xf32> into vector<9x16xf32>
%52 = vector.load %3[%6, %10, %45, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%53 = vector.insert %52, %51 [3] : vector<16xf32> into vector<9x16xf32>
%54 = vector.load %3[%6, %10, %45, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%55 = vector.insert %54, %53 [4] : vector<16xf32> into vector<9x16xf32>
%56 = vector.load %3[%6, %10, %45, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%57 = vector.insert %56, %55 [5] : vector<16xf32> into vector<9x16xf32>
%58 = vector.load %3[%6, %10, %45, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%59 = vector.insert %58, %57 [6] : vector<16xf32> into vector<9x16xf32>
%60 = vector.load %3[%6, %10, %45, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%61 = vector.insert %60, %59 [7] : vector<16xf32> into vector<9x16xf32>
%62 = vector.load %3[%6, %10, %45, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%63 = vector.insert %62, %61 [8] : vector<16xf32> into vector<9x16xf32>
%64 = vector.load %4[%6, %8, %13, %45, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%65 = vector.load %4[%6, %8, %13, %45, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%66 = vector.load %4[%6, %8, %13, %45, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%67 = vector.load %4[%6, %8, %13, %45, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%68 = vector.load %4[%6, %8, %13, %45, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%69 = vector.load %4[%6, %8, %13, %45, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%70 = vector.load %4[%6, %8, %13, %45, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %4[%6, %8, %13, %45, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %4[%6, %8, %13, %45, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %4[%6, %8, %13, %45, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %4[%6, %8, %13, %45, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %4[%6, %8, %13, %45, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %4[%6, %8, %13, %45, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %4[%6, %8, %13, %45, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %4[%6, %8, %13, %45, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %4[%6, %8, %13, %45, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.transpose %63, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%81 = vector.extract %80[0] : vector<16x9xf32>
%82 = vector.outerproduct %81, %64, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%83 = vector.extract %80[1] : vector<16x9xf32>
%84 = vector.outerproduct %83, %65, %82 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%85 = vector.extract %80[2] : vector<16x9xf32>
%86 = vector.outerproduct %85, %66, %84 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%87 = vector.extract %80[3] : vector<16x9xf32>
%88 = vector.outerproduct %87, %67, %86 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%89 = vector.extract %80[4] : vector<16x9xf32>
%90 = vector.outerproduct %89, %68, %88 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%91 = vector.extract %80[5] : vector<16x9xf32>
%92 = vector.outerproduct %91, %69, %90 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%93 = vector.extract %80[6] : vector<16x9xf32>
%94 = vector.outerproduct %93, %70, %92 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%95 = vector.extract %80[7] : vector<16x9xf32>
%96 = vector.outerproduct %95, %71, %94 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%97 = vector.extract %80[8] : vector<16x9xf32>
%98 = vector.outerproduct %97, %72, %96 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%99 = vector.extract %80[9] : vector<16x9xf32>
%100 = vector.outerproduct %99, %73, %98 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%101 = vector.extract %80[10] : vector<16x9xf32>
%102 = vector.outerproduct %101, %74, %100 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%103 = vector.extract %80[11] : vector<16x9xf32>
%104 = vector.outerproduct %103, %75, %102 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%105 = vector.extract %80[12] : vector<16x9xf32>
%106 = vector.outerproduct %105, %76, %104 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%107 = vector.extract %80[13] : vector<16x9xf32>
%108 = vector.outerproduct %107, %77, %106 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%109 = vector.extract %80[14] : vector<16x9xf32>
%110 = vector.outerproduct %109, %78, %108 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%111 = vector.extract %80[15] : vector<16x9xf32>
%112 = vector.outerproduct %111, %79, %110 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %112 : vector<9x32xf32>
}
%35 = scf.if %11 -> (memref<?x32xf32, #map24>) {
%45 = memref.cast %14 : memref<?x32xf32, #map23> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
} else {
%45 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
}
%36 = vector.extract %34[0] : vector<9x32xf32>
vector.store %36, %35[%c0, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%37 = vector.extract %34[1] : vector<9x32xf32>
vector.store %37, %35[%c1, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%38 = vector.extract %34[2] : vector<9x32xf32>
vector.store %38, %35[%c2, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%39 = vector.extract %34[3] : vector<9x32xf32>
vector.store %39, %35[%c3, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%40 = vector.extract %34[4] : vector<9x32xf32>
vector.store %40, %35[%c4, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%41 = vector.extract %34[5] : vector<9x32xf32>
vector.store %41, %35[%c5, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%42 = vector.extract %34[6] : vector<9x32xf32>
vector.store %42, %35[%c6, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%43 = vector.extract %34[7] : vector<9x32xf32>
vector.store %43, %35[%c7, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%44 = vector.extract %34[8] : vector<9x32xf32>
vector.store %44, %35[%c8, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
scf.if %12 {
%45 = memref.subview %2[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map26>
%46 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map23> to memref<?x32xf32, #map23>
linalg.copy(%45, %46) : memref<?x32xf32, #map26>, memref<?x32xf32, #map23>
}
}
}
}
}
}
memref.dealloc %4 : memref<4x16x4x32x16x32xf32>
memref.dealloc %3 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692250>]]]
#map0 = affine_map<(d0) -> (d0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 32)>
#map3 = affine_map<(d0, d1) -> (d0 + d1)>
#map4 = affine_map<(d0) -> (d0 ceildiv 16)>
#map5 = affine_map<(d0, d1) -> (d0 + d1 + 1)>
#map6 = affine_map<(d0, d1) -> (d0 + d1 + 2)>
#map7 = affine_map<(d0, d1) -> (d0 + d1 + 3)>
#map8 = affine_map<(d0, d1) -> (d0 + d1 + 4)>
#map9 = affine_map<(d0, d1) -> (d0 + d1 + 5)>
#map10 = affine_map<(d0, d1) -> (d0 + d1 + 6)>
#map11 = affine_map<(d0, d1) -> (d0 + d1 + 7)>
#map12 = affine_map<(d0, d1) -> (d0 + d1 + 8)>
#map13 = affine_map<(d0, d1) -> (d0 + d1 + 9)>
#map14 = affine_map<(d0, d1) -> (d0 + d1 + 10)>
#map15 = affine_map<(d0, d1) -> (d0 + d1 + 11)>
#map16 = affine_map<(d0, d1) -> (d0 + d1 + 12)>
#map17 = affine_map<(d0, d1) -> (d0 + d1 + 13)>
#map18 = affine_map<(d0, d1) -> (d0 + d1 + 14)>
#map19 = affine_map<(d0, d1) -> (d0 + d1 + 15)>
#map20 = affine_map<(d0) -> (288, -d0 + 2048)>
#map21 = affine_map<(d0) -> (d0 ceildiv 9)>
#map22 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map23 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map24 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map25 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
#map26 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%cst = arith.constant dense<0.000000e+00> : vector<16x9xf32>
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<4x32x32x9x16xf32>
%4 = memref.alloc() {alignment = 128 : i64} : memref<4x16x4x32x16x32xf32>
linalg.fill(%cst_2, %arg2) : f32, memref<2048x2048xf32>
scf.for %arg3 = %c0 to %c2048 step %c512 {
%5 = affine.apply #map0(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c128 {
%6 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%7 = affine.apply #map2(%arg5)
%8 = affine.apply #map3(%arg5, %arg4)
scf.for %arg6 = %c0 to %c512 step %c16 {
%9 = affine.apply #map4(%arg6)
%10 = affine.apply #map3(%arg6, %arg3)
%11 = vector.load %arg1[%10, %8] : memref<2048x2048xf32>, vector<32xf32>
%12 = affine.apply #map5(%arg6, %arg3)
%13 = vector.load %arg1[%12, %8] : memref<2048x2048xf32>, vector<32xf32>
%14 = affine.apply #map6(%arg6, %arg3)
%15 = vector.load %arg1[%14, %8] : memref<2048x2048xf32>, vector<32xf32>
%16 = affine.apply #map7(%arg6, %arg3)
%17 = vector.load %arg1[%16, %8] : memref<2048x2048xf32>, vector<32xf32>
%18 = affine.apply #map8(%arg6, %arg3)
%19 = vector.load %arg1[%18, %8] : memref<2048x2048xf32>, vector<32xf32>
%20 = affine.apply #map9(%arg6, %arg3)
%21 = vector.load %arg1[%20, %8] : memref<2048x2048xf32>, vector<32xf32>
%22 = affine.apply #map10(%arg6, %arg3)
%23 = vector.load %arg1[%22, %8] : memref<2048x2048xf32>, vector<32xf32>
%24 = affine.apply #map11(%arg6, %arg3)
%25 = vector.load %arg1[%24, %8] : memref<2048x2048xf32>, vector<32xf32>
%26 = affine.apply #map12(%arg6, %arg3)
%27 = vector.load %arg1[%26, %8] : memref<2048x2048xf32>, vector<32xf32>
%28 = affine.apply #map13(%arg6, %arg3)
%29 = vector.load %arg1[%28, %8] : memref<2048x2048xf32>, vector<32xf32>
%30 = affine.apply #map14(%arg6, %arg3)
%31 = vector.load %arg1[%30, %8] : memref<2048x2048xf32>, vector<32xf32>
%32 = affine.apply #map15(%arg6, %arg3)
%33 = vector.load %arg1[%32, %8] : memref<2048x2048xf32>, vector<32xf32>
%34 = affine.apply #map16(%arg6, %arg3)
%35 = vector.load %arg1[%34, %8] : memref<2048x2048xf32>, vector<32xf32>
%36 = affine.apply #map17(%arg6, %arg3)
%37 = vector.load %arg1[%36, %8] : memref<2048x2048xf32>, vector<32xf32>
%38 = affine.apply #map18(%arg6, %arg3)
%39 = vector.load %arg1[%38, %8] : memref<2048x2048xf32>, vector<32xf32>
%40 = affine.apply #map19(%arg6, %arg3)
%41 = vector.load %arg1[%40, %8] : memref<2048x2048xf32>, vector<32xf32>
vector.store %11, %4[%5, %6, %7, %9, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %13, %4[%5, %6, %7, %9, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %15, %4[%5, %6, %7, %9, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %17, %4[%5, %6, %7, %9, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %19, %4[%5, %6, %7, %9, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %21, %4[%5, %6, %7, %9, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %4[%5, %6, %7, %9, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %4[%5, %6, %7, %9, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %4[%5, %6, %7, %9, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %4[%5, %6, %7, %9, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %4[%5, %6, %7, %9, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %4[%5, %6, %7, %9, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %4[%5, %6, %7, %9, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %4[%5, %6, %7, %9, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %4[%5, %6, %7, %9, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %4[%5, %6, %7, %9, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
scf.for %arg3 = %c0 to %c2048 step %c288 {
%5 = affine.min #map20(%arg3)
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %5 step %c9 {
%7 = affine.apply #map21(%arg5)
%8 = affine.apply #map3(%arg5, %arg3)
%9 = affine.min #map22(%arg5, %5)
%10 = arith.cmpi sle, %c9, %9 : index
scf.for %arg6 = %c0 to %c512 step %c16 {
%11 = affine.apply #map4(%arg6)
%12 = affine.apply #map3(%arg6, %arg4)
%13 = memref.subview %arg0[%8, %12] [%9, 16] [1, 1] : memref<2048x2048xf32> to memref<?x16xf32, #map23>
%14 = scf.if %10 -> (memref<?x16xf32, #map24>) {
%24 = memref.cast %13 : memref<?x16xf32, #map23> to memref<?x16xf32, #map24>
scf.yield %24 : memref<?x16xf32, #map24>
} else {
linalg.fill(%cst_2, %0) : f32, memref<9x16xf32>
%24 = memref.subview %13[0, 0] [%9, 16] [1, 1] : memref<?x16xf32, #map23> to memref<?x16xf32, #map23>
%25 = memref.subview %0[0, 0] [%9, 16] [1, 1] : memref<9x16xf32> to memref<?x16xf32, #map25>
linalg.copy(%24, %25) : memref<?x16xf32, #map23>, memref<?x16xf32, #map25>
%26 = memref.cast %0 : memref<9x16xf32> to memref<?x16xf32, #map24>
scf.yield %26 : memref<?x16xf32, #map24>
}
%15 = vector.load %14[%c0, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%16 = vector.load %14[%c1, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%17 = vector.load %14[%c2, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%18 = vector.load %14[%c3, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%19 = vector.load %14[%c4, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%20 = vector.load %14[%c5, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%21 = vector.load %14[%c6, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%22 = vector.load %14[%c7, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
%23 = vector.load %14[%c8, %c0] : memref<?x16xf32, #map24>, vector<16xf32>
vector.store %15, %3[%6, %7, %11, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %16, %3[%6, %7, %11, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %17, %3[%6, %7, %11, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %18, %3[%6, %7, %11, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %19, %3[%6, %7, %11, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %20, %3[%6, %7, %11, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %21, %3[%6, %7, %11, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %22, %3[%6, %7, %11, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
vector.store %23, %3[%6, %7, %11, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %c2048 step %c512 {
%6 = affine.apply #map0(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%7 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<2048x2048xf32> to memref<?x128xf32, #map23>
%8 = affine.apply #map1(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%9 = affine.min #map22(%arg6, %5)
%10 = affine.apply #map21(%arg6)
%11 = arith.cmpi sle, %c9, %9 : index
%12 = arith.cmpi sgt, %c9, %9 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%13 = affine.apply #map2(%arg7)
%14 = memref.subview %7[%arg6, %arg7] [%9, 32] [1, 1] : memref<?x128xf32, #map23> to memref<?x32xf32, #map23>
%15 = scf.if %11 -> (memref<?x32xf32, #map24>) {
%45 = memref.cast %14 : memref<?x32xf32, #map23> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
} else {
linalg.fill(%cst_2, %1) : f32, memref<9x32xf32>
%45 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map23> to memref<?x32xf32, #map23>
%46 = memref.subview %1[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map26>
linalg.copy(%45, %46) : memref<?x32xf32, #map23>, memref<?x32xf32, #map26>
%47 = memref.cast %1 : memref<9x32xf32> to memref<?x32xf32, #map24>
scf.yield %47 : memref<?x32xf32, #map24>
}
%16 = vector.load %15[%c0, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%17 = vector.insert %16, %cst_1 [0] : vector<32xf32> into vector<9x32xf32>
%18 = vector.load %15[%c1, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%19 = vector.insert %18, %17 [1] : vector<32xf32> into vector<9x32xf32>
%20 = vector.load %15[%c2, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%21 = vector.insert %20, %19 [2] : vector<32xf32> into vector<9x32xf32>
%22 = vector.load %15[%c3, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%23 = vector.insert %22, %21 [3] : vector<32xf32> into vector<9x32xf32>
%24 = vector.load %15[%c4, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%25 = vector.insert %24, %23 [4] : vector<32xf32> into vector<9x32xf32>
%26 = vector.load %15[%c5, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%27 = vector.insert %26, %25 [5] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %15[%c6, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%29 = vector.insert %28, %27 [6] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %15[%c7, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%31 = vector.insert %30, %29 [7] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %15[%c8, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%33 = vector.insert %32, %31 [8] : vector<32xf32> into vector<9x32xf32>
%34 = scf.for %arg8 = %c0 to %c512 step %c16 iter_args(%arg9 = %33) -> (vector<9x32xf32>) {
%45 = affine.apply #map4(%arg8)
%46 = vector.load %3[%6, %10, %45, %c0, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%47 = vector.insert %46, %cst_0 [0] : vector<16xf32> into vector<9x16xf32>
%48 = vector.load %3[%6, %10, %45, %c1, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%49 = vector.insert %48, %47 [1] : vector<16xf32> into vector<9x16xf32>
%50 = vector.load %3[%6, %10, %45, %c2, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%51 = vector.insert %50, %49 [2] : vector<16xf32> into vector<9x16xf32>
%52 = vector.load %3[%6, %10, %45, %c3, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%53 = vector.insert %52, %51 [3] : vector<16xf32> into vector<9x16xf32>
%54 = vector.load %3[%6, %10, %45, %c4, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%55 = vector.insert %54, %53 [4] : vector<16xf32> into vector<9x16xf32>
%56 = vector.load %3[%6, %10, %45, %c5, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%57 = vector.insert %56, %55 [5] : vector<16xf32> into vector<9x16xf32>
%58 = vector.load %3[%6, %10, %45, %c6, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%59 = vector.insert %58, %57 [6] : vector<16xf32> into vector<9x16xf32>
%60 = vector.load %3[%6, %10, %45, %c7, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%61 = vector.insert %60, %59 [7] : vector<16xf32> into vector<9x16xf32>
%62 = vector.load %3[%6, %10, %45, %c8, %c0] : memref<4x32x32x9x16xf32>, vector<16xf32>
%63 = vector.insert %62, %61 [8] : vector<16xf32> into vector<9x16xf32>
%64 = vector.load %4[%6, %8, %13, %45, %c0, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%65 = vector.load %4[%6, %8, %13, %45, %c1, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%66 = vector.load %4[%6, %8, %13, %45, %c2, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%67 = vector.load %4[%6, %8, %13, %45, %c3, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%68 = vector.load %4[%6, %8, %13, %45, %c4, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%69 = vector.load %4[%6, %8, %13, %45, %c5, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%70 = vector.load %4[%6, %8, %13, %45, %c6, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %4[%6, %8, %13, %45, %c7, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %4[%6, %8, %13, %45, %c8, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %4[%6, %8, %13, %45, %c9, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %4[%6, %8, %13, %45, %c10, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %4[%6, %8, %13, %45, %c11, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %4[%6, %8, %13, %45, %c12, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %4[%6, %8, %13, %45, %c13, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %4[%6, %8, %13, %45, %c14, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %4[%6, %8, %13, %45, %c15, %c0] : memref<4x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.extract %63[0, 0] : vector<9x16xf32>
%81 = vector.insert %80, %cst [0, 0] : f32 into vector<16x9xf32>
%82 = vector.extract %63[1, 0] : vector<9x16xf32>
%83 = vector.insert %82, %81 [0, 1] : f32 into vector<16x9xf32>
%84 = vector.extract %63[2, 0] : vector<9x16xf32>
%85 = vector.insert %84, %83 [0, 2] : f32 into vector<16x9xf32>
%86 = vector.extract %63[3, 0] : vector<9x16xf32>
%87 = vector.insert %86, %85 [0, 3] : f32 into vector<16x9xf32>
%88 = vector.extract %63[4, 0] : vector<9x16xf32>
%89 = vector.insert %88, %87 [0, 4] : f32 into vector<16x9xf32>
%90 = vector.extract %63[5, 0] : vector<9x16xf32>
%91 = vector.insert %90, %89 [0, 5] : f32 into vector<16x9xf32>
%92 = vector.extract %63[6, 0] : vector<9x16xf32>
%93 = vector.insert %92, %91 [0, 6] : f32 into vector<16x9xf32>
%94 = vector.extract %63[7, 0] : vector<9x16xf32>
%95 = vector.insert %94, %93 [0, 7] : f32 into vector<16x9xf32>
%96 = vector.extract %63[8, 0] : vector<9x16xf32>
%97 = vector.insert %96, %95 [0, 8] : f32 into vector<16x9xf32>
%98 = vector.extract %63[0, 1] : vector<9x16xf32>
%99 = vector.insert %98, %97 [1, 0] : f32 into vector<16x9xf32>
%100 = vector.extract %63[1, 1] : vector<9x16xf32>
%101 = vector.insert %100, %99 [1, 1] : f32 into vector<16x9xf32>
%102 = vector.extract %63[2, 1] : vector<9x16xf32>
%103 = vector.insert %102, %101 [1, 2] : f32 into vector<16x9xf32>
%104 = vector.extract %63[3, 1] : vector<9x16xf32>
%105 = vector.insert %104, %103 [1, 3] : f32 into vector<16x9xf32>
%106 = vector.extract %63[4, 1] : vector<9x16xf32>
%107 = vector.insert %106, %105 [1, 4] : f32 into vector<16x9xf32>
%108 = vector.extract %63[5, 1] : vector<9x16xf32>
%109 = vector.insert %108, %107 [1, 5] : f32 into vector<16x9xf32>
%110 = vector.extract %63[6, 1] : vector<9x16xf32>
%111 = vector.insert %110, %109 [1, 6] : f32 into vector<16x9xf32>
%112 = vector.extract %63[7, 1] : vector<9x16xf32>
%113 = vector.insert %112, %111 [1, 7] : f32 into vector<16x9xf32>
%114 = vector.extract %63[8, 1] : vector<9x16xf32>
%115 = vector.insert %114, %113 [1, 8] : f32 into vector<16x9xf32>
%116 = vector.extract %63[0, 2] : vector<9x16xf32>
%117 = vector.insert %116, %115 [2, 0] : f32 into vector<16x9xf32>
%118 = vector.extract %63[1, 2] : vector<9x16xf32>
%119 = vector.insert %118, %117 [2, 1] : f32 into vector<16x9xf32>
%120 = vector.extract %63[2, 2] : vector<9x16xf32>
%121 = vector.insert %120, %119 [2, 2] : f32 into vector<16x9xf32>
%122 = vector.extract %63[3, 2] : vector<9x16xf32>
%123 = vector.insert %122, %121 [2, 3] : f32 into vector<16x9xf32>
%124 = vector.extract %63[4, 2] : vector<9x16xf32>
%125 = vector.insert %124, %123 [2, 4] : f32 into vector<16x9xf32>
%126 = vector.extract %63[5, 2] : vector<9x16xf32>
%127 = vector.insert %126, %125 [2, 5] : f32 into vector<16x9xf32>
%128 = vector.extract %63[6, 2] : vector<9x16xf32>
%129 = vector.insert %128, %127 [2, 6] : f32 into vector<16x9xf32>
%130 = vector.extract %63[7, 2] : vector<9x16xf32>
%131 = vector.insert %130, %129 [2, 7] : f32 into vector<16x9xf32>
%132 = vector.extract %63[8, 2] : vector<9x16xf32>
%133 = vector.insert %132, %131 [2, 8] : f32 into vector<16x9xf32>
%134 = vector.extract %63[0, 3] : vector<9x16xf32>
%135 = vector.insert %134, %133 [3, 0] : f32 into vector<16x9xf32>
%136 = vector.extract %63[1, 3] : vector<9x16xf32>
%137 = vector.insert %136, %135 [3, 1] : f32 into vector<16x9xf32>
%138 = vector.extract %63[2, 3] : vector<9x16xf32>
%139 = vector.insert %138, %137 [3, 2] : f32 into vector<16x9xf32>
%140 = vector.extract %63[3, 3] : vector<9x16xf32>
%141 = vector.insert %140, %139 [3, 3] : f32 into vector<16x9xf32>
%142 = vector.extract %63[4, 3] : vector<9x16xf32>
%143 = vector.insert %142, %141 [3, 4] : f32 into vector<16x9xf32>
%144 = vector.extract %63[5, 3] : vector<9x16xf32>
%145 = vector.insert %144, %143 [3, 5] : f32 into vector<16x9xf32>
%146 = vector.extract %63[6, 3] : vector<9x16xf32>
%147 = vector.insert %146, %145 [3, 6] : f32 into vector<16x9xf32>
%148 = vector.extract %63[7, 3] : vector<9x16xf32>
%149 = vector.insert %148, %147 [3, 7] : f32 into vector<16x9xf32>
%150 = vector.extract %63[8, 3] : vector<9x16xf32>
%151 = vector.insert %150, %149 [3, 8] : f32 into vector<16x9xf32>
%152 = vector.extract %63[0, 4] : vector<9x16xf32>
%153 = vector.insert %152, %151 [4, 0] : f32 into vector<16x9xf32>
%154 = vector.extract %63[1, 4] : vector<9x16xf32>
%155 = vector.insert %154, %153 [4, 1] : f32 into vector<16x9xf32>
%156 = vector.extract %63[2, 4] : vector<9x16xf32>
%157 = vector.insert %156, %155 [4, 2] : f32 into vector<16x9xf32>
%158 = vector.extract %63[3, 4] : vector<9x16xf32>
%159 = vector.insert %158, %157 [4, 3] : f32 into vector<16x9xf32>
%160 = vector.extract %63[4, 4] : vector<9x16xf32>
%161 = vector.insert %160, %159 [4, 4] : f32 into vector<16x9xf32>
%162 = vector.extract %63[5, 4] : vector<9x16xf32>
%163 = vector.insert %162, %161 [4, 5] : f32 into vector<16x9xf32>
%164 = vector.extract %63[6, 4] : vector<9x16xf32>
%165 = vector.insert %164, %163 [4, 6] : f32 into vector<16x9xf32>
%166 = vector.extract %63[7, 4] : vector<9x16xf32>
%167 = vector.insert %166, %165 [4, 7] : f32 into vector<16x9xf32>
%168 = vector.extract %63[8, 4] : vector<9x16xf32>
%169 = vector.insert %168, %167 [4, 8] : f32 into vector<16x9xf32>
%170 = vector.extract %63[0, 5] : vector<9x16xf32>
%171 = vector.insert %170, %169 [5, 0] : f32 into vector<16x9xf32>
%172 = vector.extract %63[1, 5] : vector<9x16xf32>
%173 = vector.insert %172, %171 [5, 1] : f32 into vector<16x9xf32>
%174 = vector.extract %63[2, 5] : vector<9x16xf32>
%175 = vector.insert %174, %173 [5, 2] : f32 into vector<16x9xf32>
%176 = vector.extract %63[3, 5] : vector<9x16xf32>
%177 = vector.insert %176, %175 [5, 3] : f32 into vector<16x9xf32>
%178 = vector.extract %63[4, 5] : vector<9x16xf32>
%179 = vector.insert %178, %177 [5, 4] : f32 into vector<16x9xf32>
%180 = vector.extract %63[5, 5] : vector<9x16xf32>
%181 = vector.insert %180, %179 [5, 5] : f32 into vector<16x9xf32>
%182 = vector.extract %63[6, 5] : vector<9x16xf32>
%183 = vector.insert %182, %181 [5, 6] : f32 into vector<16x9xf32>
%184 = vector.extract %63[7, 5] : vector<9x16xf32>
%185 = vector.insert %184, %183 [5, 7] : f32 into vector<16x9xf32>
%186 = vector.extract %63[8, 5] : vector<9x16xf32>
%187 = vector.insert %186, %185 [5, 8] : f32 into vector<16x9xf32>
%188 = vector.extract %63[0, 6] : vector<9x16xf32>
%189 = vector.insert %188, %187 [6, 0] : f32 into vector<16x9xf32>
%190 = vector.extract %63[1, 6] : vector<9x16xf32>
%191 = vector.insert %190, %189 [6, 1] : f32 into vector<16x9xf32>
%192 = vector.extract %63[2, 6] : vector<9x16xf32>
%193 = vector.insert %192, %191 [6, 2] : f32 into vector<16x9xf32>
%194 = vector.extract %63[3, 6] : vector<9x16xf32>
%195 = vector.insert %194, %193 [6, 3] : f32 into vector<16x9xf32>
%196 = vector.extract %63[4, 6] : vector<9x16xf32>
%197 = vector.insert %196, %195 [6, 4] : f32 into vector<16x9xf32>
%198 = vector.extract %63[5, 6] : vector<9x16xf32>
%199 = vector.insert %198, %197 [6, 5] : f32 into vector<16x9xf32>
%200 = vector.extract %63[6, 6] : vector<9x16xf32>
%201 = vector.insert %200, %199 [6, 6] : f32 into vector<16x9xf32>
%202 = vector.extract %63[7, 6] : vector<9x16xf32>
%203 = vector.insert %202, %201 [6, 7] : f32 into vector<16x9xf32>
%204 = vector.extract %63[8, 6] : vector<9x16xf32>
%205 = vector.insert %204, %203 [6, 8] : f32 into vector<16x9xf32>
%206 = vector.extract %63[0, 7] : vector<9x16xf32>
%207 = vector.insert %206, %205 [7, 0] : f32 into vector<16x9xf32>
%208 = vector.extract %63[1, 7] : vector<9x16xf32>
%209 = vector.insert %208, %207 [7, 1] : f32 into vector<16x9xf32>
%210 = vector.extract %63[2, 7] : vector<9x16xf32>
%211 = vector.insert %210, %209 [7, 2] : f32 into vector<16x9xf32>
%212 = vector.extract %63[3, 7] : vector<9x16xf32>
%213 = vector.insert %212, %211 [7, 3] : f32 into vector<16x9xf32>
%214 = vector.extract %63[4, 7] : vector<9x16xf32>
%215 = vector.insert %214, %213 [7, 4] : f32 into vector<16x9xf32>
%216 = vector.extract %63[5, 7] : vector<9x16xf32>
%217 = vector.insert %216, %215 [7, 5] : f32 into vector<16x9xf32>
%218 = vector.extract %63[6, 7] : vector<9x16xf32>
%219 = vector.insert %218, %217 [7, 6] : f32 into vector<16x9xf32>
%220 = vector.extract %63[7, 7] : vector<9x16xf32>
%221 = vector.insert %220, %219 [7, 7] : f32 into vector<16x9xf32>
%222 = vector.extract %63[8, 7] : vector<9x16xf32>
%223 = vector.insert %222, %221 [7, 8] : f32 into vector<16x9xf32>
%224 = vector.extract %63[0, 8] : vector<9x16xf32>
%225 = vector.insert %224, %223 [8, 0] : f32 into vector<16x9xf32>
%226 = vector.extract %63[1, 8] : vector<9x16xf32>
%227 = vector.insert %226, %225 [8, 1] : f32 into vector<16x9xf32>
%228 = vector.extract %63[2, 8] : vector<9x16xf32>
%229 = vector.insert %228, %227 [8, 2] : f32 into vector<16x9xf32>
%230 = vector.extract %63[3, 8] : vector<9x16xf32>
%231 = vector.insert %230, %229 [8, 3] : f32 into vector<16x9xf32>
%232 = vector.extract %63[4, 8] : vector<9x16xf32>
%233 = vector.insert %232, %231 [8, 4] : f32 into vector<16x9xf32>
%234 = vector.extract %63[5, 8] : vector<9x16xf32>
%235 = vector.insert %234, %233 [8, 5] : f32 into vector<16x9xf32>
%236 = vector.extract %63[6, 8] : vector<9x16xf32>
%237 = vector.insert %236, %235 [8, 6] : f32 into vector<16x9xf32>
%238 = vector.extract %63[7, 8] : vector<9x16xf32>
%239 = vector.insert %238, %237 [8, 7] : f32 into vector<16x9xf32>
%240 = vector.extract %63[8, 8] : vector<9x16xf32>
%241 = vector.insert %240, %239 [8, 8] : f32 into vector<16x9xf32>
%242 = vector.extract %63[0, 9] : vector<9x16xf32>
%243 = vector.insert %242, %241 [9, 0] : f32 into vector<16x9xf32>
%244 = vector.extract %63[1, 9] : vector<9x16xf32>
%245 = vector.insert %244, %243 [9, 1] : f32 into vector<16x9xf32>
%246 = vector.extract %63[2, 9] : vector<9x16xf32>
%247 = vector.insert %246, %245 [9, 2] : f32 into vector<16x9xf32>
%248 = vector.extract %63[3, 9] : vector<9x16xf32>
%249 = vector.insert %248, %247 [9, 3] : f32 into vector<16x9xf32>
%250 = vector.extract %63[4, 9] : vector<9x16xf32>
%251 = vector.insert %250, %249 [9, 4] : f32 into vector<16x9xf32>
%252 = vector.extract %63[5, 9] : vector<9x16xf32>
%253 = vector.insert %252, %251 [9, 5] : f32 into vector<16x9xf32>
%254 = vector.extract %63[6, 9] : vector<9x16xf32>
%255 = vector.insert %254, %253 [9, 6] : f32 into vector<16x9xf32>
%256 = vector.extract %63[7, 9] : vector<9x16xf32>
%257 = vector.insert %256, %255 [9, 7] : f32 into vector<16x9xf32>
%258 = vector.extract %63[8, 9] : vector<9x16xf32>
%259 = vector.insert %258, %257 [9, 8] : f32 into vector<16x9xf32>
%260 = vector.extract %63[0, 10] : vector<9x16xf32>
%261 = vector.insert %260, %259 [10, 0] : f32 into vector<16x9xf32>
%262 = vector.extract %63[1, 10] : vector<9x16xf32>
%263 = vector.insert %262, %261 [10, 1] : f32 into vector<16x9xf32>
%264 = vector.extract %63[2, 10] : vector<9x16xf32>
%265 = vector.insert %264, %263 [10, 2] : f32 into vector<16x9xf32>
%266 = vector.extract %63[3, 10] : vector<9x16xf32>
%267 = vector.insert %266, %265 [10, 3] : f32 into vector<16x9xf32>
%268 = vector.extract %63[4, 10] : vector<9x16xf32>
%269 = vector.insert %268, %267 [10, 4] : f32 into vector<16x9xf32>
%270 = vector.extract %63[5, 10] : vector<9x16xf32>
%271 = vector.insert %270, %269 [10, 5] : f32 into vector<16x9xf32>
%272 = vector.extract %63[6, 10] : vector<9x16xf32>
%273 = vector.insert %272, %271 [10, 6] : f32 into vector<16x9xf32>
%274 = vector.extract %63[7, 10] : vector<9x16xf32>
%275 = vector.insert %274, %273 [10, 7] : f32 into vector<16x9xf32>
%276 = vector.extract %63[8, 10] : vector<9x16xf32>
%277 = vector.insert %276, %275 [10, 8] : f32 into vector<16x9xf32>
%278 = vector.extract %63[0, 11] : vector<9x16xf32>
%279 = vector.insert %278, %277 [11, 0] : f32 into vector<16x9xf32>
%280 = vector.extract %63[1, 11] : vector<9x16xf32>
%281 = vector.insert %280, %279 [11, 1] : f32 into vector<16x9xf32>
%282 = vector.extract %63[2, 11] : vector<9x16xf32>
%283 = vector.insert %282, %281 [11, 2] : f32 into vector<16x9xf32>
%284 = vector.extract %63[3, 11] : vector<9x16xf32>
%285 = vector.insert %284, %283 [11, 3] : f32 into vector<16x9xf32>
%286 = vector.extract %63[4, 11] : vector<9x16xf32>
%287 = vector.insert %286, %285 [11, 4] : f32 into vector<16x9xf32>
%288 = vector.extract %63[5, 11] : vector<9x16xf32>
%289 = vector.insert %288, %287 [11, 5] : f32 into vector<16x9xf32>
%290 = vector.extract %63[6, 11] : vector<9x16xf32>
%291 = vector.insert %290, %289 [11, 6] : f32 into vector<16x9xf32>
%292 = vector.extract %63[7, 11] : vector<9x16xf32>
%293 = vector.insert %292, %291 [11, 7] : f32 into vector<16x9xf32>
%294 = vector.extract %63[8, 11] : vector<9x16xf32>
%295 = vector.insert %294, %293 [11, 8] : f32 into vector<16x9xf32>
%296 = vector.extract %63[0, 12] : vector<9x16xf32>
%297 = vector.insert %296, %295 [12, 0] : f32 into vector<16x9xf32>
%298 = vector.extract %63[1, 12] : vector<9x16xf32>
%299 = vector.insert %298, %297 [12, 1] : f32 into vector<16x9xf32>
%300 = vector.extract %63[2, 12] : vector<9x16xf32>
%301 = vector.insert %300, %299 [12, 2] : f32 into vector<16x9xf32>
%302 = vector.extract %63[3, 12] : vector<9x16xf32>
%303 = vector.insert %302, %301 [12, 3] : f32 into vector<16x9xf32>
%304 = vector.extract %63[4, 12] : vector<9x16xf32>
%305 = vector.insert %304, %303 [12, 4] : f32 into vector<16x9xf32>
%306 = vector.extract %63[5, 12] : vector<9x16xf32>
%307 = vector.insert %306, %305 [12, 5] : f32 into vector<16x9xf32>
%308 = vector.extract %63[6, 12] : vector<9x16xf32>
%309 = vector.insert %308, %307 [12, 6] : f32 into vector<16x9xf32>
%310 = vector.extract %63[7, 12] : vector<9x16xf32>
%311 = vector.insert %310, %309 [12, 7] : f32 into vector<16x9xf32>
%312 = vector.extract %63[8, 12] : vector<9x16xf32>
%313 = vector.insert %312, %311 [12, 8] : f32 into vector<16x9xf32>
%314 = vector.extract %63[0, 13] : vector<9x16xf32>
%315 = vector.insert %314, %313 [13, 0] : f32 into vector<16x9xf32>
%316 = vector.extract %63[1, 13] : vector<9x16xf32>
%317 = vector.insert %316, %315 [13, 1] : f32 into vector<16x9xf32>
%318 = vector.extract %63[2, 13] : vector<9x16xf32>
%319 = vector.insert %318, %317 [13, 2] : f32 into vector<16x9xf32>
%320 = vector.extract %63[3, 13] : vector<9x16xf32>
%321 = vector.insert %320, %319 [13, 3] : f32 into vector<16x9xf32>
%322 = vector.extract %63[4, 13] : vector<9x16xf32>
%323 = vector.insert %322, %321 [13, 4] : f32 into vector<16x9xf32>
%324 = vector.extract %63[5, 13] : vector<9x16xf32>
%325 = vector.insert %324, %323 [13, 5] : f32 into vector<16x9xf32>
%326 = vector.extract %63[6, 13] : vector<9x16xf32>
%327 = vector.insert %326, %325 [13, 6] : f32 into vector<16x9xf32>
%328 = vector.extract %63[7, 13] : vector<9x16xf32>
%329 = vector.insert %328, %327 [13, 7] : f32 into vector<16x9xf32>
%330 = vector.extract %63[8, 13] : vector<9x16xf32>
%331 = vector.insert %330, %329 [13, 8] : f32 into vector<16x9xf32>
%332 = vector.extract %63[0, 14] : vector<9x16xf32>
%333 = vector.insert %332, %331 [14, 0] : f32 into vector<16x9xf32>
%334 = vector.extract %63[1, 14] : vector<9x16xf32>
%335 = vector.insert %334, %333 [14, 1] : f32 into vector<16x9xf32>
%336 = vector.extract %63[2, 14] : vector<9x16xf32>
%337 = vector.insert %336, %335 [14, 2] : f32 into vector<16x9xf32>
%338 = vector.extract %63[3, 14] : vector<9x16xf32>
%339 = vector.insert %338, %337 [14, 3] : f32 into vector<16x9xf32>
%340 = vector.extract %63[4, 14] : vector<9x16xf32>
%341 = vector.insert %340, %339 [14, 4] : f32 into vector<16x9xf32>
%342 = vector.extract %63[5, 14] : vector<9x16xf32>
%343 = vector.insert %342, %341 [14, 5] : f32 into vector<16x9xf32>
%344 = vector.extract %63[6, 14] : vector<9x16xf32>
%345 = vector.insert %344, %343 [14, 6] : f32 into vector<16x9xf32>
%346 = vector.extract %63[7, 14] : vector<9x16xf32>
%347 = vector.insert %346, %345 [14, 7] : f32 into vector<16x9xf32>
%348 = vector.extract %63[8, 14] : vector<9x16xf32>
%349 = vector.insert %348, %347 [14, 8] : f32 into vector<16x9xf32>
%350 = vector.extract %63[0, 15] : vector<9x16xf32>
%351 = vector.insert %350, %349 [15, 0] : f32 into vector<16x9xf32>
%352 = vector.extract %63[1, 15] : vector<9x16xf32>
%353 = vector.insert %352, %351 [15, 1] : f32 into vector<16x9xf32>
%354 = vector.extract %63[2, 15] : vector<9x16xf32>
%355 = vector.insert %354, %353 [15, 2] : f32 into vector<16x9xf32>
%356 = vector.extract %63[3, 15] : vector<9x16xf32>
%357 = vector.insert %356, %355 [15, 3] : f32 into vector<16x9xf32>
%358 = vector.extract %63[4, 15] : vector<9x16xf32>
%359 = vector.insert %358, %357 [15, 4] : f32 into vector<16x9xf32>
%360 = vector.extract %63[5, 15] : vector<9x16xf32>
%361 = vector.insert %360, %359 [15, 5] : f32 into vector<16x9xf32>
%362 = vector.extract %63[6, 15] : vector<9x16xf32>
%363 = vector.insert %362, %361 [15, 6] : f32 into vector<16x9xf32>
%364 = vector.extract %63[7, 15] : vector<9x16xf32>
%365 = vector.insert %364, %363 [15, 7] : f32 into vector<16x9xf32>
%366 = vector.extract %63[8, 15] : vector<9x16xf32>
%367 = vector.insert %366, %365 [15, 8] : f32 into vector<16x9xf32>
%368 = vector.extract %367[0] : vector<16x9xf32>
%369 = vector.outerproduct %368, %64, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%370 = vector.extract %367[1] : vector<16x9xf32>
%371 = vector.outerproduct %370, %65, %369 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%372 = vector.extract %367[2] : vector<16x9xf32>
%373 = vector.outerproduct %372, %66, %371 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%374 = vector.extract %367[3] : vector<16x9xf32>
%375 = vector.outerproduct %374, %67, %373 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%376 = vector.extract %367[4] : vector<16x9xf32>
%377 = vector.outerproduct %376, %68, %375 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%378 = vector.extract %367[5] : vector<16x9xf32>
%379 = vector.outerproduct %378, %69, %377 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%380 = vector.extract %367[6] : vector<16x9xf32>
%381 = vector.outerproduct %380, %70, %379 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%382 = vector.extract %367[7] : vector<16x9xf32>
%383 = vector.outerproduct %382, %71, %381 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%384 = vector.extract %367[8] : vector<16x9xf32>
%385 = vector.outerproduct %384, %72, %383 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%386 = vector.extract %367[9] : vector<16x9xf32>
%387 = vector.outerproduct %386, %73, %385 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%388 = vector.extract %367[10] : vector<16x9xf32>
%389 = vector.outerproduct %388, %74, %387 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%390 = vector.extract %367[11] : vector<16x9xf32>
%391 = vector.outerproduct %390, %75, %389 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%392 = vector.extract %367[12] : vector<16x9xf32>
%393 = vector.outerproduct %392, %76, %391 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%394 = vector.extract %367[13] : vector<16x9xf32>
%395 = vector.outerproduct %394, %77, %393 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%396 = vector.extract %367[14] : vector<16x9xf32>
%397 = vector.outerproduct %396, %78, %395 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%398 = vector.extract %367[15] : vector<16x9xf32>
%399 = vector.outerproduct %398, %79, %397 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %399 : vector<9x32xf32>
}
%35 = scf.if %11 -> (memref<?x32xf32, #map24>) {
%45 = memref.cast %14 : memref<?x32xf32, #map23> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
} else {
%45 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map24>
scf.yield %45 : memref<?x32xf32, #map24>
}
%36 = vector.extract %34[0] : vector<9x32xf32>
vector.store %36, %35[%c0, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%37 = vector.extract %34[1] : vector<9x32xf32>
vector.store %37, %35[%c1, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%38 = vector.extract %34[2] : vector<9x32xf32>
vector.store %38, %35[%c2, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%39 = vector.extract %34[3] : vector<9x32xf32>
vector.store %39, %35[%c3, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%40 = vector.extract %34[4] : vector<9x32xf32>
vector.store %40, %35[%c4, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%41 = vector.extract %34[5] : vector<9x32xf32>
vector.store %41, %35[%c5, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%42 = vector.extract %34[6] : vector<9x32xf32>
vector.store %42, %35[%c6, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%43 = vector.extract %34[7] : vector<9x32xf32>
vector.store %43, %35[%c7, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
%44 = vector.extract %34[8] : vector<9x32xf32>
vector.store %44, %35[%c8, %c0] : memref<?x32xf32, #map24>, vector<32xf32>
scf.if %12 {
%45 = memref.subview %2[0, 0] [%9, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map26>
%46 = memref.subview %14[0, 0] [%9, 32] [1, 1] : memref<?x32xf32, #map23> to memref<?x32xf32, #map23>
linalg.copy(%45, %46) : memref<?x32xf32, #map26>, memref<?x32xf32, #map23>
}
}
}
}
}
}
memref.dealloc %4 : memref<4x16x4x32x16x32xf32>
memref.dealloc %3 : memref<4x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<2048x2048xf32>, memref<2048x2048xf32>, memref<2048x2048xf32>) -> ()
}
return
}
}
compilation in 0.2177s
xxxxxxxxxx : 10 iters time on 1 threads in 0.163s per iter sec (105.4 GFlop/s, 0.3087 GB/s) total time 1.63s
###############################################################
Runtime problem size {'M': 2048, 'N': 2048, 'K': 2048}
Compile-time problem size {'M': -1, 'N': 2048, 'K': -1}
Problem types [<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>]
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43d6affd0>]]]
#map0 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x2048xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c2048 = arith.constant 2048 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x2048xf32> -> tensor<?x2048xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x2048xf32>) {
%4 = affine.min #map0(%arg3)[%1]
%5 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x2048xf32>) {
%6 = affine.min #map1(%arg5)[%2]
%7 = tensor.extract_slice %arg0[%arg3, %arg5] [%4, %6] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%8 = scf.for %arg7 = %c0 to %c2048 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x2048xf32>) {
%9 = tensor.extract_slice %arg1[%arg5, %arg7] [%6, 128] [1, 1] : tensor<?x2048xf32> to tensor<?x128xf32>
%10 = tensor.extract_slice %arg8[%arg3, %arg7] [%4, 128] [1, 1] : tensor<?x2048xf32> to tensor<?x128xf32>
%11 = linalg.matmul ins(%7, %9 : tensor<?x?xf32>, tensor<?x128xf32>) outs(%10 : tensor<?x128xf32>) -> tensor<?x128xf32>
%12 = tensor.insert_slice %11 into %arg8[%arg3, %arg7] [%4, 128] [1, 1] : tensor<?x128xf32> into tensor<?x2048xf32>
scf.yield %12 : tensor<?x2048xf32>
}
scf.yield %8 : tensor<?x2048xf32>
}
scf.yield %5 : tensor<?x2048xf32>
}
return %3 : tensor<?x2048xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x2048xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x2048xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x2048xf32>, tensor<?x2048xf32>) -> tensor<?x2048xf32>
scf.yield %1 : tensor<?x2048xf32>
}
return %0 : tensor<?x2048xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43b6dbc50>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0) -> (-d0 + 16)>
#map9 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map10 = affine_map<(d0) -> (d0 ceildiv 9)>
#map11 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map12 = affine_map<(d0) -> (-d0 + 9)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x2048xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2048 = arith.constant 2048 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x2048xf32> -> tensor<?x2048xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = affine.apply #map0()[%2]
%4 = linalg.init_tensor [%3, 16, 4, 32, 16, 32] : tensor<?x16x4x32x16x32xf32>
%5 = tensor.cast %4 : tensor<?x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%6 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %5) -> (tensor<?x?x?x?x16x32xf32>) {
%10 = affine.apply #map1(%arg3)
%11 = affine.min #map2(%arg3)[%2]
%12 = scf.for %arg5 = %c0 to %c2048 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%13 = affine.apply #map3(%arg5)
%14 = scf.for %arg7 = %c0 to %c128 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%15 = affine.apply #map4(%arg7)
%16 = affine.apply #map5(%arg7, %arg5)
%17 = scf.for %arg9 = %c0 to %11 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map6(%arg9)
%19 = affine.apply #map5(%arg9, %arg3)
%20 = affine.min #map7(%arg9, %11)
%21 = tensor.extract_slice %arg1[%19, %16] [%20, 32] [1, 1] : tensor<?x2048xf32> to tensor<?x32xf32>
%22 = affine.apply #map8(%20)
%23 = linalg.pad_tensor %21 nofold low[%c0, %c0] high[%22, %c0] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x32xf32> to tensor<16x32xf32>
%24 = tensor.insert_slice %23 into %arg10[%10, %13, %15, %18, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<?x?x?x?x16x32xf32>
scf.yield %24 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %14 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %12 : tensor<?x?x?x?x16x32xf32>
}
%7 = linalg.init_tensor [%3, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%8 = tensor.cast %7 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%9 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x2048xf32>) {
%10 = affine.min #map9(%arg3)[%1]
%11 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %8) -> (tensor<?x?x?x9x16xf32>) {
%13 = affine.apply #map1(%arg5)
%14 = affine.min #map2(%arg5)[%2]
%15 = scf.for %arg7 = %c0 to %10 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%16 = affine.apply #map10(%arg7)
%17 = affine.apply #map5(%arg7, %arg3)
%18 = affine.min #map11(%arg7, %10)
%19 = affine.apply #map12(%18)
%20 = scf.for %arg9 = %c0 to %14 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%21 = affine.apply #map6(%arg9)
%22 = affine.apply #map5(%arg9, %arg5)
%23 = affine.min #map7(%arg9, %14)
%24 = tensor.extract_slice %arg0[%17, %22] [%18, %23] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%25 = affine.apply #map8(%23)
%26 = linalg.pad_tensor %24 nofold low[%c0, %c0] high[%19, %25] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x16xf32>
%27 = tensor.insert_slice %26 into %arg10[%13, %16, %21, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<9x16xf32> into tensor<?x?x?x9x16xf32>
scf.yield %27 : tensor<?x?x?x9x16xf32>
}
scf.yield %20 : tensor<?x?x?x9x16xf32>
}
scf.yield %15 : tensor<?x?x?x9x16xf32>
}
%12 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x2048xf32>) {
%13 = affine.min #map2(%arg5)[%2]
%14 = affine.apply #map1(%arg5)
%15 = scf.for %arg7 = %c0 to %c2048 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x2048xf32>) {
%16 = tensor.extract_slice %arg8[%arg3, %arg7] [%10, 128] [1, 1] : tensor<?x2048xf32> to tensor<?x128xf32>
%17 = affine.apply #map3(%arg7)
%18 = scf.for %arg9 = %c0 to %10 step %c9 iter_args(%arg10 = %16) -> (tensor<?x128xf32>) {
%20 = affine.min #map11(%arg9, %10)
%21 = affine.apply #map10(%arg9)
%22 = affine.apply #map12(%20)
%23 = scf.for %arg11 = %c0 to %c128 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x128xf32>) {
%24 = affine.apply #map4(%arg11)
%25 = scf.for %arg13 = %c0 to %13 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x128xf32>) {
%26 = tensor.extract_slice %arg14[%arg9, %arg11] [%20, 32] [1, 1] : tensor<?x128xf32> to tensor<?x32xf32>
%27 = affine.apply #map6(%arg13)
%28 = tensor.extract_slice %11[%14, %21, %27, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<?x?x?x9x16xf32> to tensor<9x16xf32>
%29 = tensor.extract_slice %6[%14, %17, %24, %27, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<?x?x?x?x16x32xf32> to tensor<16x32xf32>
%30 = linalg.pad_tensor %26 low[%c0, %c0] high[%22, %c0] {
^bb0(%arg15: index, %arg16: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x32xf32> to tensor<9x32xf32>
%31 = linalg.matmul ins(%28, %29 : tensor<9x16xf32>, tensor<16x32xf32>) outs(%30 : tensor<9x32xf32>) -> tensor<9x32xf32>
%32 = tensor.extract_slice %31[0, 0] [%20, 32] [1, 1] : tensor<9x32xf32> to tensor<?x32xf32>
%33 = tensor.insert_slice %32 into %arg14[%arg9, %arg11] [%20, 32] [1, 1] : tensor<?x32xf32> into tensor<?x128xf32>
scf.yield %33 : tensor<?x128xf32>
}
scf.yield %25 : tensor<?x128xf32>
}
scf.yield %23 : tensor<?x128xf32>
}
%19 = tensor.insert_slice %18 into %arg8[%arg3, %arg7] [%10, 128] [1, 1] : tensor<?x128xf32> into tensor<?x2048xf32>
scf.yield %19 : tensor<?x2048xf32>
}
scf.yield %15 : tensor<?x2048xf32>
}
scf.yield %12 : tensor<?x2048xf32>
}
return %9 : tensor<?x2048xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x2048xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x2048xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x2048xf32>, tensor<?x2048xf32>) -> tensor<?x2048xf32>
scf.yield %1 : tensor<?x2048xf32>
}
return %0 : tensor<?x2048xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Vectorize object at 0x7fb43b6bdc10>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map9 = affine_map<(d0) -> (d0 ceildiv 9)>
#map10 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map11 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map12 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map13 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x2048xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2048 = arith.constant 2048 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x2048xf32> -> tensor<?x2048xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = affine.apply #map0()[%2]
%4 = linalg.init_tensor [%3, 16, 4, 32, 16, 32] : tensor<?x16x4x32x16x32xf32>
%5 = tensor.cast %4 : tensor<?x16x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%6 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %5) -> (tensor<?x?x?x?x16x32xf32>) {
%10 = affine.apply #map1(%arg3)
%11 = affine.min #map2(%arg3)[%2]
%12 = scf.for %arg5 = %c0 to %c2048 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%13 = affine.apply #map3(%arg5)
%14 = scf.for %arg7 = %c0 to %c128 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%15 = affine.apply #map4(%arg7)
%16 = affine.apply #map5(%arg7, %arg5)
%17 = scf.for %arg9 = %c0 to %11 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map6(%arg9)
%19 = affine.apply #map5(%arg9, %arg3)
%20 = affine.min #map7(%arg9, %11)
%21 = tensor.extract_slice %arg1[%19, %16] [%20, 32] [1, 1] : tensor<?x2048xf32> to tensor<?x32xf32>
%22 = vector.transfer_read %21[%c0, %c0], %cst {in_bounds = [false, true]} : tensor<?x32xf32>, vector<16x32xf32>
%23 = vector.transfer_write %22, %arg10[%10, %13, %15, %18, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, tensor<?x?x?x?x16x32xf32>
scf.yield %23 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %14 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %12 : tensor<?x?x?x?x16x32xf32>
}
%7 = linalg.init_tensor [%3, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%8 = tensor.cast %7 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%9 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x2048xf32>) {
%10 = affine.min #map8(%arg3)[%1]
%11 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %8) -> (tensor<?x?x?x9x16xf32>) {
%13 = affine.apply #map1(%arg5)
%14 = affine.min #map2(%arg5)[%2]
%15 = scf.for %arg7 = %c0 to %10 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%16 = affine.apply #map9(%arg7)
%17 = affine.apply #map5(%arg7, %arg3)
%18 = affine.min #map10(%arg7, %10)
%19 = scf.for %arg9 = %c0 to %14 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%20 = affine.apply #map6(%arg9)
%21 = affine.apply #map5(%arg9, %arg5)
%22 = affine.min #map7(%arg9, %14)
%23 = tensor.extract_slice %arg0[%17, %21] [%18, %22] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%24 = vector.transfer_read %23[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x16xf32>
%25 = vector.transfer_write %24, %arg10[%13, %16, %20, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, tensor<?x?x?x9x16xf32>
scf.yield %25 : tensor<?x?x?x9x16xf32>
}
scf.yield %19 : tensor<?x?x?x9x16xf32>
}
scf.yield %15 : tensor<?x?x?x9x16xf32>
}
%12 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x2048xf32>) {
%13 = affine.min #map2(%arg5)[%2]
%14 = affine.apply #map1(%arg5)
%15 = scf.for %arg7 = %c0 to %c2048 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x2048xf32>) {
%16 = tensor.extract_slice %arg8[%arg3, %arg7] [%10, 128] [1, 1] : tensor<?x2048xf32> to tensor<?x128xf32>
%17 = affine.apply #map3(%arg7)
%18 = scf.for %arg9 = %c0 to %10 step %c9 iter_args(%arg10 = %16) -> (tensor<?x128xf32>) {
%20 = affine.min #map10(%arg9, %10)
%21 = affine.apply #map9(%arg9)
%22 = scf.for %arg11 = %c0 to %c128 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x128xf32>) {
%23 = affine.apply #map4(%arg11)
%24 = scf.for %arg13 = %c0 to %13 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x128xf32>) {
%25 = tensor.extract_slice %arg14[%arg9, %arg11] [%20, 32] [1, 1] : tensor<?x128xf32> to tensor<?x32xf32>
%26 = affine.apply #map6(%arg13)
%27 = vector.transfer_read %11[%14, %21, %26, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x9x16xf32>, vector<9x16xf32>
%28 = vector.transfer_read %6[%14, %17, %23, %26, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x?x16x32xf32>, vector<16x32xf32>
%29 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [false, true]} : tensor<?x32xf32>, vector<9x32xf32>
%30 = vector.contract {indexing_maps = [#map11, #map12, #map13], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %27, %28, %29 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
%31 = vector.transfer_write %30, %25[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, tensor<?x32xf32>
%32 = tensor.insert_slice %31 into %arg14[%arg9, %arg11] [%20, 32] [1, 1] : tensor<?x32xf32> into tensor<?x128xf32>
scf.yield %32 : tensor<?x128xf32>
}
scf.yield %24 : tensor<?x128xf32>
}
scf.yield %22 : tensor<?x128xf32>
}
%19 = tensor.insert_slice %18 into %arg8[%arg3, %arg7] [%10, 128] [1, 1] : tensor<?x128xf32> into tensor<?x2048xf32>
scf.yield %19 : tensor<?x2048xf32>
}
scf.yield %15 : tensor<?x2048xf32>
}
scf.yield %12 : tensor<?x2048xf32>
}
return %9 : tensor<?x2048xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x2048xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x2048xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x2048xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x2048xf32>, tensor<?x2048xf32>) -> tensor<?x2048xf32>
scf.yield %1 : tensor<?x2048xf32>
}
return %0 : tensor<?x2048xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Bufferize object at 0x7fb43b692110>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map10 = affine_map<(d0) -> (d0 ceildiv 9)>
#map11 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map12 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map13 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map14 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map15 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2048 = arith.constant 2048 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
linalg.fill(%cst, %arg2) : f32, memref<?x2048xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = affine.apply #map0()[%1]
%3 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%5 = affine.apply #map1(%arg3)
%6 = affine.min #map2(%arg3)[%1]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%7 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%8 = affine.apply #map4(%arg5)
%9 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %6 step %c16 {
%10 = affine.apply #map6(%arg6)
%11 = affine.apply #map5(%arg6, %arg3)
%12 = affine.min #map7(%arg6, %6)
%13 = memref.subview %arg1[%11, %9] [%12, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%14 = vector.transfer_read %13[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<16x32xf32>
vector.transfer_write %14, %3[%5, %7, %8, %10, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%4 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%5 = affine.min #map9(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.apply #map1(%arg4)
%7 = affine.min #map2(%arg4)[%1]
scf.for %arg5 = %c0 to %5 step %c9 {
%8 = affine.apply #map10(%arg5)
%9 = affine.apply #map5(%arg5, %arg3)
%10 = affine.min #map11(%arg5, %5)
scf.for %arg6 = %c0 to %7 step %c16 {
%11 = affine.apply #map6(%arg6)
%12 = affine.apply #map5(%arg6, %arg4)
%13 = affine.min #map7(%arg6, %7)
%14 = memref.subview %arg0[%9, %12] [%10, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map12>
%15 = vector.transfer_read %14[%c0, %c0], %cst : memref<?x?xf32, #map12>, vector<9x16xf32>
vector.transfer_write %15, %4[%6, %8, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.min #map2(%arg4)[%1]
%7 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%8 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%9 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%10 = affine.min #map11(%arg6, %5)
%11 = affine.apply #map10(%arg6)
scf.for %arg7 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg7)
%13 = memref.subview %8[%arg6, %arg7] [%10, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%14 = vector.transfer_read %13[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<9x32xf32>
%15 = scf.for %arg8 = %c0 to %6 step %c16 iter_args(%arg9 = %14) -> (vector<9x32xf32>) {
%16 = affine.apply #map6(%arg8)
%17 = vector.transfer_read %4[%7, %11, %16, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%18 = vector.transfer_read %3[%7, %9, %12, %16, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%19 = vector.contract {indexing_maps = [#map13, #map14, #map15], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %17, %18, %arg9 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
scf.yield %19 : vector<9x32xf32>
}
vector.transfer_write %15, %13[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, memref<?x32xf32, #map8>
}
}
}
}
}
memref.dealloc %3 : memref<?x16x4x32x16x32xf32>
memref.dealloc %4 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692190>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map10 = affine_map<(d0) -> (d0 ceildiv 9)>
#map11 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map12 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2048 = arith.constant 2048 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
linalg.fill(%cst, %arg2) : f32, memref<?x2048xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = affine.apply #map0()[%1]
%3 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%5 = affine.apply #map1(%arg3)
%6 = affine.min #map2(%arg3)[%1]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%7 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%8 = affine.apply #map4(%arg5)
%9 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %6 step %c16 {
%10 = affine.apply #map6(%arg6)
%11 = affine.apply #map5(%arg6, %arg3)
%12 = affine.min #map7(%arg6, %6)
%13 = memref.subview %arg1[%11, %9] [%12, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%14 = vector.transfer_read %13[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<16x32xf32>
vector.transfer_write %14, %3[%5, %7, %8, %10, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%4 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%5 = affine.min #map9(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.apply #map1(%arg4)
%7 = affine.min #map2(%arg4)[%1]
scf.for %arg5 = %c0 to %5 step %c9 {
%8 = affine.apply #map10(%arg5)
%9 = affine.apply #map5(%arg5, %arg3)
%10 = affine.min #map11(%arg5, %5)
scf.for %arg6 = %c0 to %7 step %c16 {
%11 = affine.apply #map6(%arg6)
%12 = affine.apply #map5(%arg6, %arg4)
%13 = affine.min #map7(%arg6, %7)
%14 = memref.subview %arg0[%9, %12] [%10, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map12>
%15 = vector.transfer_read %14[%c0, %c0], %cst : memref<?x?xf32, #map12>, vector<9x16xf32>
vector.transfer_write %15, %4[%6, %8, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.min #map2(%arg4)[%1]
%7 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%8 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%9 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%10 = affine.min #map11(%arg6, %5)
%11 = affine.apply #map10(%arg6)
scf.for %arg7 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg7)
%13 = memref.subview %8[%arg6, %arg7] [%10, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%14 = vector.transfer_read %13[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<9x32xf32>
%15 = scf.for %arg8 = %c0 to %6 step %c16 iter_args(%arg9 = %14) -> (vector<9x32xf32>) {
%16 = affine.apply #map6(%arg8)
%17 = vector.transfer_read %4[%7, %11, %16, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%18 = vector.transfer_read %3[%7, %9, %12, %16, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%19 = vector.transpose %17, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%20 = vector.extract %19[0] : vector<16x9xf32>
%21 = vector.extract %18[0] : vector<16x32xf32>
%22 = vector.outerproduct %20, %21, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%23 = vector.extract %19[1] : vector<16x9xf32>
%24 = vector.extract %18[1] : vector<16x32xf32>
%25 = vector.outerproduct %23, %24, %22 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%26 = vector.extract %19[2] : vector<16x9xf32>
%27 = vector.extract %18[2] : vector<16x32xf32>
%28 = vector.outerproduct %26, %27, %25 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%29 = vector.extract %19[3] : vector<16x9xf32>
%30 = vector.extract %18[3] : vector<16x32xf32>
%31 = vector.outerproduct %29, %30, %28 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%32 = vector.extract %19[4] : vector<16x9xf32>
%33 = vector.extract %18[4] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %31 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %19[5] : vector<16x9xf32>
%36 = vector.extract %18[5] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %19[6] : vector<16x9xf32>
%39 = vector.extract %18[6] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %19[7] : vector<16x9xf32>
%42 = vector.extract %18[7] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %19[8] : vector<16x9xf32>
%45 = vector.extract %18[8] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %19[9] : vector<16x9xf32>
%48 = vector.extract %18[9] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %19[10] : vector<16x9xf32>
%51 = vector.extract %18[10] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %19[11] : vector<16x9xf32>
%54 = vector.extract %18[11] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %19[12] : vector<16x9xf32>
%57 = vector.extract %18[12] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %19[13] : vector<16x9xf32>
%60 = vector.extract %18[13] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %19[14] : vector<16x9xf32>
%63 = vector.extract %18[14] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %19[15] : vector<16x9xf32>
%66 = vector.extract %18[15] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %67 : vector<9x32xf32>
}
vector.transfer_write %15, %13[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, memref<?x32xf32, #map8>
}
}
}
}
}
memref.dealloc %3 : memref<?x16x4x32x16x32xf32>
memref.dealloc %4 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6923d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map10 = affine_map<(d0) -> (d0 ceildiv 9)>
#map11 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map12 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2048 = arith.constant 2048 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
linalg.fill(%cst, %arg2) : f32, memref<?x2048xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = affine.apply #map0()[%1]
%3 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%5 = affine.apply #map1(%arg3)
%6 = affine.min #map2(%arg3)[%1]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%7 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%8 = affine.apply #map4(%arg5)
%9 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %6 step %c16 {
%10 = affine.apply #map6(%arg6)
%11 = affine.apply #map5(%arg6, %arg3)
%12 = affine.min #map7(%arg6, %6)
%13 = memref.subview %arg1[%11, %9] [%12, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%14 = vector.transfer_read %13[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<16x32xf32>
vector.transfer_write %14, %3[%5, %7, %8, %10, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%4 = memref.alloc(%2) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%5 = affine.min #map9(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.apply #map1(%arg4)
%7 = affine.min #map2(%arg4)[%1]
scf.for %arg5 = %c0 to %5 step %c9 {
%8 = affine.apply #map10(%arg5)
%9 = affine.apply #map5(%arg5, %arg3)
%10 = affine.min #map11(%arg5, %5)
scf.for %arg6 = %c0 to %7 step %c16 {
%11 = affine.apply #map6(%arg6)
%12 = affine.apply #map5(%arg6, %arg4)
%13 = affine.min #map7(%arg6, %7)
%14 = memref.subview %arg0[%9, %12] [%10, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map12>
%15 = vector.transfer_read %14[%c0, %c0], %cst : memref<?x?xf32, #map12>, vector<9x16xf32>
vector.transfer_write %15, %4[%6, %8, %11, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%6 = affine.min #map2(%arg4)[%1]
%7 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%8 = memref.subview %arg2[%arg3, %arg5] [%5, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%9 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %5 step %c9 {
%10 = affine.min #map11(%arg6, %5)
%11 = affine.apply #map10(%arg6)
scf.for %arg7 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg7)
%13 = memref.subview %8[%arg6, %arg7] [%10, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%14 = vector.transfer_read %13[%c0, %c0], %cst {in_bounds = [false, true]} : memref<?x32xf32, #map8>, vector<9x32xf32>
%15 = scf.for %arg8 = %c0 to %6 step %c16 iter_args(%arg9 = %14) -> (vector<9x32xf32>) {
%16 = affine.apply #map6(%arg8)
%17 = vector.transfer_read %4[%7, %11, %16, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%18 = vector.transfer_read %3[%7, %9, %12, %16, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%19 = vector.transpose %17, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%20 = vector.extract %19[0] : vector<16x9xf32>
%21 = vector.extract %18[0] : vector<16x32xf32>
%22 = vector.outerproduct %20, %21, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%23 = vector.extract %19[1] : vector<16x9xf32>
%24 = vector.extract %18[1] : vector<16x32xf32>
%25 = vector.outerproduct %23, %24, %22 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%26 = vector.extract %19[2] : vector<16x9xf32>
%27 = vector.extract %18[2] : vector<16x32xf32>
%28 = vector.outerproduct %26, %27, %25 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%29 = vector.extract %19[3] : vector<16x9xf32>
%30 = vector.extract %18[3] : vector<16x32xf32>
%31 = vector.outerproduct %29, %30, %28 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%32 = vector.extract %19[4] : vector<16x9xf32>
%33 = vector.extract %18[4] : vector<16x32xf32>
%34 = vector.outerproduct %32, %33, %31 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%35 = vector.extract %19[5] : vector<16x9xf32>
%36 = vector.extract %18[5] : vector<16x32xf32>
%37 = vector.outerproduct %35, %36, %34 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%38 = vector.extract %19[6] : vector<16x9xf32>
%39 = vector.extract %18[6] : vector<16x32xf32>
%40 = vector.outerproduct %38, %39, %37 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%41 = vector.extract %19[7] : vector<16x9xf32>
%42 = vector.extract %18[7] : vector<16x32xf32>
%43 = vector.outerproduct %41, %42, %40 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%44 = vector.extract %19[8] : vector<16x9xf32>
%45 = vector.extract %18[8] : vector<16x32xf32>
%46 = vector.outerproduct %44, %45, %43 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%47 = vector.extract %19[9] : vector<16x9xf32>
%48 = vector.extract %18[9] : vector<16x32xf32>
%49 = vector.outerproduct %47, %48, %46 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%50 = vector.extract %19[10] : vector<16x9xf32>
%51 = vector.extract %18[10] : vector<16x32xf32>
%52 = vector.outerproduct %50, %51, %49 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%53 = vector.extract %19[11] : vector<16x9xf32>
%54 = vector.extract %18[11] : vector<16x32xf32>
%55 = vector.outerproduct %53, %54, %52 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%56 = vector.extract %19[12] : vector<16x9xf32>
%57 = vector.extract %18[12] : vector<16x32xf32>
%58 = vector.outerproduct %56, %57, %55 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%59 = vector.extract %19[13] : vector<16x9xf32>
%60 = vector.extract %18[13] : vector<16x32xf32>
%61 = vector.outerproduct %59, %60, %58 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%62 = vector.extract %19[14] : vector<16x9xf32>
%63 = vector.extract %18[14] : vector<16x32xf32>
%64 = vector.outerproduct %62, %63, %61 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%65 = vector.extract %19[15] : vector<16x9xf32>
%66 = vector.extract %18[15] : vector<16x32xf32>
%67 = vector.outerproduct %65, %66, %64 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %67 : vector<9x32xf32>
}
vector.transfer_write %15, %13[%c0, %c0] {in_bounds = [false, true]} : vector<9x32xf32>, memref<?x32xf32, #map8>
}
}
}
}
}
memref.dealloc %3 : memref<?x16x4x32x16x32xf32>
memref.dealloc %4 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692210>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2048 = arith.constant 2048 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x2048xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%11 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg5)
%13 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %10 step %c16 {
%14 = affine.apply #map6(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map7(%arg6, %10)
%17 = memref.subview %arg1[%15, %13] [%16, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = scf.if %18 -> (memref<?x32xf32, #map9>) {
%21 = memref.cast %17 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %21 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%21 = memref.subview %17[0, 0] [%16, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%22 = memref.subview %0[0, 0] [%16, 32] [1, 1] : memref<16x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%21, %22) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%23 = memref.cast %0 : memref<16x32xf32> to memref<?x32xf32, #map9>
scf.yield %23 : memref<?x32xf32, #map9>
}
%20 = vector.transfer_read %19[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32xf32, #map9>, vector<16x32xf32>
vector.transfer_write %20, %7[%9, %11, %12, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map11(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map12(%arg5)
%13 = affine.apply #map5(%arg5, %arg3)
%14 = affine.min #map13(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map6(%arg6)
%17 = affine.apply #map5(%arg6, %arg4)
%18 = affine.min #map7(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map9>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map9>) {
scf.yield %19 : memref<?x?xf32, #map9>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%24 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%25 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map14>
linalg.copy(%24, %25) : memref<?x?xf32, #map9>, memref<?x?xf32, #map14>
%26 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map9>
scf.yield %26 : memref<?x?xf32, #map9>
}
%23 = vector.transfer_read %22[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map9>, vector<9x16xf32>
vector.transfer_write %23, %8[%10, %12, %16, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%12 = memref.subview %arg2[%arg3, %arg5] [%9, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%13 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%14 = affine.min #map13(%arg6, %9)
%15 = affine.apply #map12(%arg6)
%16 = arith.cmpi sle, %c9, %14 : index
%17 = arith.cmpi sgt, %c9, %14 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%18 = affine.apply #map4(%arg7)
%19 = memref.subview %12[%arg6, %arg7] [%14, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%20 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%24 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %24 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%24 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%25 = memref.subview %2[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%24, %25) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%26 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %26 : memref<?x32xf32, #map9>
}
%21 = vector.transfer_read %20[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32xf32, #map9>, vector<9x32xf32>
%22 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %21) -> (vector<9x32xf32>) {
%24 = affine.apply #map6(%arg8)
%25 = vector.transfer_read %8[%11, %15, %24, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%26 = vector.transfer_read %7[%11, %13, %18, %24, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%27 = vector.transpose %25, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%28 = vector.extract %27[0] : vector<16x9xf32>
%29 = vector.extract %26[0] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %27[1] : vector<16x9xf32>
%32 = vector.extract %26[1] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %27[2] : vector<16x9xf32>
%35 = vector.extract %26[2] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %27[3] : vector<16x9xf32>
%38 = vector.extract %26[3] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %27[4] : vector<16x9xf32>
%41 = vector.extract %26[4] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %27[5] : vector<16x9xf32>
%44 = vector.extract %26[5] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %27[6] : vector<16x9xf32>
%47 = vector.extract %26[6] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %27[7] : vector<16x9xf32>
%50 = vector.extract %26[7] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %27[8] : vector<16x9xf32>
%53 = vector.extract %26[8] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %27[9] : vector<16x9xf32>
%56 = vector.extract %26[9] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %27[10] : vector<16x9xf32>
%59 = vector.extract %26[10] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %27[11] : vector<16x9xf32>
%62 = vector.extract %26[11] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %27[12] : vector<16x9xf32>
%65 = vector.extract %26[12] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %27[13] : vector<16x9xf32>
%68 = vector.extract %26[13] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%70 = vector.extract %27[14] : vector<16x9xf32>
%71 = vector.extract %26[14] : vector<16x32xf32>
%72 = vector.outerproduct %70, %71, %69 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%73 = vector.extract %27[15] : vector<16x9xf32>
%74 = vector.extract %26[15] : vector<16x32xf32>
%75 = vector.outerproduct %73, %74, %72 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %75 : vector<9x32xf32>
}
%23 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%24 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %24 : memref<?x32xf32, #map9>
} else {
%24 = memref.cast %3 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %24 : memref<?x32xf32, #map9>
}
vector.transfer_write %22, %23[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x32xf32, #map9>
scf.if %17 {
%24 = memref.subview %3[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
%25 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%24, %25) : memref<?x32xf32, #map10>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6921d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2048 = arith.constant 2048 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x2048xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%11 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg5)
%13 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %10 step %c16 {
%14 = affine.apply #map6(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map7(%arg6, %10)
%17 = memref.subview %arg1[%15, %13] [%16, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = scf.if %18 -> (memref<?x32xf32, #map9>) {
%21 = memref.cast %17 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %21 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%21 = memref.subview %17[0, 0] [%16, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%22 = memref.subview %0[0, 0] [%16, 32] [1, 1] : memref<16x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%21, %22) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%23 = memref.cast %0 : memref<16x32xf32> to memref<?x32xf32, #map9>
scf.yield %23 : memref<?x32xf32, #map9>
}
%20 = vector.transfer_read %19[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32xf32, #map9>, vector<16x32xf32>
vector.transfer_write %20, %7[%9, %11, %12, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x16x4x32x16x32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map11(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map12(%arg5)
%13 = affine.apply #map5(%arg5, %arg3)
%14 = affine.min #map13(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map6(%arg6)
%17 = affine.apply #map5(%arg6, %arg4)
%18 = affine.min #map7(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map9>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map9>) {
scf.yield %19 : memref<?x?xf32, #map9>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%24 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%25 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map14>
linalg.copy(%24, %25) : memref<?x?xf32, #map9>, memref<?x?xf32, #map14>
%26 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map9>
scf.yield %26 : memref<?x?xf32, #map9>
}
%23 = vector.transfer_read %22[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map9>, vector<9x16xf32>
vector.transfer_write %23, %8[%10, %12, %16, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%12 = memref.subview %arg2[%arg3, %arg5] [%9, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%13 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%14 = affine.min #map13(%arg6, %9)
%15 = affine.apply #map12(%arg6)
%16 = arith.cmpi sle, %c9, %14 : index
%17 = arith.cmpi sgt, %c9, %14 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%18 = affine.apply #map4(%arg7)
%19 = memref.subview %12[%arg6, %arg7] [%14, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%20 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%24 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %24 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%24 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%25 = memref.subview %2[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%24, %25) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%26 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %26 : memref<?x32xf32, #map9>
}
%21 = vector.transfer_read %20[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32xf32, #map9>, vector<9x32xf32>
%22 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %21) -> (vector<9x32xf32>) {
%24 = affine.apply #map6(%arg8)
%25 = vector.transfer_read %8[%11, %15, %24, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%26 = vector.transfer_read %7[%11, %13, %18, %24, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x16x4x32x16x32xf32>, vector<16x32xf32>
%27 = vector.transpose %25, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%28 = vector.extract %27[0] : vector<16x9xf32>
%29 = vector.extract %26[0] : vector<16x32xf32>
%30 = vector.outerproduct %28, %29, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%31 = vector.extract %27[1] : vector<16x9xf32>
%32 = vector.extract %26[1] : vector<16x32xf32>
%33 = vector.outerproduct %31, %32, %30 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%34 = vector.extract %27[2] : vector<16x9xf32>
%35 = vector.extract %26[2] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %33 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %27[3] : vector<16x9xf32>
%38 = vector.extract %26[3] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %27[4] : vector<16x9xf32>
%41 = vector.extract %26[4] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %27[5] : vector<16x9xf32>
%44 = vector.extract %26[5] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %27[6] : vector<16x9xf32>
%47 = vector.extract %26[6] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %27[7] : vector<16x9xf32>
%50 = vector.extract %26[7] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %27[8] : vector<16x9xf32>
%53 = vector.extract %26[8] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %27[9] : vector<16x9xf32>
%56 = vector.extract %26[9] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %27[10] : vector<16x9xf32>
%59 = vector.extract %26[10] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %27[11] : vector<16x9xf32>
%62 = vector.extract %26[11] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %27[12] : vector<16x9xf32>
%65 = vector.extract %26[12] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %27[13] : vector<16x9xf32>
%68 = vector.extract %26[13] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%70 = vector.extract %27[14] : vector<16x9xf32>
%71 = vector.extract %26[14] : vector<16x32xf32>
%72 = vector.outerproduct %70, %71, %69 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%73 = vector.extract %27[15] : vector<16x9xf32>
%74 = vector.extract %26[15] : vector<16x32xf32>
%75 = vector.outerproduct %73, %74, %72 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %75 : vector<9x32xf32>
}
%23 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%24 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %24 : memref<?x32xf32, #map9>
} else {
%24 = memref.cast %3 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %24 : memref<?x32xf32, #map9>
}
vector.transfer_write %22, %23[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x32xf32, #map9>
scf.if %17 {
%24 = memref.subview %3[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
%25 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%24, %25) : memref<?x32xf32, #map10>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692150>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2048 = arith.constant 2048 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_1, %arg2) : f32, memref<?x2048xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%11 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg5)
%13 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %10 step %c16 {
%14 = affine.apply #map6(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map7(%arg6, %10)
%17 = memref.subview %arg1[%15, %13] [%16, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = scf.if %18 -> (memref<?x32xf32, #map9>) {
%36 = memref.cast %17 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %36 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst_1, %0) : f32, memref<16x32xf32>
%36 = memref.subview %17[0, 0] [%16, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%37 = memref.subview %0[0, 0] [%16, 32] [1, 1] : memref<16x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%36, %37) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%38 = memref.cast %0 : memref<16x32xf32> to memref<?x32xf32, #map9>
scf.yield %38 : memref<?x32xf32, #map9>
}
%20 = vector.load %19[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%21 = vector.load %19[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%22 = vector.load %19[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%23 = vector.load %19[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%24 = vector.load %19[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%25 = vector.load %19[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%26 = vector.load %19[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%27 = vector.load %19[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%28 = vector.load %19[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%29 = vector.load %19[%c9, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%30 = vector.load %19[%c10, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%31 = vector.load %19[%c11, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%32 = vector.load %19[%c12, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%33 = vector.load %19[%c13, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%34 = vector.load %19[%c14, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%35 = vector.load %19[%c15, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
vector.store %20, %7[%9, %11, %12, %14, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %21, %7[%9, %11, %12, %14, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %22, %7[%9, %11, %12, %14, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %7[%9, %11, %12, %14, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %24, %7[%9, %11, %12, %14, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %7[%9, %11, %12, %14, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %7[%9, %11, %12, %14, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %7[%9, %11, %12, %14, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %7[%9, %11, %12, %14, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %7[%9, %11, %12, %14, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %7[%9, %11, %12, %14, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %7[%9, %11, %12, %14, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %7[%9, %11, %12, %14, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %7[%9, %11, %12, %14, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %7[%9, %11, %12, %14, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %7[%9, %11, %12, %14, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map11(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map12(%arg5)
%13 = affine.apply #map5(%arg5, %arg3)
%14 = affine.min #map13(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map6(%arg6)
%17 = affine.apply #map5(%arg6, %arg4)
%18 = affine.min #map7(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map9>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map9>) {
scf.yield %19 : memref<?x?xf32, #map9>
} else {
linalg.fill(%cst_1, %1) : f32, memref<9x16xf32>
%32 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%33 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map14>
linalg.copy(%32, %33) : memref<?x?xf32, #map9>, memref<?x?xf32, #map14>
%34 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map9>
scf.yield %34 : memref<?x?xf32, #map9>
}
%23 = vector.load %22[%c0, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%24 = vector.load %22[%c1, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%25 = vector.load %22[%c2, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%26 = vector.load %22[%c3, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%27 = vector.load %22[%c4, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%28 = vector.load %22[%c5, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%29 = vector.load %22[%c6, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%30 = vector.load %22[%c7, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%31 = vector.load %22[%c8, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
vector.store %23, %8[%10, %12, %16, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %8[%10, %12, %16, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %8[%10, %12, %16, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %8[%10, %12, %16, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %8[%10, %12, %16, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %8[%10, %12, %16, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %8[%10, %12, %16, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %8[%10, %12, %16, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %8[%10, %12, %16, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%12 = memref.subview %arg2[%arg3, %arg5] [%9, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%13 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%14 = affine.min #map13(%arg6, %9)
%15 = affine.apply #map12(%arg6)
%16 = arith.cmpi sle, %c9, %14 : index
%17 = arith.cmpi sgt, %c9, %14 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%18 = affine.apply #map4(%arg7)
%19 = memref.subview %12[%arg6, %arg7] [%14, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%20 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%50 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst_1, %2) : f32, memref<9x32xf32>
%50 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%51 = memref.subview %2[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%50, %51) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%52 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %52 : memref<?x32xf32, #map9>
}
%21 = vector.load %20[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%22 = vector.insert %21, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%23 = vector.load %20[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%24 = vector.insert %23, %22 [1] : vector<32xf32> into vector<9x32xf32>
%25 = vector.load %20[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%26 = vector.insert %25, %24 [2] : vector<32xf32> into vector<9x32xf32>
%27 = vector.load %20[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%28 = vector.insert %27, %26 [3] : vector<32xf32> into vector<9x32xf32>
%29 = vector.load %20[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%30 = vector.insert %29, %28 [4] : vector<32xf32> into vector<9x32xf32>
%31 = vector.load %20[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%32 = vector.insert %31, %30 [5] : vector<32xf32> into vector<9x32xf32>
%33 = vector.load %20[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%34 = vector.insert %33, %32 [6] : vector<32xf32> into vector<9x32xf32>
%35 = vector.load %20[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%36 = vector.insert %35, %34 [7] : vector<32xf32> into vector<9x32xf32>
%37 = vector.load %20[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%38 = vector.insert %37, %36 [8] : vector<32xf32> into vector<9x32xf32>
%39 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %38) -> (vector<9x32xf32>) {
%50 = affine.apply #map6(%arg8)
%51 = vector.load %8[%11, %15, %50, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%52 = vector.insert %51, %cst [0] : vector<16xf32> into vector<9x16xf32>
%53 = vector.load %8[%11, %15, %50, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%54 = vector.insert %53, %52 [1] : vector<16xf32> into vector<9x16xf32>
%55 = vector.load %8[%11, %15, %50, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%56 = vector.insert %55, %54 [2] : vector<16xf32> into vector<9x16xf32>
%57 = vector.load %8[%11, %15, %50, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %56 [3] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %8[%11, %15, %50, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [4] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %8[%11, %15, %50, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [5] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %8[%11, %15, %50, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [6] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %8[%11, %15, %50, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [7] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %8[%11, %15, %50, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [8] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %7[%11, %13, %18, %50, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%70 = vector.load %7[%11, %13, %18, %50, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %7[%11, %13, %18, %50, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %7[%11, %13, %18, %50, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %7[%11, %13, %18, %50, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %7[%11, %13, %18, %50, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %7[%11, %13, %18, %50, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %7[%11, %13, %18, %50, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %7[%11, %13, %18, %50, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %7[%11, %13, %18, %50, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %7[%11, %13, %18, %50, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %7[%11, %13, %18, %50, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %7[%11, %13, %18, %50, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %7[%11, %13, %18, %50, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %7[%11, %13, %18, %50, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %7[%11, %13, %18, %50, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.transpose %68, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%86 = vector.extract %85[0] : vector<16x9xf32>
%87 = vector.outerproduct %86, %69, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%88 = vector.extract %85[1] : vector<16x9xf32>
%89 = vector.outerproduct %88, %70, %87 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%90 = vector.extract %85[2] : vector<16x9xf32>
%91 = vector.outerproduct %90, %71, %89 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%92 = vector.extract %85[3] : vector<16x9xf32>
%93 = vector.outerproduct %92, %72, %91 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %85[4] : vector<16x9xf32>
%95 = vector.outerproduct %94, %73, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %85[5] : vector<16x9xf32>
%97 = vector.outerproduct %96, %74, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %85[6] : vector<16x9xf32>
%99 = vector.outerproduct %98, %75, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %85[7] : vector<16x9xf32>
%101 = vector.outerproduct %100, %76, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %85[8] : vector<16x9xf32>
%103 = vector.outerproduct %102, %77, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %85[9] : vector<16x9xf32>
%105 = vector.outerproduct %104, %78, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %85[10] : vector<16x9xf32>
%107 = vector.outerproduct %106, %79, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %85[11] : vector<16x9xf32>
%109 = vector.outerproduct %108, %80, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %85[12] : vector<16x9xf32>
%111 = vector.outerproduct %110, %81, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %85[13] : vector<16x9xf32>
%113 = vector.outerproduct %112, %82, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %85[14] : vector<16x9xf32>
%115 = vector.outerproduct %114, %83, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %85[15] : vector<16x9xf32>
%117 = vector.outerproduct %116, %84, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %117 : vector<9x32xf32>
}
%40 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%50 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
} else {
%50 = memref.cast %3 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
}
%41 = vector.extract %39[0] : vector<9x32xf32>
vector.store %41, %40[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%42 = vector.extract %39[1] : vector<9x32xf32>
vector.store %42, %40[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%43 = vector.extract %39[2] : vector<9x32xf32>
vector.store %43, %40[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%44 = vector.extract %39[3] : vector<9x32xf32>
vector.store %44, %40[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%45 = vector.extract %39[4] : vector<9x32xf32>
vector.store %45, %40[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%46 = vector.extract %39[5] : vector<9x32xf32>
vector.store %46, %40[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%47 = vector.extract %39[6] : vector<9x32xf32>
vector.store %47, %40[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%48 = vector.extract %39[7] : vector<9x32xf32>
vector.store %48, %40[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%49 = vector.extract %39[8] : vector<9x32xf32>
vector.store %49, %40[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
scf.if %17 {
%50 = memref.subview %3[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
%51 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%50, %51) : memref<?x32xf32, #map10>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692390>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2048 = arith.constant 2048 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x2048xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%11 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg5)
%13 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %10 step %c16 {
%14 = affine.apply #map6(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map7(%arg6, %10)
%17 = memref.subview %arg1[%15, %13] [%16, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = scf.if %18 -> (memref<?x32xf32, #map9>) {
%36 = memref.cast %17 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %36 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%36 = memref.subview %17[0, 0] [%16, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%37 = memref.subview %0[0, 0] [%16, 32] [1, 1] : memref<16x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%36, %37) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%38 = memref.cast %0 : memref<16x32xf32> to memref<?x32xf32, #map9>
scf.yield %38 : memref<?x32xf32, #map9>
}
%20 = vector.load %19[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%21 = vector.load %19[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%22 = vector.load %19[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%23 = vector.load %19[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%24 = vector.load %19[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%25 = vector.load %19[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%26 = vector.load %19[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%27 = vector.load %19[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%28 = vector.load %19[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%29 = vector.load %19[%c9, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%30 = vector.load %19[%c10, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%31 = vector.load %19[%c11, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%32 = vector.load %19[%c12, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%33 = vector.load %19[%c13, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%34 = vector.load %19[%c14, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%35 = vector.load %19[%c15, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
vector.store %20, %7[%9, %11, %12, %14, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %21, %7[%9, %11, %12, %14, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %22, %7[%9, %11, %12, %14, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %7[%9, %11, %12, %14, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %24, %7[%9, %11, %12, %14, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %7[%9, %11, %12, %14, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %7[%9, %11, %12, %14, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %7[%9, %11, %12, %14, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %7[%9, %11, %12, %14, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %7[%9, %11, %12, %14, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %7[%9, %11, %12, %14, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %7[%9, %11, %12, %14, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %7[%9, %11, %12, %14, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %7[%9, %11, %12, %14, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %7[%9, %11, %12, %14, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %7[%9, %11, %12, %14, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map11(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map12(%arg5)
%13 = affine.apply #map5(%arg5, %arg3)
%14 = affine.min #map13(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map6(%arg6)
%17 = affine.apply #map5(%arg6, %arg4)
%18 = affine.min #map7(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map9>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map9>) {
scf.yield %19 : memref<?x?xf32, #map9>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%32 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%33 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map14>
linalg.copy(%32, %33) : memref<?x?xf32, #map9>, memref<?x?xf32, #map14>
%34 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map9>
scf.yield %34 : memref<?x?xf32, #map9>
}
%23 = vector.load %22[%c0, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%24 = vector.load %22[%c1, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%25 = vector.load %22[%c2, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%26 = vector.load %22[%c3, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%27 = vector.load %22[%c4, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%28 = vector.load %22[%c5, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%29 = vector.load %22[%c6, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%30 = vector.load %22[%c7, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%31 = vector.load %22[%c8, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
vector.store %23, %8[%10, %12, %16, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %8[%10, %12, %16, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %8[%10, %12, %16, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %8[%10, %12, %16, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %8[%10, %12, %16, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %8[%10, %12, %16, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %8[%10, %12, %16, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %8[%10, %12, %16, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %8[%10, %12, %16, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%12 = memref.subview %arg2[%arg3, %arg5] [%9, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%13 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%14 = affine.min #map13(%arg6, %9)
%15 = affine.apply #map12(%arg6)
%16 = arith.cmpi sle, %c9, %14 : index
%17 = arith.cmpi sgt, %c9, %14 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%18 = affine.apply #map4(%arg7)
%19 = memref.subview %12[%arg6, %arg7] [%14, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%20 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%50 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%50 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%51 = memref.subview %2[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%50, %51) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%52 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %52 : memref<?x32xf32, #map9>
}
%21 = vector.load %20[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%22 = vector.insert %21, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%23 = vector.load %20[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%24 = vector.insert %23, %22 [1] : vector<32xf32> into vector<9x32xf32>
%25 = vector.load %20[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%26 = vector.insert %25, %24 [2] : vector<32xf32> into vector<9x32xf32>
%27 = vector.load %20[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%28 = vector.insert %27, %26 [3] : vector<32xf32> into vector<9x32xf32>
%29 = vector.load %20[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%30 = vector.insert %29, %28 [4] : vector<32xf32> into vector<9x32xf32>
%31 = vector.load %20[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%32 = vector.insert %31, %30 [5] : vector<32xf32> into vector<9x32xf32>
%33 = vector.load %20[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%34 = vector.insert %33, %32 [6] : vector<32xf32> into vector<9x32xf32>
%35 = vector.load %20[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%36 = vector.insert %35, %34 [7] : vector<32xf32> into vector<9x32xf32>
%37 = vector.load %20[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%38 = vector.insert %37, %36 [8] : vector<32xf32> into vector<9x32xf32>
%39 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %38) -> (vector<9x32xf32>) {
%50 = affine.apply #map6(%arg8)
%51 = vector.load %8[%11, %15, %50, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%52 = vector.insert %51, %cst_1 [0] : vector<16xf32> into vector<9x16xf32>
%53 = vector.load %8[%11, %15, %50, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%54 = vector.insert %53, %52 [1] : vector<16xf32> into vector<9x16xf32>
%55 = vector.load %8[%11, %15, %50, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%56 = vector.insert %55, %54 [2] : vector<16xf32> into vector<9x16xf32>
%57 = vector.load %8[%11, %15, %50, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %56 [3] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %8[%11, %15, %50, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [4] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %8[%11, %15, %50, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [5] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %8[%11, %15, %50, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [6] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %8[%11, %15, %50, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [7] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %8[%11, %15, %50, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [8] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %7[%11, %13, %18, %50, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%70 = vector.load %7[%11, %13, %18, %50, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %7[%11, %13, %18, %50, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %7[%11, %13, %18, %50, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %7[%11, %13, %18, %50, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %7[%11, %13, %18, %50, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %7[%11, %13, %18, %50, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %7[%11, %13, %18, %50, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %7[%11, %13, %18, %50, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %7[%11, %13, %18, %50, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %7[%11, %13, %18, %50, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %7[%11, %13, %18, %50, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %7[%11, %13, %18, %50, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %7[%11, %13, %18, %50, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %7[%11, %13, %18, %50, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %7[%11, %13, %18, %50, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.transpose %68, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%86 = vector.extract %85[0] : vector<16x9xf32>
%87 = vector.outerproduct %86, %69, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%88 = vector.extract %85[1] : vector<16x9xf32>
%89 = vector.outerproduct %88, %70, %87 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%90 = vector.extract %85[2] : vector<16x9xf32>
%91 = vector.outerproduct %90, %71, %89 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%92 = vector.extract %85[3] : vector<16x9xf32>
%93 = vector.outerproduct %92, %72, %91 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %85[4] : vector<16x9xf32>
%95 = vector.outerproduct %94, %73, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %85[5] : vector<16x9xf32>
%97 = vector.outerproduct %96, %74, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %85[6] : vector<16x9xf32>
%99 = vector.outerproduct %98, %75, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %85[7] : vector<16x9xf32>
%101 = vector.outerproduct %100, %76, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %85[8] : vector<16x9xf32>
%103 = vector.outerproduct %102, %77, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %85[9] : vector<16x9xf32>
%105 = vector.outerproduct %104, %78, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %85[10] : vector<16x9xf32>
%107 = vector.outerproduct %106, %79, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %85[11] : vector<16x9xf32>
%109 = vector.outerproduct %108, %80, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %85[12] : vector<16x9xf32>
%111 = vector.outerproduct %110, %81, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %85[13] : vector<16x9xf32>
%113 = vector.outerproduct %112, %82, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %85[14] : vector<16x9xf32>
%115 = vector.outerproduct %114, %83, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %85[15] : vector<16x9xf32>
%117 = vector.outerproduct %116, %84, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %117 : vector<9x32xf32>
}
%40 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%50 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
} else {
%50 = memref.cast %3 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
}
%41 = vector.extract %39[0] : vector<9x32xf32>
vector.store %41, %40[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%42 = vector.extract %39[1] : vector<9x32xf32>
vector.store %42, %40[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%43 = vector.extract %39[2] : vector<9x32xf32>
vector.store %43, %40[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%44 = vector.extract %39[3] : vector<9x32xf32>
vector.store %44, %40[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%45 = vector.extract %39[4] : vector<9x32xf32>
vector.store %45, %40[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%46 = vector.extract %39[5] : vector<9x32xf32>
vector.store %46, %40[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%47 = vector.extract %39[6] : vector<9x32xf32>
vector.store %47, %40[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%48 = vector.extract %39[7] : vector<9x32xf32>
vector.store %48, %40[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%49 = vector.extract %39[8] : vector<9x32xf32>
vector.store %49, %40[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
scf.if %17 {
%50 = memref.subview %3[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
%51 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%50, %51) : memref<?x32xf32, #map10>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692250>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<(d0) -> (d0 ceildiv 512)>
#map2 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map3 = affine_map<(d0) -> (d0 ceildiv 128)>
#map4 = affine_map<(d0) -> (d0 ceildiv 32)>
#map5 = affine_map<(d0, d1) -> (d0 + d1)>
#map6 = affine_map<(d0) -> (d0 ceildiv 16)>
#map7 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map8 = affine_map<(d0, d1)[s0] -> (d0 * 2048 + s0 + d1)>
#map9 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map10 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%cst = arith.constant dense<0.000000e+00> : vector<16x9xf32>
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c2048 = arith.constant 2048 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_2, %arg2) : f32, memref<?x2048xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = affine.apply #map0()[%5]
%7 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x16x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%9 = affine.apply #map1(%arg3)
%10 = affine.min #map2(%arg3)[%5]
scf.for %arg4 = %c0 to %c2048 step %c128 {
%11 = affine.apply #map3(%arg4)
scf.for %arg5 = %c0 to %c128 step %c32 {
%12 = affine.apply #map4(%arg5)
%13 = affine.apply #map5(%arg5, %arg4)
scf.for %arg6 = %c0 to %10 step %c16 {
%14 = affine.apply #map6(%arg6)
%15 = affine.apply #map5(%arg6, %arg3)
%16 = affine.min #map7(%arg6, %10)
%17 = memref.subview %arg1[%15, %13] [%16, 32] [1, 1] : memref<?x2048xf32> to memref<?x32xf32, #map8>
%18 = arith.cmpi sle, %c16, %16 : index
%19 = scf.if %18 -> (memref<?x32xf32, #map9>) {
%36 = memref.cast %17 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %36 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst_2, %0) : f32, memref<16x32xf32>
%36 = memref.subview %17[0, 0] [%16, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%37 = memref.subview %0[0, 0] [%16, 32] [1, 1] : memref<16x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%36, %37) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%38 = memref.cast %0 : memref<16x32xf32> to memref<?x32xf32, #map9>
scf.yield %38 : memref<?x32xf32, #map9>
}
%20 = vector.load %19[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%21 = vector.load %19[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%22 = vector.load %19[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%23 = vector.load %19[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%24 = vector.load %19[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%25 = vector.load %19[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%26 = vector.load %19[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%27 = vector.load %19[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%28 = vector.load %19[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%29 = vector.load %19[%c9, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%30 = vector.load %19[%c10, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%31 = vector.load %19[%c11, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%32 = vector.load %19[%c12, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%33 = vector.load %19[%c13, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%34 = vector.load %19[%c14, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%35 = vector.load %19[%c15, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
vector.store %20, %7[%9, %11, %12, %14, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %21, %7[%9, %11, %12, %14, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %22, %7[%9, %11, %12, %14, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %23, %7[%9, %11, %12, %14, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %24, %7[%9, %11, %12, %14, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %25, %7[%9, %11, %12, %14, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %26, %7[%9, %11, %12, %14, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %7[%9, %11, %12, %14, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %7[%9, %11, %12, %14, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %7[%9, %11, %12, %14, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %7[%9, %11, %12, %14, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %7[%9, %11, %12, %14, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %7[%9, %11, %12, %14, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %7[%9, %11, %12, %14, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %7[%9, %11, %12, %14, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %7[%9, %11, %12, %14, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%8 = memref.alloc(%6) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%9 = affine.min #map11(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.apply #map1(%arg4)
%11 = affine.min #map2(%arg4)[%5]
scf.for %arg5 = %c0 to %9 step %c9 {
%12 = affine.apply #map12(%arg5)
%13 = affine.apply #map5(%arg5, %arg3)
%14 = affine.min #map13(%arg5, %9)
%15 = arith.cmpi sle, %c9, %14 : index
scf.for %arg6 = %c0 to %11 step %c16 {
%16 = affine.apply #map6(%arg6)
%17 = affine.apply #map5(%arg6, %arg4)
%18 = affine.min #map7(%arg6, %11)
%19 = memref.subview %arg0[%13, %17] [%14, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map9>
%20 = arith.cmpi sle, %c16, %18 : index
%21 = arith.andi %15, %20 : i1
%22 = scf.if %21 -> (memref<?x?xf32, #map9>) {
scf.yield %19 : memref<?x?xf32, #map9>
} else {
linalg.fill(%cst_2, %1) : f32, memref<9x16xf32>
%32 = memref.subview %19[0, 0] [%14, %18] [1, 1] : memref<?x?xf32, #map9> to memref<?x?xf32, #map9>
%33 = memref.subview %1[0, 0] [%14, %18] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map14>
linalg.copy(%32, %33) : memref<?x?xf32, #map9>, memref<?x?xf32, #map14>
%34 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map9>
scf.yield %34 : memref<?x?xf32, #map9>
}
%23 = vector.load %22[%c0, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%24 = vector.load %22[%c1, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%25 = vector.load %22[%c2, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%26 = vector.load %22[%c3, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%27 = vector.load %22[%c4, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%28 = vector.load %22[%c5, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%29 = vector.load %22[%c6, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%30 = vector.load %22[%c7, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
%31 = vector.load %22[%c8, %c0] : memref<?x?xf32, #map9>, vector<16xf32>
vector.store %23, %8[%10, %12, %16, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %24, %8[%10, %12, %16, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %25, %8[%10, %12, %16, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %8[%10, %12, %16, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %8[%10, %12, %16, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %8[%10, %12, %16, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %8[%10, %12, %16, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %8[%10, %12, %16, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %8[%10, %12, %16, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%10 = affine.min #map2(%arg4)[%5]
%11 = affine.apply #map1(%arg4)
scf.for %arg5 = %c0 to %c2048 step %c128 {
%12 = memref.subview %arg2[%arg3, %arg5] [%9, 128] [1, 1] : memref<?x2048xf32> to memref<?x128xf32, #map8>
%13 = affine.apply #map3(%arg5)
scf.for %arg6 = %c0 to %9 step %c9 {
%14 = affine.min #map13(%arg6, %9)
%15 = affine.apply #map12(%arg6)
%16 = arith.cmpi sle, %c9, %14 : index
%17 = arith.cmpi sgt, %c9, %14 : index
scf.for %arg7 = %c0 to %c128 step %c32 {
%18 = affine.apply #map4(%arg7)
%19 = memref.subview %12[%arg6, %arg7] [%14, 32] [1, 1] : memref<?x128xf32, #map8> to memref<?x32xf32, #map8>
%20 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%50 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
} else {
linalg.fill(%cst_2, %2) : f32, memref<9x32xf32>
%50 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
%51 = memref.subview %2[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
linalg.copy(%50, %51) : memref<?x32xf32, #map8>, memref<?x32xf32, #map10>
%52 = memref.cast %2 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %52 : memref<?x32xf32, #map9>
}
%21 = vector.load %20[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%22 = vector.insert %21, %cst_1 [0] : vector<32xf32> into vector<9x32xf32>
%23 = vector.load %20[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%24 = vector.insert %23, %22 [1] : vector<32xf32> into vector<9x32xf32>
%25 = vector.load %20[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%26 = vector.insert %25, %24 [2] : vector<32xf32> into vector<9x32xf32>
%27 = vector.load %20[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%28 = vector.insert %27, %26 [3] : vector<32xf32> into vector<9x32xf32>
%29 = vector.load %20[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%30 = vector.insert %29, %28 [4] : vector<32xf32> into vector<9x32xf32>
%31 = vector.load %20[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%32 = vector.insert %31, %30 [5] : vector<32xf32> into vector<9x32xf32>
%33 = vector.load %20[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%34 = vector.insert %33, %32 [6] : vector<32xf32> into vector<9x32xf32>
%35 = vector.load %20[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%36 = vector.insert %35, %34 [7] : vector<32xf32> into vector<9x32xf32>
%37 = vector.load %20[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%38 = vector.insert %37, %36 [8] : vector<32xf32> into vector<9x32xf32>
%39 = scf.for %arg8 = %c0 to %10 step %c16 iter_args(%arg9 = %38) -> (vector<9x32xf32>) {
%50 = affine.apply #map6(%arg8)
%51 = vector.load %8[%11, %15, %50, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%52 = vector.insert %51, %cst_0 [0] : vector<16xf32> into vector<9x16xf32>
%53 = vector.load %8[%11, %15, %50, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%54 = vector.insert %53, %52 [1] : vector<16xf32> into vector<9x16xf32>
%55 = vector.load %8[%11, %15, %50, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%56 = vector.insert %55, %54 [2] : vector<16xf32> into vector<9x16xf32>
%57 = vector.load %8[%11, %15, %50, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %56 [3] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %8[%11, %15, %50, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [4] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %8[%11, %15, %50, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [5] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %8[%11, %15, %50, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [6] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %8[%11, %15, %50, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [7] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %8[%11, %15, %50, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [8] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %7[%11, %13, %18, %50, %c0, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%70 = vector.load %7[%11, %13, %18, %50, %c1, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%71 = vector.load %7[%11, %13, %18, %50, %c2, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%72 = vector.load %7[%11, %13, %18, %50, %c3, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%73 = vector.load %7[%11, %13, %18, %50, %c4, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%74 = vector.load %7[%11, %13, %18, %50, %c5, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%75 = vector.load %7[%11, %13, %18, %50, %c6, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %7[%11, %13, %18, %50, %c7, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %7[%11, %13, %18, %50, %c8, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %7[%11, %13, %18, %50, %c9, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %7[%11, %13, %18, %50, %c10, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %7[%11, %13, %18, %50, %c11, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %7[%11, %13, %18, %50, %c12, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %7[%11, %13, %18, %50, %c13, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %7[%11, %13, %18, %50, %c14, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %7[%11, %13, %18, %50, %c15, %c0] : memref<?x16x4x32x16x32xf32>, vector<32xf32>
%85 = vector.extract %68[0, 0] : vector<9x16xf32>
%86 = vector.insert %85, %cst [0, 0] : f32 into vector<16x9xf32>
%87 = vector.extract %68[1, 0] : vector<9x16xf32>
%88 = vector.insert %87, %86 [0, 1] : f32 into vector<16x9xf32>
%89 = vector.extract %68[2, 0] : vector<9x16xf32>
%90 = vector.insert %89, %88 [0, 2] : f32 into vector<16x9xf32>
%91 = vector.extract %68[3, 0] : vector<9x16xf32>
%92 = vector.insert %91, %90 [0, 3] : f32 into vector<16x9xf32>
%93 = vector.extract %68[4, 0] : vector<9x16xf32>
%94 = vector.insert %93, %92 [0, 4] : f32 into vector<16x9xf32>
%95 = vector.extract %68[5, 0] : vector<9x16xf32>
%96 = vector.insert %95, %94 [0, 5] : f32 into vector<16x9xf32>
%97 = vector.extract %68[6, 0] : vector<9x16xf32>
%98 = vector.insert %97, %96 [0, 6] : f32 into vector<16x9xf32>
%99 = vector.extract %68[7, 0] : vector<9x16xf32>
%100 = vector.insert %99, %98 [0, 7] : f32 into vector<16x9xf32>
%101 = vector.extract %68[8, 0] : vector<9x16xf32>
%102 = vector.insert %101, %100 [0, 8] : f32 into vector<16x9xf32>
%103 = vector.extract %68[0, 1] : vector<9x16xf32>
%104 = vector.insert %103, %102 [1, 0] : f32 into vector<16x9xf32>
%105 = vector.extract %68[1, 1] : vector<9x16xf32>
%106 = vector.insert %105, %104 [1, 1] : f32 into vector<16x9xf32>
%107 = vector.extract %68[2, 1] : vector<9x16xf32>
%108 = vector.insert %107, %106 [1, 2] : f32 into vector<16x9xf32>
%109 = vector.extract %68[3, 1] : vector<9x16xf32>
%110 = vector.insert %109, %108 [1, 3] : f32 into vector<16x9xf32>
%111 = vector.extract %68[4, 1] : vector<9x16xf32>
%112 = vector.insert %111, %110 [1, 4] : f32 into vector<16x9xf32>
%113 = vector.extract %68[5, 1] : vector<9x16xf32>
%114 = vector.insert %113, %112 [1, 5] : f32 into vector<16x9xf32>
%115 = vector.extract %68[6, 1] : vector<9x16xf32>
%116 = vector.insert %115, %114 [1, 6] : f32 into vector<16x9xf32>
%117 = vector.extract %68[7, 1] : vector<9x16xf32>
%118 = vector.insert %117, %116 [1, 7] : f32 into vector<16x9xf32>
%119 = vector.extract %68[8, 1] : vector<9x16xf32>
%120 = vector.insert %119, %118 [1, 8] : f32 into vector<16x9xf32>
%121 = vector.extract %68[0, 2] : vector<9x16xf32>
%122 = vector.insert %121, %120 [2, 0] : f32 into vector<16x9xf32>
%123 = vector.extract %68[1, 2] : vector<9x16xf32>
%124 = vector.insert %123, %122 [2, 1] : f32 into vector<16x9xf32>
%125 = vector.extract %68[2, 2] : vector<9x16xf32>
%126 = vector.insert %125, %124 [2, 2] : f32 into vector<16x9xf32>
%127 = vector.extract %68[3, 2] : vector<9x16xf32>
%128 = vector.insert %127, %126 [2, 3] : f32 into vector<16x9xf32>
%129 = vector.extract %68[4, 2] : vector<9x16xf32>
%130 = vector.insert %129, %128 [2, 4] : f32 into vector<16x9xf32>
%131 = vector.extract %68[5, 2] : vector<9x16xf32>
%132 = vector.insert %131, %130 [2, 5] : f32 into vector<16x9xf32>
%133 = vector.extract %68[6, 2] : vector<9x16xf32>
%134 = vector.insert %133, %132 [2, 6] : f32 into vector<16x9xf32>
%135 = vector.extract %68[7, 2] : vector<9x16xf32>
%136 = vector.insert %135, %134 [2, 7] : f32 into vector<16x9xf32>
%137 = vector.extract %68[8, 2] : vector<9x16xf32>
%138 = vector.insert %137, %136 [2, 8] : f32 into vector<16x9xf32>
%139 = vector.extract %68[0, 3] : vector<9x16xf32>
%140 = vector.insert %139, %138 [3, 0] : f32 into vector<16x9xf32>
%141 = vector.extract %68[1, 3] : vector<9x16xf32>
%142 = vector.insert %141, %140 [3, 1] : f32 into vector<16x9xf32>
%143 = vector.extract %68[2, 3] : vector<9x16xf32>
%144 = vector.insert %143, %142 [3, 2] : f32 into vector<16x9xf32>
%145 = vector.extract %68[3, 3] : vector<9x16xf32>
%146 = vector.insert %145, %144 [3, 3] : f32 into vector<16x9xf32>
%147 = vector.extract %68[4, 3] : vector<9x16xf32>
%148 = vector.insert %147, %146 [3, 4] : f32 into vector<16x9xf32>
%149 = vector.extract %68[5, 3] : vector<9x16xf32>
%150 = vector.insert %149, %148 [3, 5] : f32 into vector<16x9xf32>
%151 = vector.extract %68[6, 3] : vector<9x16xf32>
%152 = vector.insert %151, %150 [3, 6] : f32 into vector<16x9xf32>
%153 = vector.extract %68[7, 3] : vector<9x16xf32>
%154 = vector.insert %153, %152 [3, 7] : f32 into vector<16x9xf32>
%155 = vector.extract %68[8, 3] : vector<9x16xf32>
%156 = vector.insert %155, %154 [3, 8] : f32 into vector<16x9xf32>
%157 = vector.extract %68[0, 4] : vector<9x16xf32>
%158 = vector.insert %157, %156 [4, 0] : f32 into vector<16x9xf32>
%159 = vector.extract %68[1, 4] : vector<9x16xf32>
%160 = vector.insert %159, %158 [4, 1] : f32 into vector<16x9xf32>
%161 = vector.extract %68[2, 4] : vector<9x16xf32>
%162 = vector.insert %161, %160 [4, 2] : f32 into vector<16x9xf32>
%163 = vector.extract %68[3, 4] : vector<9x16xf32>
%164 = vector.insert %163, %162 [4, 3] : f32 into vector<16x9xf32>
%165 = vector.extract %68[4, 4] : vector<9x16xf32>
%166 = vector.insert %165, %164 [4, 4] : f32 into vector<16x9xf32>
%167 = vector.extract %68[5, 4] : vector<9x16xf32>
%168 = vector.insert %167, %166 [4, 5] : f32 into vector<16x9xf32>
%169 = vector.extract %68[6, 4] : vector<9x16xf32>
%170 = vector.insert %169, %168 [4, 6] : f32 into vector<16x9xf32>
%171 = vector.extract %68[7, 4] : vector<9x16xf32>
%172 = vector.insert %171, %170 [4, 7] : f32 into vector<16x9xf32>
%173 = vector.extract %68[8, 4] : vector<9x16xf32>
%174 = vector.insert %173, %172 [4, 8] : f32 into vector<16x9xf32>
%175 = vector.extract %68[0, 5] : vector<9x16xf32>
%176 = vector.insert %175, %174 [5, 0] : f32 into vector<16x9xf32>
%177 = vector.extract %68[1, 5] : vector<9x16xf32>
%178 = vector.insert %177, %176 [5, 1] : f32 into vector<16x9xf32>
%179 = vector.extract %68[2, 5] : vector<9x16xf32>
%180 = vector.insert %179, %178 [5, 2] : f32 into vector<16x9xf32>
%181 = vector.extract %68[3, 5] : vector<9x16xf32>
%182 = vector.insert %181, %180 [5, 3] : f32 into vector<16x9xf32>
%183 = vector.extract %68[4, 5] : vector<9x16xf32>
%184 = vector.insert %183, %182 [5, 4] : f32 into vector<16x9xf32>
%185 = vector.extract %68[5, 5] : vector<9x16xf32>
%186 = vector.insert %185, %184 [5, 5] : f32 into vector<16x9xf32>
%187 = vector.extract %68[6, 5] : vector<9x16xf32>
%188 = vector.insert %187, %186 [5, 6] : f32 into vector<16x9xf32>
%189 = vector.extract %68[7, 5] : vector<9x16xf32>
%190 = vector.insert %189, %188 [5, 7] : f32 into vector<16x9xf32>
%191 = vector.extract %68[8, 5] : vector<9x16xf32>
%192 = vector.insert %191, %190 [5, 8] : f32 into vector<16x9xf32>
%193 = vector.extract %68[0, 6] : vector<9x16xf32>
%194 = vector.insert %193, %192 [6, 0] : f32 into vector<16x9xf32>
%195 = vector.extract %68[1, 6] : vector<9x16xf32>
%196 = vector.insert %195, %194 [6, 1] : f32 into vector<16x9xf32>
%197 = vector.extract %68[2, 6] : vector<9x16xf32>
%198 = vector.insert %197, %196 [6, 2] : f32 into vector<16x9xf32>
%199 = vector.extract %68[3, 6] : vector<9x16xf32>
%200 = vector.insert %199, %198 [6, 3] : f32 into vector<16x9xf32>
%201 = vector.extract %68[4, 6] : vector<9x16xf32>
%202 = vector.insert %201, %200 [6, 4] : f32 into vector<16x9xf32>
%203 = vector.extract %68[5, 6] : vector<9x16xf32>
%204 = vector.insert %203, %202 [6, 5] : f32 into vector<16x9xf32>
%205 = vector.extract %68[6, 6] : vector<9x16xf32>
%206 = vector.insert %205, %204 [6, 6] : f32 into vector<16x9xf32>
%207 = vector.extract %68[7, 6] : vector<9x16xf32>
%208 = vector.insert %207, %206 [6, 7] : f32 into vector<16x9xf32>
%209 = vector.extract %68[8, 6] : vector<9x16xf32>
%210 = vector.insert %209, %208 [6, 8] : f32 into vector<16x9xf32>
%211 = vector.extract %68[0, 7] : vector<9x16xf32>
%212 = vector.insert %211, %210 [7, 0] : f32 into vector<16x9xf32>
%213 = vector.extract %68[1, 7] : vector<9x16xf32>
%214 = vector.insert %213, %212 [7, 1] : f32 into vector<16x9xf32>
%215 = vector.extract %68[2, 7] : vector<9x16xf32>
%216 = vector.insert %215, %214 [7, 2] : f32 into vector<16x9xf32>
%217 = vector.extract %68[3, 7] : vector<9x16xf32>
%218 = vector.insert %217, %216 [7, 3] : f32 into vector<16x9xf32>
%219 = vector.extract %68[4, 7] : vector<9x16xf32>
%220 = vector.insert %219, %218 [7, 4] : f32 into vector<16x9xf32>
%221 = vector.extract %68[5, 7] : vector<9x16xf32>
%222 = vector.insert %221, %220 [7, 5] : f32 into vector<16x9xf32>
%223 = vector.extract %68[6, 7] : vector<9x16xf32>
%224 = vector.insert %223, %222 [7, 6] : f32 into vector<16x9xf32>
%225 = vector.extract %68[7, 7] : vector<9x16xf32>
%226 = vector.insert %225, %224 [7, 7] : f32 into vector<16x9xf32>
%227 = vector.extract %68[8, 7] : vector<9x16xf32>
%228 = vector.insert %227, %226 [7, 8] : f32 into vector<16x9xf32>
%229 = vector.extract %68[0, 8] : vector<9x16xf32>
%230 = vector.insert %229, %228 [8, 0] : f32 into vector<16x9xf32>
%231 = vector.extract %68[1, 8] : vector<9x16xf32>
%232 = vector.insert %231, %230 [8, 1] : f32 into vector<16x9xf32>
%233 = vector.extract %68[2, 8] : vector<9x16xf32>
%234 = vector.insert %233, %232 [8, 2] : f32 into vector<16x9xf32>
%235 = vector.extract %68[3, 8] : vector<9x16xf32>
%236 = vector.insert %235, %234 [8, 3] : f32 into vector<16x9xf32>
%237 = vector.extract %68[4, 8] : vector<9x16xf32>
%238 = vector.insert %237, %236 [8, 4] : f32 into vector<16x9xf32>
%239 = vector.extract %68[5, 8] : vector<9x16xf32>
%240 = vector.insert %239, %238 [8, 5] : f32 into vector<16x9xf32>
%241 = vector.extract %68[6, 8] : vector<9x16xf32>
%242 = vector.insert %241, %240 [8, 6] : f32 into vector<16x9xf32>
%243 = vector.extract %68[7, 8] : vector<9x16xf32>
%244 = vector.insert %243, %242 [8, 7] : f32 into vector<16x9xf32>
%245 = vector.extract %68[8, 8] : vector<9x16xf32>
%246 = vector.insert %245, %244 [8, 8] : f32 into vector<16x9xf32>
%247 = vector.extract %68[0, 9] : vector<9x16xf32>
%248 = vector.insert %247, %246 [9, 0] : f32 into vector<16x9xf32>
%249 = vector.extract %68[1, 9] : vector<9x16xf32>
%250 = vector.insert %249, %248 [9, 1] : f32 into vector<16x9xf32>
%251 = vector.extract %68[2, 9] : vector<9x16xf32>
%252 = vector.insert %251, %250 [9, 2] : f32 into vector<16x9xf32>
%253 = vector.extract %68[3, 9] : vector<9x16xf32>
%254 = vector.insert %253, %252 [9, 3] : f32 into vector<16x9xf32>
%255 = vector.extract %68[4, 9] : vector<9x16xf32>
%256 = vector.insert %255, %254 [9, 4] : f32 into vector<16x9xf32>
%257 = vector.extract %68[5, 9] : vector<9x16xf32>
%258 = vector.insert %257, %256 [9, 5] : f32 into vector<16x9xf32>
%259 = vector.extract %68[6, 9] : vector<9x16xf32>
%260 = vector.insert %259, %258 [9, 6] : f32 into vector<16x9xf32>
%261 = vector.extract %68[7, 9] : vector<9x16xf32>
%262 = vector.insert %261, %260 [9, 7] : f32 into vector<16x9xf32>
%263 = vector.extract %68[8, 9] : vector<9x16xf32>
%264 = vector.insert %263, %262 [9, 8] : f32 into vector<16x9xf32>
%265 = vector.extract %68[0, 10] : vector<9x16xf32>
%266 = vector.insert %265, %264 [10, 0] : f32 into vector<16x9xf32>
%267 = vector.extract %68[1, 10] : vector<9x16xf32>
%268 = vector.insert %267, %266 [10, 1] : f32 into vector<16x9xf32>
%269 = vector.extract %68[2, 10] : vector<9x16xf32>
%270 = vector.insert %269, %268 [10, 2] : f32 into vector<16x9xf32>
%271 = vector.extract %68[3, 10] : vector<9x16xf32>
%272 = vector.insert %271, %270 [10, 3] : f32 into vector<16x9xf32>
%273 = vector.extract %68[4, 10] : vector<9x16xf32>
%274 = vector.insert %273, %272 [10, 4] : f32 into vector<16x9xf32>
%275 = vector.extract %68[5, 10] : vector<9x16xf32>
%276 = vector.insert %275, %274 [10, 5] : f32 into vector<16x9xf32>
%277 = vector.extract %68[6, 10] : vector<9x16xf32>
%278 = vector.insert %277, %276 [10, 6] : f32 into vector<16x9xf32>
%279 = vector.extract %68[7, 10] : vector<9x16xf32>
%280 = vector.insert %279, %278 [10, 7] : f32 into vector<16x9xf32>
%281 = vector.extract %68[8, 10] : vector<9x16xf32>
%282 = vector.insert %281, %280 [10, 8] : f32 into vector<16x9xf32>
%283 = vector.extract %68[0, 11] : vector<9x16xf32>
%284 = vector.insert %283, %282 [11, 0] : f32 into vector<16x9xf32>
%285 = vector.extract %68[1, 11] : vector<9x16xf32>
%286 = vector.insert %285, %284 [11, 1] : f32 into vector<16x9xf32>
%287 = vector.extract %68[2, 11] : vector<9x16xf32>
%288 = vector.insert %287, %286 [11, 2] : f32 into vector<16x9xf32>
%289 = vector.extract %68[3, 11] : vector<9x16xf32>
%290 = vector.insert %289, %288 [11, 3] : f32 into vector<16x9xf32>
%291 = vector.extract %68[4, 11] : vector<9x16xf32>
%292 = vector.insert %291, %290 [11, 4] : f32 into vector<16x9xf32>
%293 = vector.extract %68[5, 11] : vector<9x16xf32>
%294 = vector.insert %293, %292 [11, 5] : f32 into vector<16x9xf32>
%295 = vector.extract %68[6, 11] : vector<9x16xf32>
%296 = vector.insert %295, %294 [11, 6] : f32 into vector<16x9xf32>
%297 = vector.extract %68[7, 11] : vector<9x16xf32>
%298 = vector.insert %297, %296 [11, 7] : f32 into vector<16x9xf32>
%299 = vector.extract %68[8, 11] : vector<9x16xf32>
%300 = vector.insert %299, %298 [11, 8] : f32 into vector<16x9xf32>
%301 = vector.extract %68[0, 12] : vector<9x16xf32>
%302 = vector.insert %301, %300 [12, 0] : f32 into vector<16x9xf32>
%303 = vector.extract %68[1, 12] : vector<9x16xf32>
%304 = vector.insert %303, %302 [12, 1] : f32 into vector<16x9xf32>
%305 = vector.extract %68[2, 12] : vector<9x16xf32>
%306 = vector.insert %305, %304 [12, 2] : f32 into vector<16x9xf32>
%307 = vector.extract %68[3, 12] : vector<9x16xf32>
%308 = vector.insert %307, %306 [12, 3] : f32 into vector<16x9xf32>
%309 = vector.extract %68[4, 12] : vector<9x16xf32>
%310 = vector.insert %309, %308 [12, 4] : f32 into vector<16x9xf32>
%311 = vector.extract %68[5, 12] : vector<9x16xf32>
%312 = vector.insert %311, %310 [12, 5] : f32 into vector<16x9xf32>
%313 = vector.extract %68[6, 12] : vector<9x16xf32>
%314 = vector.insert %313, %312 [12, 6] : f32 into vector<16x9xf32>
%315 = vector.extract %68[7, 12] : vector<9x16xf32>
%316 = vector.insert %315, %314 [12, 7] : f32 into vector<16x9xf32>
%317 = vector.extract %68[8, 12] : vector<9x16xf32>
%318 = vector.insert %317, %316 [12, 8] : f32 into vector<16x9xf32>
%319 = vector.extract %68[0, 13] : vector<9x16xf32>
%320 = vector.insert %319, %318 [13, 0] : f32 into vector<16x9xf32>
%321 = vector.extract %68[1, 13] : vector<9x16xf32>
%322 = vector.insert %321, %320 [13, 1] : f32 into vector<16x9xf32>
%323 = vector.extract %68[2, 13] : vector<9x16xf32>
%324 = vector.insert %323, %322 [13, 2] : f32 into vector<16x9xf32>
%325 = vector.extract %68[3, 13] : vector<9x16xf32>
%326 = vector.insert %325, %324 [13, 3] : f32 into vector<16x9xf32>
%327 = vector.extract %68[4, 13] : vector<9x16xf32>
%328 = vector.insert %327, %326 [13, 4] : f32 into vector<16x9xf32>
%329 = vector.extract %68[5, 13] : vector<9x16xf32>
%330 = vector.insert %329, %328 [13, 5] : f32 into vector<16x9xf32>
%331 = vector.extract %68[6, 13] : vector<9x16xf32>
%332 = vector.insert %331, %330 [13, 6] : f32 into vector<16x9xf32>
%333 = vector.extract %68[7, 13] : vector<9x16xf32>
%334 = vector.insert %333, %332 [13, 7] : f32 into vector<16x9xf32>
%335 = vector.extract %68[8, 13] : vector<9x16xf32>
%336 = vector.insert %335, %334 [13, 8] : f32 into vector<16x9xf32>
%337 = vector.extract %68[0, 14] : vector<9x16xf32>
%338 = vector.insert %337, %336 [14, 0] : f32 into vector<16x9xf32>
%339 = vector.extract %68[1, 14] : vector<9x16xf32>
%340 = vector.insert %339, %338 [14, 1] : f32 into vector<16x9xf32>
%341 = vector.extract %68[2, 14] : vector<9x16xf32>
%342 = vector.insert %341, %340 [14, 2] : f32 into vector<16x9xf32>
%343 = vector.extract %68[3, 14] : vector<9x16xf32>
%344 = vector.insert %343, %342 [14, 3] : f32 into vector<16x9xf32>
%345 = vector.extract %68[4, 14] : vector<9x16xf32>
%346 = vector.insert %345, %344 [14, 4] : f32 into vector<16x9xf32>
%347 = vector.extract %68[5, 14] : vector<9x16xf32>
%348 = vector.insert %347, %346 [14, 5] : f32 into vector<16x9xf32>
%349 = vector.extract %68[6, 14] : vector<9x16xf32>
%350 = vector.insert %349, %348 [14, 6] : f32 into vector<16x9xf32>
%351 = vector.extract %68[7, 14] : vector<9x16xf32>
%352 = vector.insert %351, %350 [14, 7] : f32 into vector<16x9xf32>
%353 = vector.extract %68[8, 14] : vector<9x16xf32>
%354 = vector.insert %353, %352 [14, 8] : f32 into vector<16x9xf32>
%355 = vector.extract %68[0, 15] : vector<9x16xf32>
%356 = vector.insert %355, %354 [15, 0] : f32 into vector<16x9xf32>
%357 = vector.extract %68[1, 15] : vector<9x16xf32>
%358 = vector.insert %357, %356 [15, 1] : f32 into vector<16x9xf32>
%359 = vector.extract %68[2, 15] : vector<9x16xf32>
%360 = vector.insert %359, %358 [15, 2] : f32 into vector<16x9xf32>
%361 = vector.extract %68[3, 15] : vector<9x16xf32>
%362 = vector.insert %361, %360 [15, 3] : f32 into vector<16x9xf32>
%363 = vector.extract %68[4, 15] : vector<9x16xf32>
%364 = vector.insert %363, %362 [15, 4] : f32 into vector<16x9xf32>
%365 = vector.extract %68[5, 15] : vector<9x16xf32>
%366 = vector.insert %365, %364 [15, 5] : f32 into vector<16x9xf32>
%367 = vector.extract %68[6, 15] : vector<9x16xf32>
%368 = vector.insert %367, %366 [15, 6] : f32 into vector<16x9xf32>
%369 = vector.extract %68[7, 15] : vector<9x16xf32>
%370 = vector.insert %369, %368 [15, 7] : f32 into vector<16x9xf32>
%371 = vector.extract %68[8, 15] : vector<9x16xf32>
%372 = vector.insert %371, %370 [15, 8] : f32 into vector<16x9xf32>
%373 = vector.extract %372[0] : vector<16x9xf32>
%374 = vector.outerproduct %373, %69, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%375 = vector.extract %372[1] : vector<16x9xf32>
%376 = vector.outerproduct %375, %70, %374 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%377 = vector.extract %372[2] : vector<16x9xf32>
%378 = vector.outerproduct %377, %71, %376 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%379 = vector.extract %372[3] : vector<16x9xf32>
%380 = vector.outerproduct %379, %72, %378 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%381 = vector.extract %372[4] : vector<16x9xf32>
%382 = vector.outerproduct %381, %73, %380 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%383 = vector.extract %372[5] : vector<16x9xf32>
%384 = vector.outerproduct %383, %74, %382 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%385 = vector.extract %372[6] : vector<16x9xf32>
%386 = vector.outerproduct %385, %75, %384 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%387 = vector.extract %372[7] : vector<16x9xf32>
%388 = vector.outerproduct %387, %76, %386 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%389 = vector.extract %372[8] : vector<16x9xf32>
%390 = vector.outerproduct %389, %77, %388 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%391 = vector.extract %372[9] : vector<16x9xf32>
%392 = vector.outerproduct %391, %78, %390 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%393 = vector.extract %372[10] : vector<16x9xf32>
%394 = vector.outerproduct %393, %79, %392 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%395 = vector.extract %372[11] : vector<16x9xf32>
%396 = vector.outerproduct %395, %80, %394 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%397 = vector.extract %372[12] : vector<16x9xf32>
%398 = vector.outerproduct %397, %81, %396 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%399 = vector.extract %372[13] : vector<16x9xf32>
%400 = vector.outerproduct %399, %82, %398 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%401 = vector.extract %372[14] : vector<16x9xf32>
%402 = vector.outerproduct %401, %83, %400 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%403 = vector.extract %372[15] : vector<16x9xf32>
%404 = vector.outerproduct %403, %84, %402 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %404 : vector<9x32xf32>
}
%40 = scf.if %16 -> (memref<?x32xf32, #map9>) {
%50 = memref.cast %19 : memref<?x32xf32, #map8> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
} else {
%50 = memref.cast %3 : memref<9x32xf32> to memref<?x32xf32, #map9>
scf.yield %50 : memref<?x32xf32, #map9>
}
%41 = vector.extract %39[0] : vector<9x32xf32>
vector.store %41, %40[%c0, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%42 = vector.extract %39[1] : vector<9x32xf32>
vector.store %42, %40[%c1, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%43 = vector.extract %39[2] : vector<9x32xf32>
vector.store %43, %40[%c2, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%44 = vector.extract %39[3] : vector<9x32xf32>
vector.store %44, %40[%c3, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%45 = vector.extract %39[4] : vector<9x32xf32>
vector.store %45, %40[%c4, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%46 = vector.extract %39[5] : vector<9x32xf32>
vector.store %46, %40[%c5, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%47 = vector.extract %39[6] : vector<9x32xf32>
vector.store %47, %40[%c6, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%48 = vector.extract %39[7] : vector<9x32xf32>
vector.store %48, %40[%c7, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
%49 = vector.extract %39[8] : vector<9x32xf32>
vector.store %49, %40[%c8, %c0] : memref<?x32xf32, #map9>, vector<32xf32>
scf.if %17 {
%50 = memref.subview %3[0, 0] [%14, 32] [1, 1] : memref<9x32xf32> to memref<?x32xf32, #map10>
%51 = memref.subview %19[0, 0] [%14, 32] [1, 1] : memref<?x32xf32, #map8> to memref<?x32xf32, #map8>
linalg.copy(%50, %51) : memref<?x32xf32, #map10>, memref<?x32xf32, #map8>
}
}
}
}
}
}
memref.dealloc %7 : memref<?x16x4x32x16x32xf32>
memref.dealloc %8 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x2048xf32>, %arg2: memref<?x2048xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x2048xf32>, memref<?x2048xf32>) -> ()
}
return
}
}
compilation in 0.2309s
xxxxxxxxxx : 10 iters time on 1 threads in 0.1534s per iter sec (112.0 GFlop/s, 0.3281 GB/s) total time 1.534s
###############################################################
Runtime problem size {'M': 2048, 'N': 2048, 'K': 2048}
Compile-time problem size {'M': -1, 'N': -1, 'K': -1}
Problem types [<class 'numpy.float32'>, <class 'numpy.float32'>, <class 'numpy.float32'>]
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43d6affd0>]]]
#map0 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map2 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x?xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x?xf32>) {
%5 = affine.min #map0(%arg3)[%1]
%6 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%7 = affine.min #map1(%arg5)[%2]
%8 = tensor.extract_slice %arg0[%arg3, %arg5] [%5, %7] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%9 = scf.for %arg7 = %c0 to %3 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
%10 = affine.min #map2(%arg7)[%3]
%11 = tensor.extract_slice %arg1[%arg5, %arg7] [%7, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%12 = tensor.extract_slice %arg8[%arg3, %arg7] [%5, %10] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%13 = linalg.matmul ins(%8, %11 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) -> tensor<?x?xf32>
%14 = tensor.insert_slice %13 into %arg8[%arg3, %arg7] [%5, %10] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %14 : tensor<?x?xf32>
}
scf.yield %9 : tensor<?x?xf32>
}
scf.yield %6 : tensor<?x?xf32>
}
return %4 : tensor<?x?xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x?xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x?xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %1 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Tile object at 0x7fb43b6dbc50>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (-d0 + 32)>
#map10 = affine_map<(d0) -> (d0 ceildiv 16)>
#map11 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map12 = affine_map<(d0) -> (-d0 + 16)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0) -> (-d0 + 9)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x?xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = affine.apply #map0()[%2]
%5 = affine.apply #map1()[%3]
%6 = linalg.init_tensor [%4, %5, 4, 32, 16, 32] : tensor<?x?x4x32x16x32xf32>
%7 = tensor.cast %6 : tensor<?x?x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%8 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %7) -> (tensor<?x?x?x?x16x32xf32>) {
%12 = affine.apply #map2(%arg3)
%13 = affine.min #map3(%arg3)[%2]
%14 = scf.for %arg5 = %c0 to %3 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%15 = affine.apply #map4(%arg5)
%16 = affine.min #map5(%arg5)[%3]
%17 = scf.for %arg7 = %c0 to %16 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map6(%arg7)
%19 = affine.apply #map7(%arg7, %arg5)
%20 = affine.min #map8(%arg7, %16)
%21 = affine.apply #map9(%20)
%22 = scf.for %arg9 = %c0 to %13 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%23 = affine.apply #map10(%arg9)
%24 = affine.apply #map7(%arg9, %arg3)
%25 = affine.min #map11(%arg9, %13)
%26 = tensor.extract_slice %arg1[%24, %19] [%25, %20] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%27 = affine.apply #map12(%25)
%28 = linalg.pad_tensor %26 nofold low[%c0, %c0] high[%27, %21] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<16x32xf32>
%29 = tensor.insert_slice %28 into %arg10[%12, %15, %18, %23, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<?x?x?x?x16x32xf32>
scf.yield %29 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %22 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %14 : tensor<?x?x?x?x16x32xf32>
}
%9 = linalg.init_tensor [%4, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%10 = tensor.cast %9 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%11 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x?xf32>) {
%12 = affine.min #map13(%arg3)[%1]
%13 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %10) -> (tensor<?x?x?x9x16xf32>) {
%15 = affine.apply #map2(%arg5)
%16 = affine.min #map3(%arg5)[%2]
%17 = scf.for %arg7 = %c0 to %12 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%18 = affine.apply #map14(%arg7)
%19 = affine.apply #map7(%arg7, %arg3)
%20 = affine.min #map15(%arg7, %12)
%21 = affine.apply #map16(%20)
%22 = scf.for %arg9 = %c0 to %16 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%23 = affine.apply #map10(%arg9)
%24 = affine.apply #map7(%arg9, %arg5)
%25 = affine.min #map11(%arg9, %16)
%26 = tensor.extract_slice %arg0[%19, %24] [%20, %25] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%27 = affine.apply #map12(%25)
%28 = linalg.pad_tensor %26 nofold low[%c0, %c0] high[%21, %27] {
^bb0(%arg11: index, %arg12: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x16xf32>
%29 = tensor.insert_slice %28 into %arg10[%15, %18, %23, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<9x16xf32> into tensor<?x?x?x9x16xf32>
scf.yield %29 : tensor<?x?x?x9x16xf32>
}
scf.yield %22 : tensor<?x?x?x9x16xf32>
}
scf.yield %17 : tensor<?x?x?x9x16xf32>
}
%14 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%15 = affine.min #map3(%arg5)[%2]
%16 = affine.apply #map2(%arg5)
%17 = scf.for %arg7 = %c0 to %3 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
%18 = affine.min #map5(%arg7)[%3]
%19 = tensor.extract_slice %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%20 = affine.apply #map4(%arg7)
%21 = scf.for %arg9 = %c0 to %12 step %c9 iter_args(%arg10 = %19) -> (tensor<?x?xf32>) {
%23 = affine.min #map15(%arg9, %12)
%24 = affine.apply #map14(%arg9)
%25 = affine.apply #map16(%23)
%26 = scf.for %arg11 = %c0 to %18 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%27 = affine.min #map8(%arg11, %18)
%28 = affine.apply #map6(%arg11)
%29 = affine.apply #map9(%27)
%30 = scf.for %arg13 = %c0 to %15 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%31 = tensor.extract_slice %arg14[%arg9, %arg11] [%23, %27] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%32 = affine.apply #map10(%arg13)
%33 = tensor.extract_slice %13[%16, %24, %32, 0, 0] [1, 1, 1, 9, 16] [1, 1, 1, 1, 1] : tensor<?x?x?x9x16xf32> to tensor<9x16xf32>
%34 = tensor.extract_slice %8[%16, %20, %28, %32, 0, 0] [1, 1, 1, 1, 16, 32] [1, 1, 1, 1, 1, 1] : tensor<?x?x?x?x16x32xf32> to tensor<16x32xf32>
%35 = linalg.pad_tensor %31 low[%c0, %c0] high[%25, %29] {
^bb0(%arg15: index, %arg16: index): // no predecessors
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<9x32xf32>
%36 = linalg.matmul ins(%33, %34 : tensor<9x16xf32>, tensor<16x32xf32>) outs(%35 : tensor<9x32xf32>) -> tensor<9x32xf32>
%37 = tensor.extract_slice %36[0, 0] [%23, %27] [1, 1] : tensor<9x32xf32> to tensor<?x?xf32>
%38 = tensor.insert_slice %37 into %arg14[%arg9, %arg11] [%23, %27] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %38 : tensor<?x?xf32>
}
scf.yield %30 : tensor<?x?xf32>
}
scf.yield %26 : tensor<?x?xf32>
}
%22 = tensor.insert_slice %21 into %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %22 : tensor<?x?xf32>
}
scf.yield %17 : tensor<?x?xf32>
}
scf.yield %14 : tensor<?x?xf32>
}
return %11 : tensor<?x?xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x?xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x?xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %1 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Vectorize object at 0x7fb43b6bdc10>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map12 = affine_map<(d0) -> (d0 ceildiv 9)>
#map13 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map14 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map15 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map16 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<?x?xf32> attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = linalg.fill(%cst, %arg2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = affine.apply #map0()[%2]
%5 = affine.apply #map1()[%3]
%6 = linalg.init_tensor [%4, %5, 4, 32, 16, 32] : tensor<?x?x4x32x16x32xf32>
%7 = tensor.cast %6 : tensor<?x?x4x32x16x32xf32> to tensor<?x?x?x?x16x32xf32>
%8 = scf.for %arg3 = %c0 to %2 step %c512 iter_args(%arg4 = %7) -> (tensor<?x?x?x?x16x32xf32>) {
%12 = affine.apply #map2(%arg3)
%13 = affine.min #map3(%arg3)[%2]
%14 = scf.for %arg5 = %c0 to %3 step %c128 iter_args(%arg6 = %arg4) -> (tensor<?x?x?x?x16x32xf32>) {
%15 = affine.apply #map4(%arg5)
%16 = affine.min #map5(%arg5)[%3]
%17 = scf.for %arg7 = %c0 to %16 step %c32 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x?x16x32xf32>) {
%18 = affine.apply #map6(%arg7)
%19 = affine.apply #map7(%arg7, %arg5)
%20 = affine.min #map8(%arg7, %16)
%21 = scf.for %arg9 = %c0 to %13 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x?x16x32xf32>) {
%22 = affine.apply #map9(%arg9)
%23 = affine.apply #map7(%arg9, %arg3)
%24 = affine.min #map10(%arg9, %13)
%25 = tensor.extract_slice %arg1[%23, %19] [%24, %20] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = vector.transfer_read %25[%c0, %c0], %cst : tensor<?x?xf32>, vector<16x32xf32>
%27 = vector.transfer_write %26, %arg10[%12, %15, %18, %22, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, tensor<?x?x?x?x16x32xf32>
scf.yield %27 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %21 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %17 : tensor<?x?x?x?x16x32xf32>
}
scf.yield %14 : tensor<?x?x?x?x16x32xf32>
}
%9 = linalg.init_tensor [%4, 32, 32, 9, 16] : tensor<?x32x32x9x16xf32>
%10 = tensor.cast %9 : tensor<?x32x32x9x16xf32> to tensor<?x?x?x9x16xf32>
%11 = scf.for %arg3 = %c0 to %1 step %c288 iter_args(%arg4 = %0) -> (tensor<?x?xf32>) {
%12 = affine.min #map11(%arg3)[%1]
%13 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %10) -> (tensor<?x?x?x9x16xf32>) {
%15 = affine.apply #map2(%arg5)
%16 = affine.min #map3(%arg5)[%2]
%17 = scf.for %arg7 = %c0 to %12 step %c9 iter_args(%arg8 = %arg6) -> (tensor<?x?x?x9x16xf32>) {
%18 = affine.apply #map12(%arg7)
%19 = affine.apply #map7(%arg7, %arg3)
%20 = affine.min #map13(%arg7, %12)
%21 = scf.for %arg9 = %c0 to %16 step %c16 iter_args(%arg10 = %arg8) -> (tensor<?x?x?x9x16xf32>) {
%22 = affine.apply #map9(%arg9)
%23 = affine.apply #map7(%arg9, %arg5)
%24 = affine.min #map10(%arg9, %16)
%25 = tensor.extract_slice %arg0[%19, %23] [%20, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = vector.transfer_read %25[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x16xf32>
%27 = vector.transfer_write %26, %arg10[%15, %18, %22, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, tensor<?x?x?x9x16xf32>
scf.yield %27 : tensor<?x?x?x9x16xf32>
}
scf.yield %21 : tensor<?x?x?x9x16xf32>
}
scf.yield %17 : tensor<?x?x?x9x16xf32>
}
%14 = scf.for %arg5 = %c0 to %2 step %c512 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%15 = affine.min #map3(%arg5)[%2]
%16 = affine.apply #map2(%arg5)
%17 = scf.for %arg7 = %c0 to %3 step %c128 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
%18 = affine.min #map5(%arg7)[%3]
%19 = tensor.extract_slice %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%20 = affine.apply #map4(%arg7)
%21 = scf.for %arg9 = %c0 to %12 step %c9 iter_args(%arg10 = %19) -> (tensor<?x?xf32>) {
%23 = affine.min #map13(%arg9, %12)
%24 = affine.apply #map12(%arg9)
%25 = scf.for %arg11 = %c0 to %18 step %c32 iter_args(%arg12 = %arg10) -> (tensor<?x?xf32>) {
%26 = affine.min #map8(%arg11, %18)
%27 = affine.apply #map6(%arg11)
%28 = scf.for %arg13 = %c0 to %15 step %c16 iter_args(%arg14 = %arg12) -> (tensor<?x?xf32>) {
%29 = tensor.extract_slice %arg14[%arg9, %arg11] [%23, %26] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%30 = affine.apply #map9(%arg13)
%31 = vector.transfer_read %13[%16, %24, %30, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x9x16xf32>, vector<9x16xf32>
%32 = vector.transfer_read %8[%16, %20, %27, %30, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<?x?x?x?x16x32xf32>, vector<16x32xf32>
%33 = vector.transfer_read %29[%c0, %c0], %cst : tensor<?x?xf32>, vector<9x32xf32>
%34 = vector.contract {indexing_maps = [#map14, #map15, #map16], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %31, %32, %33 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
%35 = vector.transfer_write %34, %29[%c0, %c0] : vector<9x32xf32>, tensor<?x?xf32>
%36 = tensor.insert_slice %35 into %arg14[%arg9, %arg11] [%23, %26] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %36 : tensor<?x?xf32>
}
scf.yield %28 : tensor<?x?xf32>
}
scf.yield %25 : tensor<?x?xf32>
}
%22 = tensor.insert_slice %21 into %arg8[%arg3, %arg7] [%12, %18] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
scf.yield %22 : tensor<?x?xf32>
}
scf.yield %17 : tensor<?x?xf32>
}
scf.yield %14 : tensor<?x?xf32>
}
return %11 : tensor<?x?xf32>
}
func public @matmul_main(%arg0: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<?x?xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}, %arg3: index) -> tensor<?x?xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = scf.for %arg4 = %c0 to %arg3 step %c1 iter_args(%arg5 = %arg2) -> (tensor<?x?xf32>) {
%1 = call @matmul_on_tensors(%arg0, %arg1, %arg5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %1 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.Bufferize object at 0x7fb43b692110>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map15 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map16 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map17 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%3 = affine.apply #map0()[%1]
%4 = affine.apply #map1()[%2]
%5 = memref.alloc(%3, %4) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%7 = affine.apply #map2(%arg3)
%8 = affine.min #map3(%arg3)[%1]
scf.for %arg4 = %c0 to %2 step %c128 {
%9 = affine.apply #map4(%arg4)
%10 = affine.min #map5(%arg4)[%2]
scf.for %arg5 = %c0 to %10 step %c32 {
%11 = affine.apply #map6(%arg5)
%12 = affine.apply #map7(%arg5, %arg4)
%13 = affine.min #map8(%arg5, %10)
scf.for %arg6 = %c0 to %8 step %c16 {
%14 = affine.apply #map9(%arg6)
%15 = affine.apply #map7(%arg6, %arg3)
%16 = affine.min #map10(%arg6, %8)
%17 = memref.subview %arg1[%15, %12] [%16, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %18, %5[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%6 = memref.alloc(%3) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%7 = affine.min #map12(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)[%1]
scf.for %arg5 = %c0 to %7 step %c9 {
%10 = affine.apply #map13(%arg5)
%11 = affine.apply #map7(%arg5, %arg3)
%12 = affine.min #map14(%arg5, %7)
scf.for %arg6 = %c0 to %9 step %c16 {
%13 = affine.apply #map9(%arg6)
%14 = affine.apply #map7(%arg6, %arg4)
%15 = affine.min #map10(%arg6, %9)
%16 = memref.subview %arg0[%11, %14] [%12, %15] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%17 = vector.transfer_read %16[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %17, %6[%8, %10, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.min #map3(%arg4)[%1]
%9 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %2 step %c128 {
%10 = affine.min #map5(%arg5)[%2]
%11 = memref.subview %arg2[%arg3, %arg5] [%7, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%12 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %7 step %c9 {
%13 = affine.min #map14(%arg6, %7)
%14 = affine.apply #map13(%arg6)
scf.for %arg7 = %c0 to %10 step %c32 {
%15 = affine.min #map8(%arg7, %10)
%16 = affine.apply #map6(%arg7)
%17 = memref.subview %11[%arg6, %arg7] [%13, %15] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x32xf32>
%19 = scf.for %arg8 = %c0 to %8 step %c16 iter_args(%arg9 = %18) -> (vector<9x32xf32>) {
%20 = affine.apply #map9(%arg8)
%21 = vector.transfer_read %6[%9, %14, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%22 = vector.transfer_read %5[%9, %12, %16, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%23 = vector.contract {indexing_maps = [#map15, #map16, #map17], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %21, %22, %arg9 : vector<9x16xf32>, vector<16x32xf32> into vector<9x32xf32>
scf.yield %23 : vector<9x32xf32>
}
vector.transfer_write %19, %17[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map11>
}
}
}
}
}
memref.dealloc %5 : memref<?x?x4x32x16x32xf32>
memref.dealloc %6 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692190>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%3 = affine.apply #map0()[%1]
%4 = affine.apply #map1()[%2]
%5 = memref.alloc(%3, %4) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%7 = affine.apply #map2(%arg3)
%8 = affine.min #map3(%arg3)[%1]
scf.for %arg4 = %c0 to %2 step %c128 {
%9 = affine.apply #map4(%arg4)
%10 = affine.min #map5(%arg4)[%2]
scf.for %arg5 = %c0 to %10 step %c32 {
%11 = affine.apply #map6(%arg5)
%12 = affine.apply #map7(%arg5, %arg4)
%13 = affine.min #map8(%arg5, %10)
scf.for %arg6 = %c0 to %8 step %c16 {
%14 = affine.apply #map9(%arg6)
%15 = affine.apply #map7(%arg6, %arg3)
%16 = affine.min #map10(%arg6, %8)
%17 = memref.subview %arg1[%15, %12] [%16, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %18, %5[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%6 = memref.alloc(%3) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%7 = affine.min #map12(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)[%1]
scf.for %arg5 = %c0 to %7 step %c9 {
%10 = affine.apply #map13(%arg5)
%11 = affine.apply #map7(%arg5, %arg3)
%12 = affine.min #map14(%arg5, %7)
scf.for %arg6 = %c0 to %9 step %c16 {
%13 = affine.apply #map9(%arg6)
%14 = affine.apply #map7(%arg6, %arg4)
%15 = affine.min #map10(%arg6, %9)
%16 = memref.subview %arg0[%11, %14] [%12, %15] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%17 = vector.transfer_read %16[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %17, %6[%8, %10, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.min #map3(%arg4)[%1]
%9 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %2 step %c128 {
%10 = affine.min #map5(%arg5)[%2]
%11 = memref.subview %arg2[%arg3, %arg5] [%7, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%12 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %7 step %c9 {
%13 = affine.min #map14(%arg6, %7)
%14 = affine.apply #map13(%arg6)
scf.for %arg7 = %c0 to %10 step %c32 {
%15 = affine.min #map8(%arg7, %10)
%16 = affine.apply #map6(%arg7)
%17 = memref.subview %11[%arg6, %arg7] [%13, %15] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x32xf32>
%19 = scf.for %arg8 = %c0 to %8 step %c16 iter_args(%arg9 = %18) -> (vector<9x32xf32>) {
%20 = affine.apply #map9(%arg8)
%21 = vector.transfer_read %6[%9, %14, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%22 = vector.transfer_read %5[%9, %12, %16, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%23 = vector.transpose %21, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%24 = vector.extract %23[0] : vector<16x9xf32>
%25 = vector.extract %22[0] : vector<16x32xf32>
%26 = vector.outerproduct %24, %25, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%27 = vector.extract %23[1] : vector<16x9xf32>
%28 = vector.extract %22[1] : vector<16x32xf32>
%29 = vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%30 = vector.extract %23[2] : vector<16x9xf32>
%31 = vector.extract %22[2] : vector<16x32xf32>
%32 = vector.outerproduct %30, %31, %29 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%33 = vector.extract %23[3] : vector<16x9xf32>
%34 = vector.extract %22[3] : vector<16x32xf32>
%35 = vector.outerproduct %33, %34, %32 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%36 = vector.extract %23[4] : vector<16x9xf32>
%37 = vector.extract %22[4] : vector<16x32xf32>
%38 = vector.outerproduct %36, %37, %35 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%39 = vector.extract %23[5] : vector<16x9xf32>
%40 = vector.extract %22[5] : vector<16x32xf32>
%41 = vector.outerproduct %39, %40, %38 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%42 = vector.extract %23[6] : vector<16x9xf32>
%43 = vector.extract %22[6] : vector<16x32xf32>
%44 = vector.outerproduct %42, %43, %41 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%45 = vector.extract %23[7] : vector<16x9xf32>
%46 = vector.extract %22[7] : vector<16x32xf32>
%47 = vector.outerproduct %45, %46, %44 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%48 = vector.extract %23[8] : vector<16x9xf32>
%49 = vector.extract %22[8] : vector<16x32xf32>
%50 = vector.outerproduct %48, %49, %47 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%51 = vector.extract %23[9] : vector<16x9xf32>
%52 = vector.extract %22[9] : vector<16x32xf32>
%53 = vector.outerproduct %51, %52, %50 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%54 = vector.extract %23[10] : vector<16x9xf32>
%55 = vector.extract %22[10] : vector<16x32xf32>
%56 = vector.outerproduct %54, %55, %53 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%57 = vector.extract %23[11] : vector<16x9xf32>
%58 = vector.extract %22[11] : vector<16x32xf32>
%59 = vector.outerproduct %57, %58, %56 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%60 = vector.extract %23[12] : vector<16x9xf32>
%61 = vector.extract %22[12] : vector<16x32xf32>
%62 = vector.outerproduct %60, %61, %59 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%63 = vector.extract %23[13] : vector<16x9xf32>
%64 = vector.extract %22[13] : vector<16x32xf32>
%65 = vector.outerproduct %63, %64, %62 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%66 = vector.extract %23[14] : vector<16x9xf32>
%67 = vector.extract %22[14] : vector<16x32xf32>
%68 = vector.outerproduct %66, %67, %65 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%69 = vector.extract %23[15] : vector<16x9xf32>
%70 = vector.extract %22[15] : vector<16x32xf32>
%71 = vector.outerproduct %69, %70, %68 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %71 : vector<9x32xf32>
}
vector.transfer_write %19, %17[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map11>
}
}
}
}
}
memref.dealloc %5 : memref<?x?x4x32x16x32xf32>
memref.dealloc %6 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6923d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map13 = affine_map<(d0) -> (d0 ceildiv 9)>
#map14 = affine_map<(d0, d1) -> (9, -d0 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%0 = memref.dim %arg0, %c0 : memref<?x?xf32>
%1 = memref.dim %arg0, %c1 : memref<?x?xf32>
%2 = memref.dim %arg1, %c1 : memref<?x?xf32>
%3 = affine.apply #map0()[%1]
%4 = affine.apply #map1()[%2]
%5 = memref.alloc(%3, %4) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %1 step %c512 {
%7 = affine.apply #map2(%arg3)
%8 = affine.min #map3(%arg3)[%1]
scf.for %arg4 = %c0 to %2 step %c128 {
%9 = affine.apply #map4(%arg4)
%10 = affine.min #map5(%arg4)[%2]
scf.for %arg5 = %c0 to %10 step %c32 {
%11 = affine.apply #map6(%arg5)
%12 = affine.apply #map7(%arg5, %arg4)
%13 = affine.min #map8(%arg5, %10)
scf.for %arg6 = %c0 to %8 step %c16 {
%14 = affine.apply #map9(%arg6)
%15 = affine.apply #map7(%arg6, %arg3)
%16 = affine.min #map10(%arg6, %8)
%17 = memref.subview %arg1[%15, %12] [%16, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %18, %5[%7, %9, %11, %14, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%6 = memref.alloc(%3) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %0 step %c288 {
%7 = affine.min #map12(%arg3)[%0]
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.apply #map2(%arg4)
%9 = affine.min #map3(%arg4)[%1]
scf.for %arg5 = %c0 to %7 step %c9 {
%10 = affine.apply #map13(%arg5)
%11 = affine.apply #map7(%arg5, %arg3)
%12 = affine.min #map14(%arg5, %7)
scf.for %arg6 = %c0 to %9 step %c16 {
%13 = affine.apply #map9(%arg6)
%14 = affine.apply #map7(%arg6, %arg4)
%15 = affine.min #map10(%arg6, %9)
%16 = memref.subview %arg0[%11, %14] [%12, %15] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%17 = vector.transfer_read %16[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %17, %6[%8, %10, %13, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %1 step %c512 {
%8 = affine.min #map3(%arg4)[%1]
%9 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %2 step %c128 {
%10 = affine.min #map5(%arg5)[%2]
%11 = memref.subview %arg2[%arg3, %arg5] [%7, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%12 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %7 step %c9 {
%13 = affine.min #map14(%arg6, %7)
%14 = affine.apply #map13(%arg6)
scf.for %arg7 = %c0 to %10 step %c32 {
%15 = affine.min #map8(%arg7, %10)
%16 = affine.apply #map6(%arg7)
%17 = memref.subview %11[%arg6, %arg7] [%13, %15] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%18 = vector.transfer_read %17[%c0, %c0], %cst : memref<?x?xf32, #map11>, vector<9x32xf32>
%19 = scf.for %arg8 = %c0 to %8 step %c16 iter_args(%arg9 = %18) -> (vector<9x32xf32>) {
%20 = affine.apply #map9(%arg8)
%21 = vector.transfer_read %6[%9, %14, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%22 = vector.transfer_read %5[%9, %12, %16, %20, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%23 = vector.transpose %21, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%24 = vector.extract %23[0] : vector<16x9xf32>
%25 = vector.extract %22[0] : vector<16x32xf32>
%26 = vector.outerproduct %24, %25, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%27 = vector.extract %23[1] : vector<16x9xf32>
%28 = vector.extract %22[1] : vector<16x32xf32>
%29 = vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%30 = vector.extract %23[2] : vector<16x9xf32>
%31 = vector.extract %22[2] : vector<16x32xf32>
%32 = vector.outerproduct %30, %31, %29 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%33 = vector.extract %23[3] : vector<16x9xf32>
%34 = vector.extract %22[3] : vector<16x32xf32>
%35 = vector.outerproduct %33, %34, %32 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%36 = vector.extract %23[4] : vector<16x9xf32>
%37 = vector.extract %22[4] : vector<16x32xf32>
%38 = vector.outerproduct %36, %37, %35 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%39 = vector.extract %23[5] : vector<16x9xf32>
%40 = vector.extract %22[5] : vector<16x32xf32>
%41 = vector.outerproduct %39, %40, %38 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%42 = vector.extract %23[6] : vector<16x9xf32>
%43 = vector.extract %22[6] : vector<16x32xf32>
%44 = vector.outerproduct %42, %43, %41 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%45 = vector.extract %23[7] : vector<16x9xf32>
%46 = vector.extract %22[7] : vector<16x32xf32>
%47 = vector.outerproduct %45, %46, %44 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%48 = vector.extract %23[8] : vector<16x9xf32>
%49 = vector.extract %22[8] : vector<16x32xf32>
%50 = vector.outerproduct %48, %49, %47 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%51 = vector.extract %23[9] : vector<16x9xf32>
%52 = vector.extract %22[9] : vector<16x32xf32>
%53 = vector.outerproduct %51, %52, %50 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%54 = vector.extract %23[10] : vector<16x9xf32>
%55 = vector.extract %22[10] : vector<16x32xf32>
%56 = vector.outerproduct %54, %55, %53 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%57 = vector.extract %23[11] : vector<16x9xf32>
%58 = vector.extract %22[11] : vector<16x32xf32>
%59 = vector.outerproduct %57, %58, %56 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%60 = vector.extract %23[12] : vector<16x9xf32>
%61 = vector.extract %22[12] : vector<16x32xf32>
%62 = vector.outerproduct %60, %61, %59 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%63 = vector.extract %23[13] : vector<16x9xf32>
%64 = vector.extract %22[13] : vector<16x32xf32>
%65 = vector.outerproduct %63, %64, %62 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%66 = vector.extract %23[14] : vector<16x9xf32>
%67 = vector.extract %22[14] : vector<16x32xf32>
%68 = vector.outerproduct %66, %67, %65 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%69 = vector.extract %23[15] : vector<16x9xf32>
%70 = vector.extract %22[15] : vector<16x32xf32>
%71 = vector.outerproduct %69, %70, %68 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %71 : vector<9x32xf32>
}
vector.transfer_write %19, %17[%c0, %c0] : vector<9x32xf32>, memref<?x?xf32, #map11>
}
}
}
}
}
memref.dealloc %5 : memref<?x?x4x32x16x32xf32>
memref.dealloc %6 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692210>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%true = arith.constant true
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%27 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%28 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%27, %28) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%29 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %29 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %26, %9[%11, %13, %15, %19, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%26 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%27 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%26, %27) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%28 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
}
%25 = vector.transfer_read %24[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %25, %10[%12, %14, %18, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%30 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%31 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%30, %31) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%32 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %32 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x32xf32>
%27 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %26) -> (vector<9x32xf32>) {
%30 = affine.apply #map9(%arg8)
%31 = vector.transfer_read %10[%13, %18, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%32 = vector.transfer_read %9[%13, %16, %21, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%33 = vector.transpose %31, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%34 = vector.extract %33[0] : vector<16x9xf32>
%35 = vector.extract %32[0] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %33[1] : vector<16x9xf32>
%38 = vector.extract %32[1] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %33[2] : vector<16x9xf32>
%41 = vector.extract %32[2] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %33[3] : vector<16x9xf32>
%44 = vector.extract %32[3] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %33[4] : vector<16x9xf32>
%47 = vector.extract %32[4] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %33[5] : vector<16x9xf32>
%50 = vector.extract %32[5] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %33[6] : vector<16x9xf32>
%53 = vector.extract %32[6] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %33[7] : vector<16x9xf32>
%56 = vector.extract %32[7] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %33[8] : vector<16x9xf32>
%59 = vector.extract %32[8] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %33[9] : vector<16x9xf32>
%62 = vector.extract %32[9] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %33[10] : vector<16x9xf32>
%65 = vector.extract %32[10] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %33[11] : vector<16x9xf32>
%68 = vector.extract %32[11] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%70 = vector.extract %33[12] : vector<16x9xf32>
%71 = vector.extract %32[12] : vector<16x32xf32>
%72 = vector.outerproduct %70, %71, %69 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%73 = vector.extract %33[13] : vector<16x9xf32>
%74 = vector.extract %32[13] : vector<16x32xf32>
%75 = vector.outerproduct %73, %74, %72 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%76 = vector.extract %33[14] : vector<16x9xf32>
%77 = vector.extract %32[14] : vector<16x32xf32>
%78 = vector.outerproduct %76, %77, %75 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%79 = vector.extract %33[15] : vector<16x9xf32>
%80 = vector.extract %32[15] : vector<16x32xf32>
%81 = vector.outerproduct %79, %80, %78 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %81 : vector<9x32xf32>
}
%28 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%30 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %30 : memref<?x?xf32, #map11>
}
vector.transfer_write %27, %28[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map11>
%29 = arith.xori %24, %true : i1
scf.if %29 {
%30 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%31 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%30, %31) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b6921d0>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%true = arith.constant true
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%27 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%28 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%27, %28) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%29 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %29 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<16x32xf32>
vector.transfer_write %26, %9[%11, %13, %15, %19, %c0, %c0] {in_bounds = [true, true]} : vector<16x32xf32>, memref<?x?x4x32x16x32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %1) : f32, memref<9x16xf32>
%26 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%27 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%26, %27) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%28 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %28 : memref<?x?xf32, #map11>
}
%25 = vector.transfer_read %24[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x16xf32>
vector.transfer_write %25, %10[%12, %14, %18, %c0, %c0] {in_bounds = [true, true]} : vector<9x16xf32>, memref<?x32x32x9x16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %2) : f32, memref<9x32xf32>
%30 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%31 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%30, %31) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%32 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %32 : memref<?x?xf32, #map11>
}
%26 = vector.transfer_read %25[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, #map11>, vector<9x32xf32>
%27 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %26) -> (vector<9x32xf32>) {
%30 = affine.apply #map9(%arg8)
%31 = vector.transfer_read %10[%13, %18, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x32x32x9x16xf32>, vector<9x16xf32>
%32 = vector.transfer_read %9[%13, %16, %21, %30, %c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?x4x32x16x32xf32>, vector<16x32xf32>
%33 = vector.transpose %31, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%34 = vector.extract %33[0] : vector<16x9xf32>
%35 = vector.extract %32[0] : vector<16x32xf32>
%36 = vector.outerproduct %34, %35, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%37 = vector.extract %33[1] : vector<16x9xf32>
%38 = vector.extract %32[1] : vector<16x32xf32>
%39 = vector.outerproduct %37, %38, %36 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%40 = vector.extract %33[2] : vector<16x9xf32>
%41 = vector.extract %32[2] : vector<16x32xf32>
%42 = vector.outerproduct %40, %41, %39 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%43 = vector.extract %33[3] : vector<16x9xf32>
%44 = vector.extract %32[3] : vector<16x32xf32>
%45 = vector.outerproduct %43, %44, %42 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%46 = vector.extract %33[4] : vector<16x9xf32>
%47 = vector.extract %32[4] : vector<16x32xf32>
%48 = vector.outerproduct %46, %47, %45 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%49 = vector.extract %33[5] : vector<16x9xf32>
%50 = vector.extract %32[5] : vector<16x32xf32>
%51 = vector.outerproduct %49, %50, %48 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%52 = vector.extract %33[6] : vector<16x9xf32>
%53 = vector.extract %32[6] : vector<16x32xf32>
%54 = vector.outerproduct %52, %53, %51 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%55 = vector.extract %33[7] : vector<16x9xf32>
%56 = vector.extract %32[7] : vector<16x32xf32>
%57 = vector.outerproduct %55, %56, %54 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%58 = vector.extract %33[8] : vector<16x9xf32>
%59 = vector.extract %32[8] : vector<16x32xf32>
%60 = vector.outerproduct %58, %59, %57 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%61 = vector.extract %33[9] : vector<16x9xf32>
%62 = vector.extract %32[9] : vector<16x32xf32>
%63 = vector.outerproduct %61, %62, %60 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%64 = vector.extract %33[10] : vector<16x9xf32>
%65 = vector.extract %32[10] : vector<16x32xf32>
%66 = vector.outerproduct %64, %65, %63 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%67 = vector.extract %33[11] : vector<16x9xf32>
%68 = vector.extract %32[11] : vector<16x32xf32>
%69 = vector.outerproduct %67, %68, %66 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%70 = vector.extract %33[12] : vector<16x9xf32>
%71 = vector.extract %32[12] : vector<16x32xf32>
%72 = vector.outerproduct %70, %71, %69 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%73 = vector.extract %33[13] : vector<16x9xf32>
%74 = vector.extract %32[13] : vector<16x32xf32>
%75 = vector.outerproduct %73, %74, %72 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%76 = vector.extract %33[14] : vector<16x9xf32>
%77 = vector.extract %32[14] : vector<16x32xf32>
%78 = vector.outerproduct %76, %77, %75 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%79 = vector.extract %33[15] : vector<16x9xf32>
%80 = vector.extract %32[15] : vector<16x32xf32>
%81 = vector.outerproduct %79, %80, %78 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %81 : vector<9x32xf32>
}
%28 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%30 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %30 : memref<?x?xf32, #map11>
}
vector.transfer_write %27, %28[%c0, %c0] {in_bounds = [true, true]} : vector<9x32xf32>, memref<?x?xf32, #map11>
%29 = arith.xori %24, %true : i1
scf.if %29 {
%30 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%31 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%30, %31) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692150>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c15 = arith.constant 15 : index
%c14 = arith.constant 14 : index
%c13 = arith.constant 13 : index
%c12 = arith.constant 12 : index
%c11 = arith.constant 11 : index
%c10 = arith.constant 10 : index
%cst = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%true = arith.constant true
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c288 = arith.constant 288 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst_1, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %0) : f32, memref<16x32xf32>
%42 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%43 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%42, %43) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%44 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %44 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%28 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%30 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%32 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%34 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.load %25[%c9, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%36 = vector.load %25[%c10, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.load %25[%c11, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%38 = vector.load %25[%c12, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.load %25[%c13, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%40 = vector.load %25[%c14, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.load %25[%c15, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
vector.store %26, %9[%11, %13, %15, %19, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %27, %9[%11, %13, %15, %19, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %28, %9[%11, %13, %15, %19, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %29, %9[%11, %13, %15, %19, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %30, %9[%11, %13, %15, %19, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %31, %9[%11, %13, %15, %19, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %32, %9[%11, %13, %15, %19, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %33, %9[%11, %13, %15, %19, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %34, %9[%11, %13, %15, %19, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %35, %9[%11, %13, %15, %19, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %36, %9[%11, %13, %15, %19, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %37, %9[%11, %13, %15, %19, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %38, %9[%11, %13, %15, %19, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %39, %9[%11, %13, %15, %19, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %40, %9[%11, %13, %15, %19, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
vector.store %41, %9[%11, %13, %15, %19, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
}
}
}
}
%10 = memref.alloc(%7) {alignment = 128 : i64} : memref<?x32x32x9x16xf32>
scf.for %arg3 = %c0 to %4 step %c288 {
%11 = affine.min #map13(%arg3)[%4]
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.apply #map2(%arg4)
%13 = affine.min #map3(%arg4)[%5]
scf.for %arg5 = %c0 to %11 step %c9 {
%14 = affine.apply #map14(%arg5)
%15 = affine.apply #map7(%arg5, %arg3)
%16 = affine.min #map15(%arg5, %11)
%17 = arith.cmpi sle, %c9, %16 : index
scf.for %arg6 = %c0 to %13 step %c16 {
%18 = affine.apply #map9(%arg6)
%19 = affine.apply #map7(%arg6, %arg4)
%20 = affine.min #map10(%arg6, %13)
%21 = memref.subview %arg0[%15, %19] [%16, %20] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%22 = arith.cmpi sle, %c16, %20 : index
%23 = arith.andi %17, %22 : i1
%24 = scf.if %23 -> (memref<?x?xf32, #map11>) {
scf.yield %21 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %1) : f32, memref<9x16xf32>
%34 = memref.subview %21[0, 0] [%16, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%35 = memref.subview %1[0, 0] [%16, %20] [1, 1] : memref<9x16xf32> to memref<?x?xf32, #map16>
linalg.copy(%34, %35) : memref<?x?xf32, #map11>, memref<?x?xf32, #map16>
%36 = memref.cast %1 : memref<9x16xf32> to memref<?x?xf32, #map11>
scf.yield %36 : memref<?x?xf32, #map11>
}
%25 = vector.load %24[%c0, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%26 = vector.load %24[%c1, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%27 = vector.load %24[%c2, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%28 = vector.load %24[%c3, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%29 = vector.load %24[%c4, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%30 = vector.load %24[%c5, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%31 = vector.load %24[%c6, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%32 = vector.load %24[%c7, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
%33 = vector.load %24[%c8, %c0] : memref<?x?xf32, #map11>, vector<16xf32>
vector.store %25, %10[%12, %14, %18, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %26, %10[%12, %14, %18, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %27, %10[%12, %14, %18, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %28, %10[%12, %14, %18, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %29, %10[%12, %14, %18, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %30, %10[%12, %14, %18, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %31, %10[%12, %14, %18, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %32, %10[%12, %14, %18, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
vector.store %33, %10[%12, %14, %18, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
}
}
}
scf.for %arg4 = %c0 to %5 step %c512 {
%12 = affine.min #map3(%arg4)[%5]
%13 = affine.apply #map2(%arg4)
scf.for %arg5 = %c0 to %6 step %c128 {
%14 = affine.min #map5(%arg5)[%6]
%15 = memref.subview %arg2[%arg3, %arg5] [%11, %14] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%16 = affine.apply #map4(%arg5)
scf.for %arg6 = %c0 to %11 step %c9 {
%17 = affine.min #map15(%arg6, %11)
%18 = affine.apply #map14(%arg6)
%19 = arith.cmpi sle, %c9, %17 : index
scf.for %arg7 = %c0 to %14 step %c32 {
%20 = affine.min #map8(%arg7, %14)
%21 = affine.apply #map6(%arg7)
%22 = memref.subview %15[%arg6, %arg7] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c32, %20 : index
%24 = arith.andi %19, %23 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst_1, %2) : f32, memref<9x32xf32>
%56 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%57 = memref.subview %2[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%56, %57) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%58 = memref.cast %2 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %58 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.insert %26, %cst_0 [0] : vector<32xf32> into vector<9x32xf32>
%28 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%29 = vector.insert %28, %27 [1] : vector<32xf32> into vector<9x32xf32>
%30 = vector.load %25[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%31 = vector.insert %30, %29 [2] : vector<32xf32> into vector<9x32xf32>
%32 = vector.load %25[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%33 = vector.insert %32, %31 [3] : vector<32xf32> into vector<9x32xf32>
%34 = vector.load %25[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%35 = vector.insert %34, %33 [4] : vector<32xf32> into vector<9x32xf32>
%36 = vector.load %25[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%37 = vector.insert %36, %35 [5] : vector<32xf32> into vector<9x32xf32>
%38 = vector.load %25[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%39 = vector.insert %38, %37 [6] : vector<32xf32> into vector<9x32xf32>
%40 = vector.load %25[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%41 = vector.insert %40, %39 [7] : vector<32xf32> into vector<9x32xf32>
%42 = vector.load %25[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%43 = vector.insert %42, %41 [8] : vector<32xf32> into vector<9x32xf32>
%44 = scf.for %arg8 = %c0 to %12 step %c16 iter_args(%arg9 = %43) -> (vector<9x32xf32>) {
%56 = affine.apply #map9(%arg8)
%57 = vector.load %10[%13, %18, %56, %c0, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%58 = vector.insert %57, %cst [0] : vector<16xf32> into vector<9x16xf32>
%59 = vector.load %10[%13, %18, %56, %c1, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%60 = vector.insert %59, %58 [1] : vector<16xf32> into vector<9x16xf32>
%61 = vector.load %10[%13, %18, %56, %c2, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%62 = vector.insert %61, %60 [2] : vector<16xf32> into vector<9x16xf32>
%63 = vector.load %10[%13, %18, %56, %c3, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%64 = vector.insert %63, %62 [3] : vector<16xf32> into vector<9x16xf32>
%65 = vector.load %10[%13, %18, %56, %c4, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%66 = vector.insert %65, %64 [4] : vector<16xf32> into vector<9x16xf32>
%67 = vector.load %10[%13, %18, %56, %c5, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%68 = vector.insert %67, %66 [5] : vector<16xf32> into vector<9x16xf32>
%69 = vector.load %10[%13, %18, %56, %c6, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%70 = vector.insert %69, %68 [6] : vector<16xf32> into vector<9x16xf32>
%71 = vector.load %10[%13, %18, %56, %c7, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%72 = vector.insert %71, %70 [7] : vector<16xf32> into vector<9x16xf32>
%73 = vector.load %10[%13, %18, %56, %c8, %c0] : memref<?x32x32x9x16xf32>, vector<16xf32>
%74 = vector.insert %73, %72 [8] : vector<16xf32> into vector<9x16xf32>
%75 = vector.load %9[%13, %16, %21, %56, %c0, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%76 = vector.load %9[%13, %16, %21, %56, %c1, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%77 = vector.load %9[%13, %16, %21, %56, %c2, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%78 = vector.load %9[%13, %16, %21, %56, %c3, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%79 = vector.load %9[%13, %16, %21, %56, %c4, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%80 = vector.load %9[%13, %16, %21, %56, %c5, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%81 = vector.load %9[%13, %16, %21, %56, %c6, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%82 = vector.load %9[%13, %16, %21, %56, %c7, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%83 = vector.load %9[%13, %16, %21, %56, %c8, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%84 = vector.load %9[%13, %16, %21, %56, %c9, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%85 = vector.load %9[%13, %16, %21, %56, %c10, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%86 = vector.load %9[%13, %16, %21, %56, %c11, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%87 = vector.load %9[%13, %16, %21, %56, %c12, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%88 = vector.load %9[%13, %16, %21, %56, %c13, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%89 = vector.load %9[%13, %16, %21, %56, %c14, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%90 = vector.load %9[%13, %16, %21, %56, %c15, %c0] : memref<?x?x4x32x16x32xf32>, vector<32xf32>
%91 = vector.transpose %74, [1, 0] : vector<9x16xf32> to vector<16x9xf32>
%92 = vector.extract %91[0] : vector<16x9xf32>
%93 = vector.outerproduct %92, %75, %arg9 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%94 = vector.extract %91[1] : vector<16x9xf32>
%95 = vector.outerproduct %94, %76, %93 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%96 = vector.extract %91[2] : vector<16x9xf32>
%97 = vector.outerproduct %96, %77, %95 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%98 = vector.extract %91[3] : vector<16x9xf32>
%99 = vector.outerproduct %98, %78, %97 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%100 = vector.extract %91[4] : vector<16x9xf32>
%101 = vector.outerproduct %100, %79, %99 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%102 = vector.extract %91[5] : vector<16x9xf32>
%103 = vector.outerproduct %102, %80, %101 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%104 = vector.extract %91[6] : vector<16x9xf32>
%105 = vector.outerproduct %104, %81, %103 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%106 = vector.extract %91[7] : vector<16x9xf32>
%107 = vector.outerproduct %106, %82, %105 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%108 = vector.extract %91[8] : vector<16x9xf32>
%109 = vector.outerproduct %108, %83, %107 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%110 = vector.extract %91[9] : vector<16x9xf32>
%111 = vector.outerproduct %110, %84, %109 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%112 = vector.extract %91[10] : vector<16x9xf32>
%113 = vector.outerproduct %112, %85, %111 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%114 = vector.extract %91[11] : vector<16x9xf32>
%115 = vector.outerproduct %114, %86, %113 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%116 = vector.extract %91[12] : vector<16x9xf32>
%117 = vector.outerproduct %116, %87, %115 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%118 = vector.extract %91[13] : vector<16x9xf32>
%119 = vector.outerproduct %118, %88, %117 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%120 = vector.extract %91[14] : vector<16x9xf32>
%121 = vector.outerproduct %120, %89, %119 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
%122 = vector.extract %91[15] : vector<16x9xf32>
%123 = vector.outerproduct %122, %90, %121 {kind = #vector.kind<add>} : vector<9xf32>, vector<32xf32>
scf.yield %123 : vector<9x32xf32>
}
%45 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
%56 = memref.cast %3 : memref<9x32xf32> to memref<?x?xf32, #map11>
scf.yield %56 : memref<?x?xf32, #map11>
}
%46 = vector.extract %44[0] : vector<9x32xf32>
vector.store %46, %45[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%47 = vector.extract %44[1] : vector<9x32xf32>
vector.store %47, %45[%c1, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%48 = vector.extract %44[2] : vector<9x32xf32>
vector.store %48, %45[%c2, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%49 = vector.extract %44[3] : vector<9x32xf32>
vector.store %49, %45[%c3, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%50 = vector.extract %44[4] : vector<9x32xf32>
vector.store %50, %45[%c4, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%51 = vector.extract %44[5] : vector<9x32xf32>
vector.store %51, %45[%c5, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%52 = vector.extract %44[6] : vector<9x32xf32>
vector.store %52, %45[%c6, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%53 = vector.extract %44[7] : vector<9x32xf32>
vector.store %53, %45[%c7, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%54 = vector.extract %44[8] : vector<9x32xf32>
vector.store %54, %45[%c8, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%55 = arith.xori %24, %true : i1
scf.if %55 {
%56 = memref.subview %3[0, 0] [%17, %20] [1, 1] : memref<9x32xf32> to memref<?x?xf32, #map12>
%57 = memref.subview %22[0, 0] [%17, %20] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
linalg.copy(%56, %57) : memref<?x?xf32, #map12>, memref<?x?xf32, #map11>
}
}
}
}
}
}
memref.dealloc %9 : memref<?x?x4x32x16x32xf32>
memref.dealloc %10 : memref<?x32x32x9x16xf32>
return
}
func public @matmul_main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>, %arg3: index) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
scf.for %arg4 = %c0 to %arg3 step %c1 {
call @matmul_on_tensors(%arg0, %arg1, %arg2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
}
return
}
}
[[[ IR after transform: <google3.third_party.mlir_edge.iree_llvm_sandbox.python.core.transforms.LowerVectors object at 0x7fb43b692390>]]]
#map0 = affine_map<()[s0] -> (s0 ceildiv 512)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 128)>
#map2 = affine_map<(d0) -> (d0 ceildiv 512)>
#map3 = affine_map<(d0)[s0] -> (512, -d0 + s0)>
#map4 = affine_map<(d0) -> (d0 ceildiv 128)>
#map5 = affine_map<(d0)[s0] -> (128, -d0 + s0)>
#map6 = affine_map<(d0) -> (d0 ceildiv 32)>
#map7 = affine_map<(d0, d1) -> (d0 + d1)>
#map8 = affine_map<(d0, d1) -> (32, -d0 + d1)>
#map9 = affine_map<(d0) -> (d0 ceildiv 16)>
#map10 = affine_map<(d0, d1) -> (16, -d0 + d1)>
#map11 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map12 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map13 = affine_map<(d0)[s0] -> (288, -d0 + s0)>
#map14 = affine_map<(d0) -> (d0 ceildiv 9)>
#map15 = affine_map<(d0, d1) -> (9, -d0 + d1)>
#map16 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
module {
func @matmul_on_tensors(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {passthrough = ["noinline", ["target-cpu", "broadwell"], ["prefer-vector-width", "256"]]} {
%c16 = arith.constant 16 : index
%c9 = arith.constant 9 : index
%cst = arith.constant 0.000000e+00 : f32
%c288 = arith.constant 288 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%true = arith.constant true
%cst_0 = arith.constant dense<0.000000e+00> : vector<9x32xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<9x16xf32>
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c13 = arith.constant 13 : index
%c14 = arith.constant 14 : index
%c15 = arith.constant 15 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%c8 = arith.constant 8 : index
%0 = memref.alloca() {alignment = 32 : i64} : memref<16x32xf32>
%1 = memref.alloca() {alignment = 32 : i64} : memref<9x16xf32>
%2 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
%3 = memref.alloca() {alignment = 32 : i64} : memref<9x32xf32>
linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
%4 = memref.dim %arg0, %c0 : memref<?x?xf32>
%5 = memref.dim %arg0, %c1 : memref<?x?xf32>
%6 = memref.dim %arg1, %c1 : memref<?x?xf32>
%7 = affine.apply #map0()[%5]
%8 = affine.apply #map1()[%6]
%9 = memref.alloc(%7, %8) {alignment = 128 : i64} : memref<?x?x4x32x16x32xf32>
scf.for %arg3 = %c0 to %5 step %c512 {
%11 = affine.apply #map2(%arg3)
%12 = affine.min #map3(%arg3)[%5]
scf.for %arg4 = %c0 to %6 step %c128 {
%13 = affine.apply #map4(%arg4)
%14 = affine.min #map5(%arg4)[%6]
scf.for %arg5 = %c0 to %14 step %c32 {
%15 = affine.apply #map6(%arg5)
%16 = affine.apply #map7(%arg5, %arg4)
%17 = affine.min #map8(%arg5, %14)
%18 = arith.cmpi sle, %c32, %17 : index
scf.for %arg6 = %c0 to %12 step %c16 {
%19 = affine.apply #map9(%arg6)
%20 = affine.apply #map7(%arg6, %arg3)
%21 = affine.min #map10(%arg6, %12)
%22 = memref.subview %arg1[%20, %16] [%21, %17] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map11>
%23 = arith.cmpi sle, %c16, %21 : index
%24 = arith.andi %23, %18 : i1
%25 = scf.if %24 -> (memref<?x?xf32, #map11>) {
scf.yield %22 : memref<?x?xf32, #map11>
} else {
linalg.fill(%cst, %0) : f32, memref<16x32xf32>
%42 = memref.subview %22[0, 0] [%21, %17] [1, 1] : memref<?x?xf32, #map11> to memref<?x?xf32, #map11>
%43 = memref.subview %0[0, 0] [%21, %17] [1, 1] : memref<16x32xf32> to memref<?x?xf32, #map12>
linalg.copy(%42, %43) : memref<?x?xf32, #map11>, memref<?x?xf32, #map12>
%44 = memref.cast %0 : memref<16x32xf32> to memref<?x?xf32, #map11>
scf.yield %44 : memref<?x?xf32, #map11>
}
%26 = vector.load %25[%c0, %c0] : memref<?x?xf32, #map11>, vector<32xf32>
%27 = vector.load %25[%c1, %c0] : memref<?x?xf32, #map11>, vector<3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment