Skip to content

Instantly share code, notes, and snippets.

@Max191
Created May 21, 2024 19:33
Show Gist options
  • Save Max191/a32c07b72272e74cf625cd810ae09c0a to your computer and use it in GitHub Desktop.
Save Max191/a32c07b72272e74cf625cd810ae09c0a to your computer and use it in GitHub Desktop.
Bad pack and unpack codegen with outer_dims_perm
module {
func.func @pack_bad(%arg0: tensor<29241x128x64xbf16>) -> tensor<64x1828x64x16x2xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%4 = tensor.empty() : tensor<64x1828x64x16x2xbf16>
%pack = tensor.pack %arg0 padding_value(%cst : bf16) outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %4 : tensor<29241x128x64xbf16> -> tensor<64x1828x64x16x2xbf16>
return %pack : tensor<64x1828x64x16x2xbf16>
}
func.func @pack_good(%arg0: tensor<64x29241x128xbf16>) -> tensor<64x1828x64x16x2xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%4 = tensor.empty() : tensor<64x1828x64x16x2xbf16>
%pack = tensor.pack %arg0 padding_value(%cst : bf16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 2] into %4 : tensor<64x29241x128xbf16> -> tensor<64x1828x64x16x2xbf16>
return %pack : tensor<64x1828x64x16x2xbf16>
}
func.func @unpack_bad(%arg0: tensor<64x1828x8x16x16xf32>) -> tensor<29241x128x64xf32> {
%cst = arith.constant 0.000000e+00 : bf16
%4 = tensor.empty() : tensor<29241x128x64xf32>
%unpack = tensor.unpack %arg0 outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %4 : tensor<64x1828x8x16x16xf32> -> tensor<29241x128x64xf32>
return %unpack : tensor<29241x128x64xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment