Skip to content

Instantly share code, notes, and snippets.

@Max191
Created October 16, 2023 20:23
Show Gist options
  • Save Max191/908486a43bd86c83d865d7d25face75f to your computer and use it in GitHub Desktop.
Save Max191/908486a43bd86c83d865d7d25face75f to your computer and use it in GitHub Desktop.
transpose batch_matmul fusion
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
module {
func.func @concat_batchMM(%arg0: tensor<?x32x128xf32>, %arg1: tensor<32x1x128xf32>) -> (tensor<32x1x?xf32>, tensor<?x32x128xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x32x128xf32>
%0 = tensor.empty(%dim) : tensor<32x?x128xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x32x128xf32>) outs(%0 : tensor<32x?x128xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x?x128xf32>
%2 = tensor.empty(%dim) : tensor<32x1x?xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<32x1x?xf32>) -> tensor<32x1x?xf32>
%4 = linalg.batch_matmul_transpose_b ins(%arg1, %1 : tensor<32x1x128xf32>, tensor<32x?x128xf32>) outs(%3 : tensor<32x1x?xf32>) -> tensor<32x1x?xf32>
return %4, %arg0 : tensor<32x1x?xf32>, tensor<?x32x128xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment