Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 27, 2024 15:47
Show Gist options
  • Save pashu123/e835b51bdeeff2b0df54f4bd2d0fee85 to your computer and use it in GitHub Desktop.
Save pashu123/e835b51bdeeff2b0df54f4bd2d0fee85 to your computer and use it in GitHub Desktop.
//util.func public @matmul_broad(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
// %cst = arith.constant 0.000000e+00 : f32
// %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
// %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
// %2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
// %3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
// %4 = tensor.empty() : tensor<540x3200x16x1xf16>
// %pack = tensor.pack %3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %4 : tensor<8640x3200xf16> -> tensor<540x3200x16x1xf16>
// %collapsed = tensor.collapse_shape %pack [[0], [1], [2, 3]] : tensor<540x3200x16x1xf16> into tensor<540x3200x16xf16>
// %5 = tensor.empty(%0) : tensor<?x540x3200x16xf16>
// %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed : tensor<540x3200x16xf16>) outs(%5 : tensor<?x540x3200x16xf16>) {
// ^bb0(%in: f16, %out: f16):
// linalg.yield %in : f16
// } -> tensor<?x540x3200x16xf16>
// %expanded = tensor.expand_shape %6 [[0], [1], [2], [3, 4]] output_shape [%0, 540, 3200, 16, 1] : tensor<?x540x3200x16xf16> into tensor<?x540x3200x16x1xf16>
// %7 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%0]
// %8 = tensor.empty(%0, %7) : tensor<?x?x3200x16x1xf32>
// %pack_0 = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %8 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
// %9 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
// %10 = tensor.empty(%0, %9) : tensor<?x?x540x16x16xf32>
// %11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
// %12 = linalg.batch_mmt4d ins(%pack_0, %expanded : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%11 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
// %13 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
// %unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %13 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
// %14 = hal.tensor.export %unpack "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
// util.return %14 : !hal.buffer_view
//}
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
util.func public @broadcast_matmul_relu(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
%0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x8640x3200xf16>
%2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
%4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
%5 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x8640xf32>) outs(%2 : tensor<?x?x8640xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.maximumf %in, %cst : f32
linalg.yield %6 : f32
} -> tensor<?x?x8640xf32>
util.return %5 : tensor<?x?x8640xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment