Created
March 17, 2024 20:10
-
-
Save meshtag/35c1a98e53cbaba71daf671a52ba24d1 to your computer and use it in GitHub Desktop.
Sanity test for channel processing.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// #map = affine_map<(d0) -> (d0)> | |
#map0 = affine_map<(d0, d1) -> (d0 + d1 - 1)> | |
// #map1 = affine_map<(d0) -> (d0 ceildiv 256)> | |
#map_new = affine_map<(d0, d1, d2) -> (d0 + d1 - d2)> | |
#map = affine_map<(d0, d1, d2) -> (d0 + d1 - d2)> | |
#map1 = affine_map<(d0) -> (d0)> | |
module { | |
func.func private @printMemrefF32(memref<*xf32>) | |
func.func private @print_flops(f64) | |
func.func private @rtclock() -> f64 | |
func.func @alloc_4d_filled_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref<?x?x?x?xf32> { | |
%c0 = arith.constant 0 : index | |
%c1 = arith.constant 1 : index | |
%0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32> | |
scf.for %arg5 = %c0 to %arg0 step %c1 { | |
scf.for %arg6 = %c0 to %arg1 step %c1 { | |
scf.for %arg7 = %c0 to %arg2 step %c1 { | |
scf.for %arg8 = %c0 to %arg3 step %c1 { | |
memref.store %arg4, %0[%arg5, %arg6, %arg7, %arg8] : memref<?x?x?x?xf32> | |
} | |
} | |
} | |
} | |
return %0 : memref<?x?x?x?xf32> | |
} | |
func.func @corr_2d_nchw_fchw_constant_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>, %arg3: index, %arg4: index, %arg5: f32) attributes {llvm.emit_c_interface} { | |
%c4 = arith.constant 4 : index | |
%c0 = arith.constant 0 : index | |
%c1 = arith.constant 1 : index | |
%c2 = arith.constant 2 : index | |
%c3 = arith.constant 3 : index | |
%dim = memref.dim %arg0, %c0 : memref<?x?x?x?xf32> | |
%dim_0 = memref.dim %arg0, %c1 : memref<?x?x?x?xf32> | |
%dim_1 = memref.dim %arg0, %c2 : memref<?x?x?x?xf32> | |
%dim_2 = memref.dim %arg0, %c3 : memref<?x?x?x?xf32> | |
%dim_3 = memref.dim %arg1, %c0 : memref<?x?x?x?xf32> | |
%dim_4 = memref.dim %arg1, %c2 : memref<?x?x?x?xf32> | |
%dim_5 = memref.dim %arg1, %c3 : memref<?x?x?x?xf32> | |
%0 = arith.addi %dim_1, %arg4 : index | |
%1 = arith.addi %dim_2, %arg3 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
%2 = vector.broadcast %cst : f32 to vector<4xf32> | |
%3 = affine.apply #map(%dim_2, %dim_5, %c1) | |
affine.for %arg6 = #map1(%c0) to #map1(%dim) { | |
affine.for %arg7 = #map1(%c0) to #map1(%dim_3) { | |
affine.for %arg8 = #map1(%c0) to #map1(%dim_0) { | |
affine.for %arg9 = #map1(%c0) to #map1(%dim_1) { | |
affine.for %arg10 = #map1(%c0) to #map1(%dim_4) { | |
affine.for %arg11 = #map1(%c0) to #map1(%dim_2) step 4 { | |
affine.for %arg12 = #map1(%c0) to #map1(%dim_5) { | |
%4 = arith.addi %arg9, %arg10 : index | |
%5 = arith.addi %arg11, %arg12 : index | |
%6 = memref.load %arg1[%arg7, %arg8, %arg10, %arg12] : memref<?x?x?x?xf32> | |
%7 = vector.broadcast %6 : f32 to vector<4xf32> | |
%8 = arith.subi %4, %arg4 : index | |
%9 = arith.subi %5, %arg3 : index | |
%10 = arith.addi %5, %c4 : index | |
%11 = arith.cmpi slt, %4, %arg4 : index | |
%12 = arith.cmpf one, %6, %cst : f32 | |
scf.if %12 { | |
scf.if %11 { | |
%13 = arith.cmpi slt, %5, %arg3 : index | |
scf.if %13 { | |
%14 = arith.subi %arg3, %5 : index | |
%15 = vector.create_mask %14 : vector<4xi1> | |
%16 = vector.create_mask %c4 : vector<4xi1> | |
%17 = arith.subi %16, %15 : vector<4xi1> | |
%18 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%19 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%20 = vector.fma %18, %7, %19 : vector<4xf32> | |
vector.store %20, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%14 = arith.cmpi slt, %10, %1 : index | |
scf.if %14 { | |
%15 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%16 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%17 = vector.fma %15, %7, %16 : vector<4xf32> | |
vector.store %17, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%15 = arith.subi %10, %1 : index | |
%16 = arith.subi %c4, %15 : index | |
%17 = vector.create_mask %16 : vector<4xi1> | |
%18 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%19 = affine.apply #map(%c4, %dim_5, %c1) | |
%20 = arith.subi %3, %arg11 : index | |
%21 = arith.cmpi sge, %20, %19 : index | |
scf.if %21 { | |
%22 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%23 = vector.fma %18, %7, %22 : vector<4xf32> | |
vector.store %23, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%22 = arith.subi %dim_2, %arg11 : index | |
%23 = vector.create_mask %22 : vector<4xi1> | |
%24 = vector.maskedload %arg2[%arg6, %arg7, %arg9, %arg11], %23, %2 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> | |
%25 = vector.fma %18, %7, %24 : vector<4xf32> | |
vector.maskedstore %arg2[%arg6, %arg7, %arg9, %arg11], %23, %25 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> | |
} | |
} | |
} | |
} else { | |
%13 = arith.cmpi slt, %4, %0 : index | |
scf.if %13 { | |
%14 = arith.cmpi slt, %5, %arg3 : index | |
scf.if %14 { | |
%15 = arith.subi %arg3, %5 : index | |
%16 = vector.create_mask %15 : vector<4xi1> | |
%17 = vector.create_mask %c4 : vector<4xi1> | |
%18 = arith.subi %17, %16 : vector<4xi1> | |
%19 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%20 = arith.subi %c0, %15 : index | |
%21 = vector.maskedload %arg0[%arg6, %arg8, %8, %20], %18, %19 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> | |
%22 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%23 = vector.fma %21, %7, %22 : vector<4xf32> | |
vector.store %23, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%15 = arith.cmpi slt, %10, %1 : index | |
scf.if %15 { | |
%16 = vector.load %arg0[%arg6, %arg8, %8, %9] : memref<?x?x?x?xf32>, vector<4xf32> | |
%17 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%18 = vector.fma %16, %7, %17 : vector<4xf32> | |
vector.store %18, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%16 = arith.subi %10, %1 : index | |
%17 = arith.subi %c4, %16 : index | |
%18 = vector.create_mask %17 : vector<4xi1> | |
%19 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%20 = vector.maskedload %arg0[%arg6, %arg8, %8, %9], %18, %19 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> | |
%21 = affine.apply #map(%c4, %dim_5, %c1) | |
%22 = arith.subi %3, %arg11 : index | |
%23 = arith.cmpi sge, %22, %21 : index | |
scf.if %23 { | |
%24 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%25 = vector.fma %20, %7, %24 : vector<4xf32> | |
vector.store %25, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%24 = arith.subi %dim_2, %arg11 : index | |
%25 = vector.create_mask %24 : vector<4xi1> | |
%26 = vector.maskedload %arg2[%arg6, %arg7, %arg9, %arg11], %25, %2 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> | |
%27 = vector.fma %20, %7, %26 : vector<4xf32> | |
vector.maskedstore %arg2[%arg6, %arg7, %arg9, %arg11], %25, %27 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> | |
} | |
} | |
} | |
} else { | |
%14 = arith.cmpi slt, %5, %arg3 : index | |
scf.if %14 { | |
%15 = arith.subi %dim_1, %c1 : index | |
%16 = arith.subi %arg3, %5 : index | |
%17 = vector.create_mask %16 : vector<4xi1> | |
%18 = vector.create_mask %c4 : vector<4xi1> | |
%19 = arith.subi %18, %17 : vector<4xi1> | |
%20 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%21 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%22 = vector.fma %20, %7, %21 : vector<4xf32> | |
vector.store %22, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%15 = arith.cmpi slt, %10, %1 : index | |
scf.if %15 { | |
%16 = arith.subi %dim_1, %c1 : index | |
%17 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%18 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%19 = vector.fma %17, %7, %18 : vector<4xf32> | |
vector.store %19, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%16 = arith.subi %10, %1 : index | |
%17 = arith.subi %c4, %16 : index | |
%18 = vector.create_mask %17 : vector<4xi1> | |
%19 = arith.subi %dim_1, %c1 : index | |
%20 = arith.subi %dim_2, %c1 : index | |
%21 = vector.broadcast %arg5 : f32 to vector<4xf32> | |
%22 = affine.apply #map(%c4, %dim_5, %c1) | |
%23 = arith.subi %3, %arg11 : index | |
%24 = arith.cmpi sge, %23, %22 : index | |
scf.if %24 { | |
%25 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
%26 = vector.fma %21, %7, %25 : vector<4xf32> | |
vector.store %26, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32> | |
} else { | |
%25 = arith.subi %dim_2, %arg11 : index | |
%26 = vector.create_mask %25 : vector<4xi1> | |
%27 = vector.maskedload %arg2[%arg6, %arg7, %arg9, %arg11], %26, %2 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> | |
%28 = vector.fma %21, %7, %27 : vector<4xf32> | |
vector.maskedstore %arg2[%arg6, %arg7, %arg9, %arg11], %26, %28 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
return | |
} | |
func.func @main() { | |
%c1 = arith.constant 1 : index | |
%c2 = arith.constant 2 : index | |
%c3 = arith.constant 3 : index | |
%c5 = arith.constant 5 : index | |
// Image and Output value. | |
%cst = arith.constant 2.000000e+00 : f32 | |
%cst_1 = arith.constant 1.000000e+00 : f32 | |
%cst_0 = arith.constant 0.000000e+00 : f32 | |
%current_filter = arith.constant 3 : index | |
%current_output = arith.constant 7 : index | |
// %current_image = affine.apply #map0(%current_output, %current_filter) | |
%current_image = arith.constant 7 : index | |
// Filter. | |
%filter = call @alloc_4d_filled_f32(%c5, %c2, %current_filter, %current_filter, %cst_1) : (index, index, index, index, f32) -> memref<?x?x?x?xf32> | |
// Image. | |
%image = call @alloc_4d_filled_f32(%c3, %c2, %current_image, %current_image, %cst_1) : (index, index, index, index, f32) -> memref<?x?x?x?xf32> | |
// Output. | |
%output = call @alloc_4d_filled_f32(%c3, %c5, %current_output, %current_output, %cst_0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32> | |
// Execution times. | |
%reps = arith.constant 1 : index | |
// Record start time. | |
%t_start = call @rtclock() : () -> f64 | |
// Execute convolution for specific times. | |
affine.for %arg0 = 0 to %reps { | |
//func.call @conv(%image, %filter, %output) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) -> () | |
func.call @corr_2d_nchw_fchw_constant_padding(%image, %filter, %output, %c1, %c1, %cst_0) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, index, index, f32) -> () | |
} | |
// Record end time. | |
%t_end = call @rtclock() : () -> f64 | |
// Get the total running time. | |
%t = arith.subf %t_end, %t_start : f64 | |
// vector.print %t : f64 | |
%print_input = memref.cast %image : memref<?x?x?x?xf32> to memref<*xf32> | |
call @printMemrefF32(%print_input) : (memref<*xf32>) -> () | |
%print_filter = memref.cast %filter : memref<?x?x?x?xf32> to memref<*xf32> | |
call @printMemrefF32(%print_filter) : (memref<*xf32>) -> () | |
// Print output. | |
%print_output = memref.cast %output : memref<?x?x?x?xf32> to memref<*xf32> | |
call @printMemrefF32(%print_output) : (memref<*xf32>) -> () | |
// // 2 * [filter size]^2 * [output size]^2. | |
// %flops1 = arith.muli %current_output, %current_output : index | |
// %flops2 = arith.muli %current_filter, %current_filter : index | |
// %flops3 = arith.muli %c2, %flops2 : index | |
// %flops4 = arith.muli %flops1, %flops3 : index | |
// // Calculate FLOPS. | |
// %num_flops = arith.muli %reps, %flops4 : index | |
// %num_flops_i = arith.index_cast %num_flops : index to i64 | |
// %num_flops_f = arith.sitofp %num_flops_i : i64 to f64 | |
// %flops = arith.divf %num_flops_f, %t : f64 | |
// // Print the FLOPS. | |
// // vector.print %flops : f64 | |
// memref.dealloc %image : memref<?x?x?x?xf32> | |
// memref.dealloc %filter : memref<?x?x?x?xf32> | |
// memref.dealloc %output : memref<?x?x?x?xf32> | |
return | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment