Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created September 12, 2022 05:25
Show Gist options
  • Save qedawkins/e3023a70ef3e5797545c4b3468e98dc1 to your computer and use it in GitHub Desktop.
Save qedawkins/e3023a70ef3e5797545c4b3468e98dc1 to your computer and use it in GitHub Desktop.
#map0 = affine_map<(d0, d1, d2) -> (d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
module attributes {torch.debug_module_name = "DLRMShark"} {
func.func @forward(%arg1: tensor<4xi64>, %arg2: tensor<2xi64>, %arg4: tensor<4xi64>, %arg5: tensor<2xi64>, %arg7: tensor<3xi64>, %arg8: tensor<2xi64>) -> tensor<2x3x8xf32> {
%cst = arith.constant dense<"0xtensor<100x8xf32>
%cst_0 = arith.constant dense<"0xtensor<100x8xf32>
%c1_i64 = arith.constant 1 : i64
%cst_1 = arith.constant 0.000000e+00 : f32
%c2_i64 = arith.constant 2 : i64
%c4_i64 = arith.constant 4 : i64
%c3_i64 = arith.constant 3 : i64
%0 = linalg.init_tensor [2, 8] : tensor<2x8xf32>
%1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<2x8xf32>) -> tensor<2x8xf32>
%2 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %arg2 : tensor<4xi64>, tensor<2xi64>) outs(%1 : tensor<2x8xf32>) {
^bb0(%arg10: i64, %arg11: i64, %arg12: f32):
%12 = linalg.index 0 : index
%13 = arith.index_cast %12 : index to i64
%14 = arith.addi %13, %c1_i64 : i64
%15 = arith.index_cast %14 : i64 to index
%16 = arith.cmpi eq, %14, %c2_i64 : i64
%17 = tensor.extract %arg2[%15] : tensor<2xi64>
%18 = arith.select %16, %c4_i64, %17 : i64
%19 = linalg.index 1 : index
%20 = arith.index_cast %19 : index to i64
%21 = arith.cmpi slt, %arg11, %20 : i64
%22 = arith.cmpi eq, %arg11, %20 : i64
%23 = arith.ori %21, %22 : i1
%24 = arith.cmpi slt, %20, %18 : i64
%25 = arith.andi %23, %24 : i1
%26 = arith.index_cast %arg10 : i64 to index
%27 = linalg.index 2 : index
%28 = tensor.extract %cst_0[%26, %27] : tensor<100x8xf32>
%29 = arith.addf %28, %arg12 : f32
%30 = arith.select %25, %29, %arg12 : f32
linalg.yield %30 : f32
} -> tensor<2x8xf32>
%3 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<2x8xf32>) -> tensor<2x8xf32>
%4 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg4, %arg5 : tensor<4xi64>, tensor<2xi64>) outs(%3 : tensor<2x8xf32>) {
^bb0(%arg10: i64, %arg11: i64, %arg12: f32):
%12 = linalg.index 0 : index
%13 = arith.index_cast %12 : index to i64
%14 = arith.addi %13, %c1_i64 : i64
%15 = arith.index_cast %14 : i64 to index
%16 = arith.cmpi eq, %14, %c2_i64 : i64
%17 = tensor.extract %arg5[%15] : tensor<2xi64>
%18 = arith.select %16, %c4_i64, %17 : i64
%19 = linalg.index 1 : index
%20 = arith.index_cast %19 : index to i64
%21 = arith.cmpi slt, %arg11, %20 : i64
%22 = arith.cmpi eq, %arg11, %20 : i64
%23 = arith.ori %21, %22 : i1
%24 = arith.cmpi slt, %20, %18 : i64
%25 = arith.andi %23, %24 : i1
%26 = arith.index_cast %arg10 : i64 to index
%27 = linalg.index 2 : index
%28 = tensor.extract %cst_0[%26, %27] : tensor<100x8xf32>
%29 = arith.addf %28, %arg12 : f32
%30 = arith.select %25, %29, %arg12 : f32
linalg.yield %30 : f32
} -> tensor<2x8xf32>
%5 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<2x8xf32>) -> tensor<2x8xf32>
%6 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg7, %arg8 : tensor<3xi64>, tensor<2xi64>) outs(%5 : tensor<2x8xf32>) {
^bb0(%arg10: i64, %arg11: i64, %arg12: f32):
%12 = linalg.index 0 : index
%13 = arith.index_cast %12 : index to i64
%14 = arith.addi %13, %c1_i64 : i64
%15 = arith.index_cast %14 : i64 to index
%16 = arith.cmpi eq, %14, %c2_i64 : i64
%17 = tensor.extract %arg8[%15] : tensor<2xi64>
%18 = arith.select %16, %c3_i64, %17 : i64
%19 = linalg.index 1 : index
%20 = arith.index_cast %19 : index to i64
%21 = arith.cmpi slt, %arg11, %20 : i64
%22 = arith.cmpi eq, %arg11, %20 : i64
%23 = arith.ori %21, %22 : i1
%24 = arith.cmpi slt, %20, %18 : i64
%25 = arith.andi %23, %24 : i1
%26 = arith.index_cast %arg10 : i64 to index
%27 = linalg.index 2 : index
%28 = tensor.extract %cst[%26, %27] : tensor<100x8xf32>
%29 = arith.addf %28, %arg12 : f32
%30 = arith.select %25, %29, %arg12 : f32
linalg.yield %30 : f32
} -> tensor<2x8xf32>
%7 = linalg.init_tensor [2, 24] : tensor<2x24xf32>
%8 = tensor.insert_slice %2 into %7[0, 0] [2, 8] [1, 1] : tensor<2x8xf32> into tensor<2x24xf32>
%9 = tensor.insert_slice %4 into %8[0, 8] [2, 8] [1, 1] : tensor<2x8xf32> into tensor<2x24xf32>
%10 = tensor.insert_slice %6 into %9[0, 16] [2, 8] [1, 1] : tensor<2x8xf32> into tensor<2x24xf32>
%11 = tensor.expand_shape %10 [[0], [1, 2]] : tensor<2x24xf32> into tensor<2x3x8xf32>
return %11 : tensor<2x3x8xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment