Skip to content

Instantly share code, notes, and snippets.

@meshtag
Created March 17, 2024 20:10
Show Gist options
  • Save meshtag/35c1a98e53cbaba71daf671a52ba24d1 to your computer and use it in GitHub Desktop.
Save meshtag/35c1a98e53cbaba71daf671a52ba24d1 to your computer and use it in GitHub Desktop.
Sanity test for channel processing.
// #map = affine_map<(d0) -> (d0)>
#map0 = affine_map<(d0, d1) -> (d0 + d1 - 1)>
// #map1 = affine_map<(d0) -> (d0 ceildiv 256)>
#map_new = affine_map<(d0, d1, d2) -> (d0 + d1 - d2)>
#map = affine_map<(d0, d1, d2) -> (d0 + d1 - d2)>
#map1 = affine_map<(d0) -> (d0)>
module {
func.func private @printMemrefF32(memref<*xf32>)
func.func private @print_flops(f64)
func.func private @rtclock() -> f64
func.func @alloc_4d_filled_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
scf.for %arg5 = %c0 to %arg0 step %c1 {
scf.for %arg6 = %c0 to %arg1 step %c1 {
scf.for %arg7 = %c0 to %arg2 step %c1 {
scf.for %arg8 = %c0 to %arg3 step %c1 {
memref.store %arg4, %0[%arg5, %arg6, %arg7, %arg8] : memref<?x?x?x?xf32>
}
}
}
}
return %0 : memref<?x?x?x?xf32>
}
func.func @corr_2d_nchw_fchw_constant_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>, %arg3: index, %arg4: index, %arg5: f32) attributes {llvm.emit_c_interface} {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%dim = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
%dim_0 = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
%dim_1 = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
%dim_2 = memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
%dim_3 = memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
%dim_4 = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
%dim_5 = memref.dim %arg1, %c3 : memref<?x?x?x?xf32>
%0 = arith.addi %dim_1, %arg4 : index
%1 = arith.addi %dim_2, %arg3 : index
%cst = arith.constant 0.000000e+00 : f32
%2 = vector.broadcast %cst : f32 to vector<4xf32>
%3 = affine.apply #map(%dim_2, %dim_5, %c1)
affine.for %arg6 = #map1(%c0) to #map1(%dim) {
affine.for %arg7 = #map1(%c0) to #map1(%dim_3) {
affine.for %arg8 = #map1(%c0) to #map1(%dim_0) {
affine.for %arg9 = #map1(%c0) to #map1(%dim_1) {
affine.for %arg10 = #map1(%c0) to #map1(%dim_4) {
affine.for %arg11 = #map1(%c0) to #map1(%dim_2) step 4 {
affine.for %arg12 = #map1(%c0) to #map1(%dim_5) {
%4 = arith.addi %arg9, %arg10 : index
%5 = arith.addi %arg11, %arg12 : index
%6 = memref.load %arg1[%arg7, %arg8, %arg10, %arg12] : memref<?x?x?x?xf32>
%7 = vector.broadcast %6 : f32 to vector<4xf32>
%8 = arith.subi %4, %arg4 : index
%9 = arith.subi %5, %arg3 : index
%10 = arith.addi %5, %c4 : index
%11 = arith.cmpi slt, %4, %arg4 : index
%12 = arith.cmpf one, %6, %cst : f32
scf.if %12 {
scf.if %11 {
%13 = arith.cmpi slt, %5, %arg3 : index
scf.if %13 {
%14 = arith.subi %arg3, %5 : index
%15 = vector.create_mask %14 : vector<4xi1>
%16 = vector.create_mask %c4 : vector<4xi1>
%17 = arith.subi %16, %15 : vector<4xi1>
%18 = vector.broadcast %arg5 : f32 to vector<4xf32>
%19 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%20 = vector.fma %18, %7, %19 : vector<4xf32>
vector.store %20, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%14 = arith.cmpi slt, %10, %1 : index
scf.if %14 {
%15 = vector.broadcast %arg5 : f32 to vector<4xf32>
%16 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%17 = vector.fma %15, %7, %16 : vector<4xf32>
vector.store %17, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%15 = arith.subi %10, %1 : index
%16 = arith.subi %c4, %15 : index
%17 = vector.create_mask %16 : vector<4xi1>
%18 = vector.broadcast %arg5 : f32 to vector<4xf32>
%19 = affine.apply #map(%c4, %dim_5, %c1)
%20 = arith.subi %3, %arg11 : index
%21 = arith.cmpi sge, %20, %19 : index
scf.if %21 {
%22 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%23 = vector.fma %18, %7, %22 : vector<4xf32>
vector.store %23, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%22 = arith.subi %dim_2, %arg11 : index
%23 = vector.create_mask %22 : vector<4xi1>
%24 = vector.maskedload %arg2[%arg6, %arg7, %arg9, %arg11], %23, %2 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%25 = vector.fma %18, %7, %24 : vector<4xf32>
vector.maskedstore %arg2[%arg6, %arg7, %arg9, %arg11], %23, %25 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32>
}
}
}
} else {
%13 = arith.cmpi slt, %4, %0 : index
scf.if %13 {
%14 = arith.cmpi slt, %5, %arg3 : index
scf.if %14 {
%15 = arith.subi %arg3, %5 : index
%16 = vector.create_mask %15 : vector<4xi1>
%17 = vector.create_mask %c4 : vector<4xi1>
%18 = arith.subi %17, %16 : vector<4xi1>
%19 = vector.broadcast %arg5 : f32 to vector<4xf32>
%20 = arith.subi %c0, %15 : index
%21 = vector.maskedload %arg0[%arg6, %arg8, %8, %20], %18, %19 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%22 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%23 = vector.fma %21, %7, %22 : vector<4xf32>
vector.store %23, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%15 = arith.cmpi slt, %10, %1 : index
scf.if %15 {
%16 = vector.load %arg0[%arg6, %arg8, %8, %9] : memref<?x?x?x?xf32>, vector<4xf32>
%17 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%18 = vector.fma %16, %7, %17 : vector<4xf32>
vector.store %18, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%16 = arith.subi %10, %1 : index
%17 = arith.subi %c4, %16 : index
%18 = vector.create_mask %17 : vector<4xi1>
%19 = vector.broadcast %arg5 : f32 to vector<4xf32>
%20 = vector.maskedload %arg0[%arg6, %arg8, %8, %9], %18, %19 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%21 = affine.apply #map(%c4, %dim_5, %c1)
%22 = arith.subi %3, %arg11 : index
%23 = arith.cmpi sge, %22, %21 : index
scf.if %23 {
%24 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%25 = vector.fma %20, %7, %24 : vector<4xf32>
vector.store %25, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%24 = arith.subi %dim_2, %arg11 : index
%25 = vector.create_mask %24 : vector<4xi1>
%26 = vector.maskedload %arg2[%arg6, %arg7, %arg9, %arg11], %25, %2 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%27 = vector.fma %20, %7, %26 : vector<4xf32>
vector.maskedstore %arg2[%arg6, %arg7, %arg9, %arg11], %25, %27 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32>
}
}
}
} else {
%14 = arith.cmpi slt, %5, %arg3 : index
scf.if %14 {
%15 = arith.subi %dim_1, %c1 : index
%16 = arith.subi %arg3, %5 : index
%17 = vector.create_mask %16 : vector<4xi1>
%18 = vector.create_mask %c4 : vector<4xi1>
%19 = arith.subi %18, %17 : vector<4xi1>
%20 = vector.broadcast %arg5 : f32 to vector<4xf32>
%21 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%22 = vector.fma %20, %7, %21 : vector<4xf32>
vector.store %22, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%15 = arith.cmpi slt, %10, %1 : index
scf.if %15 {
%16 = arith.subi %dim_1, %c1 : index
%17 = vector.broadcast %arg5 : f32 to vector<4xf32>
%18 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%19 = vector.fma %17, %7, %18 : vector<4xf32>
vector.store %19, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%16 = arith.subi %10, %1 : index
%17 = arith.subi %c4, %16 : index
%18 = vector.create_mask %17 : vector<4xi1>
%19 = arith.subi %dim_1, %c1 : index
%20 = arith.subi %dim_2, %c1 : index
%21 = vector.broadcast %arg5 : f32 to vector<4xf32>
%22 = affine.apply #map(%c4, %dim_5, %c1)
%23 = arith.subi %3, %arg11 : index
%24 = arith.cmpi sge, %23, %22 : index
scf.if %24 {
%25 = vector.load %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
%26 = vector.fma %21, %7, %25 : vector<4xf32>
vector.store %26, %arg2[%arg6, %arg7, %arg9, %arg11] : memref<?x?x?x?xf32>, vector<4xf32>
} else {
%25 = arith.subi %dim_2, %arg11 : index
%26 = vector.create_mask %25 : vector<4xi1>
%27 = vector.maskedload %arg2[%arg6, %arg7, %arg9, %arg11], %26, %2 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%28 = vector.fma %21, %7, %27 : vector<4xf32>
vector.maskedstore %arg2[%arg6, %arg7, %arg9, %arg11], %26, %28 : memref<?x?x?x?xf32>, vector<4xi1>, vector<4xf32>
}
}
}
}
}
}
}
}
}
}
}
}
}
return
}
func.func @main() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c5 = arith.constant 5 : index
// Image and Output value.
%cst = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%current_filter = arith.constant 3 : index
%current_output = arith.constant 7 : index
// %current_image = affine.apply #map0(%current_output, %current_filter)
%current_image = arith.constant 7 : index
// Filter.
%filter = call @alloc_4d_filled_f32(%c5, %c2, %current_filter, %current_filter, %cst_1) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
// Image.
%image = call @alloc_4d_filled_f32(%c3, %c2, %current_image, %current_image, %cst_1) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
// Output.
%output = call @alloc_4d_filled_f32(%c3, %c5, %current_output, %current_output, %cst_0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
// Execution times.
%reps = arith.constant 1 : index
// Record start time.
%t_start = call @rtclock() : () -> f64
// Execute convolution for specific times.
affine.for %arg0 = 0 to %reps {
//func.call @conv(%image, %filter, %output) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) -> ()
func.call @corr_2d_nchw_fchw_constant_padding(%image, %filter, %output, %c1, %c1, %cst_0) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, index, index, f32) -> ()
}
// Record end time.
%t_end = call @rtclock() : () -> f64
// Get the total running time.
%t = arith.subf %t_end, %t_start : f64
// vector.print %t : f64
%print_input = memref.cast %image : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_input) : (memref<*xf32>) -> ()
%print_filter = memref.cast %filter : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_filter) : (memref<*xf32>) -> ()
// Print output.
%print_output = memref.cast %output : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_output) : (memref<*xf32>) -> ()
// // 2 * [filter size]^2 * [output size]^2.
// %flops1 = arith.muli %current_output, %current_output : index
// %flops2 = arith.muli %current_filter, %current_filter : index
// %flops3 = arith.muli %c2, %flops2 : index
// %flops4 = arith.muli %flops1, %flops3 : index
// // Calculate FLOPS.
// %num_flops = arith.muli %reps, %flops4 : index
// %num_flops_i = arith.index_cast %num_flops : index to i64
// %num_flops_f = arith.sitofp %num_flops_i : i64 to f64
// %flops = arith.divf %num_flops_f, %t : f64
// // Print the FLOPS.
// // vector.print %flops : f64
// memref.dealloc %image : memref<?x?x?x?xf32>
// memref.dealloc %filter : memref<?x?x?x?xf32>
// memref.dealloc %output : memref<?x?x?x?xf32>
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment