Skip to content

Instantly share code, notes, and snippets.

@gysit
gysit / fusion.mlir
Created April 11, 2022 19:11
fusion lowering
// 1) initial IR.
func @matmul_bias_add(%arg0: tensor<601x513xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<513x321xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<321xf32> {linalg.buffer_layout = affine_map<(d0) -> (d0)>, linalg.inplaceable = false}, %arg3: tensor<601x321xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<601x321xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg3 : tensor<601x321xf32>) -> tensor<601x321xf32>
%1 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<601x513xf32>, tensor<513x321xf32>) outs(%0 : tensor<601x321xf32>) -> tensor<601x321xf32>
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)
@gysit
gysit / conv2d_canonicalization.mlir
Created December 10, 2021 15:25
conv2d effect of canonicalization
// ir after applyPatternsAndFoldGreedily in LinalgStrategyVectorizePass (after the first execution)
%0 = vector.transfer_write %cst_0, %arg2[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<8x16x15x64xf32>, tensor<8x16x15x64xf32>
%1 = scf.for %arg3 = %c0 to %c8 step %c1 iter_args(%arg4 = %0) -> (tensor<8x16x15x64xf32>) {
%2 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %arg4) -> (tensor<8x16x15x64xf32>) {
%3 = scf.for %arg7 = %c0 to %c15 step %c8 iter_args(%arg8 = %arg6) -> (tensor<8x16x15x64xf32>) {
%4 = affine.min affine_map<(d0) -> (8, -d0 + 15)>(%arg7)
%5 = affine.apply affine_map<(d0) -> (-d0 + 8)>(%4)
%6 = scf.for %arg9 = %c0 to %c64 step %c32 iter_args(%arg10 = %arg8) -> (tensor<8x16x15x64xf32>) {
%7 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<8x16x15x64xf32>) {
@gysit
gysit / conv2d_2x_vectorize.mlir
Created December 10, 2021 15:13
conv2d lowering with two vectorization passes
// ir after Vectorize(fun_name, "", vectorize_paddings=False)
func @conv_2d_nhwc_hwcf_main(%arg0: tensor<8x18x17x32xf32> {linalg.buffer_layout = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, linalg.inplaceable = false}, %arg1: tensor<3x3x32x64xf32> {linalg.buffer_layout = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, linalg.inplaceable = false}, %arg2: tensor<8x16x15x64xf32> {linalg.buffer_layout = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, linalg.inplaceable = true}) -> tensor<8x16x15x64xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
%cst = arith.constant dense<0.000000e+00> : vector<8x16x15x64xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%c15 = arith.constant 15 : index
%c64 = arith.constant 64 : index
@gysit
gysit / conv2d_1x_vectorize.mlir
Last active December 10, 2021 15:08
conv2d lowering with single vectorization pass
// ir before running vectorization
func @conv_2d_nhwc_hwcf_main(%arg0: tensor<8x18x17x32xf32> {linalg.buffer_layout = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, linalg.inplaceable = false}, %arg1: tensor<3x3x32x64xf32> {linalg.buffer_layout = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, linalg.inplaceable = false}, %arg2: tensor<8x16x15x64xf32> {linalg.buffer_layout = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, linalg.inplaceable = true}) -> tensor<8x16x15x64xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c64 = arith.constant 64 : index
%c15 = arith.constant 15 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index