Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created September 1, 2022 04:50
Show Gist options
  • Save qedawkins/f8e899a93a9ae8265cc5c11a07271e55 to your computer and use it in GitHub Desktop.
Save qedawkins/f8e899a93a9ae8265cc5c11a07271e55 to your computer and use it in GitHub Desktop.
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (0, d1)>
#map3 = affine_map<(d0, d1) -> (d1)>
module attributes {torch.debug_module_name = "DLRM_Net"} {
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<3x1xi64>, %arg2: tensor<3xi64>, %arg3: tensor<1xi64>, %arg4: tensor<1xi64>) -> tensor<1x2xf32> {
%cst = arith.constant dense<[[0.929304063, 0.0979973599, 0.239170983], [-5.614850e-01, -1.25276566, -0.220038965]]> : tensor<2x3xf32>
%cst_0 = arith.constant dense<[0.110555418, 0.869946897]> : tensor<2xf32>
%cst_1 = arith.constant dense<[[0.237254873, 0.178356424, 0.798618853, -0.109661706], [0.167341724, -0.456533372, -1.36463046, 0.349373847], [0.462060571, -0.396703899, 1.2132349, -0.777391135]]> : tensor<3x4xf32>
%cst_2 = arith.constant dense<[0.0264186915, -0.108070649, 0.884950518]> : tensor<3xf32>
%cst_3 = arith.constant 0.000000e+00 : f32
%0 = linalg.init_tensor [4, 3] : tensor<4x3xf32>
%1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<3x4xf32>) outs(%0 : tensor<4x3xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<4x3xf32>
%2 = linalg.init_tensor [1, 3] : tensor<1x3xf32>
%3 = linalg.fill ins(%cst_3 : f32) outs(%2 : tensor<1x3xf32>) -> tensor<1x3xf32>
%4 = linalg.matmul ins(%arg0, %1 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%3 : tensor<1x3xf32>) -> tensor<1x3xf32>
%5 = linalg.generic {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "parallel"]} ins(%4, %cst_2 : tensor<1x3xf32>, tensor<3xf32>) outs(%2 : tensor<1x3xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%14 = arith.addf %arg5, %arg6 : f32
linalg.yield %14 : f32
} -> tensor<1x3xf32>
%6 = linalg.generic {indexing_maps = [#map2, #map0], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<1x3xf32>) outs(%2 : tensor<1x3xf32>) {
^bb0(%arg5: f32, %arg6: f32):
%14 = arith.cmpf ugt, %arg5, %cst_3 : f32
%15 = arith.select %14, %arg5, %cst_3 : f32
linalg.yield %15 : f32
} -> tensor<1x3xf32>
%7 = linalg.init_tensor [3, 2] : tensor<3x2xf32>
%8 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor<2x3xf32>) outs(%7 : tensor<3x2xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<3x2xf32>
%9 = linalg.init_tensor [1, 2] : tensor<1x2xf32>
%10 = linalg.fill ins(%cst_3 : f32) outs(%9 : tensor<1x2xf32>) -> tensor<1x2xf32>
%11 = linalg.matmul ins(%6, %8 : tensor<1x3xf32>, tensor<3x2xf32>) outs(%10 : tensor<1x2xf32>) -> tensor<1x2xf32>
%12 = linalg.generic {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "parallel"]} ins(%11, %cst_0 : tensor<1x2xf32>, tensor<2xf32>) outs(%9 : tensor<1x2xf32>) {
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
%14 = arith.addf %arg5, %arg6 : f32
linalg.yield %14 : f32
} -> tensor<1x2xf32>
%13 = linalg.generic {indexing_maps = [#map2, #map0], iterator_types = ["parallel", "parallel"]} ins(%12 : tensor<1x2xf32>) outs(%9 : tensor<1x2xf32>) {
^bb0(%arg5: f32, %arg6: f32):
%14 = arith.cmpf ugt, %arg5, %cst_3 : f32
%15 = arith.select %14, %arg5, %cst_3 : f32
linalg.yield %15 : f32
} -> tensor<1x2xf32>
return %13 : tensor<1x2xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment