Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 14, 2024 08:52
Show Gist options
  • Save pashu123/83ca1f519aa39f1ce7a035122bbb7e54 to your computer and use it in GitHub Desktop.
Save pashu123/83ca1f519aa39f1ce7a035122bbb7e54 to your computer and use it in GitHub Desktop.
//func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178x1178xf32> {
// %c0 = arith.constant 0 : index
// %0 = tensor.empty() : tensor<2x24x1178x1178xf32>
// %1 = linalg.softmax dimension(3) ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%0 : tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178x1178xf32>
// return %1 : tensor<2x24x1178x1178xf32>
//}
func.func @softmax(%arg0: tensor<2x24x1178x1178xf32>) -> tensor<2x24x1178xf32> {
%4 = tensor.empty() : tensor<2x24x1178xf32>
%cst = arith.constant -3.40282347E+38 : f32
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x24x1178xf32>) -> tensor<2x24x1178xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<2x24x1178x1178xf32>) outs(%5 : tensor<2x24x1178xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.maximumf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<2x24x1178xf32>
return %6 : tensor<2x24x1178xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment