Skip to content

Instantly share code, notes, and snippets.

@cathyzhyi
Created April 26, 2022 16:24
Show Gist options
  • Save cathyzhyi/0dbee44a9bb91754ab3404660a6d2433 to your computer and use it in GitHub Desktop.
Save cathyzhyi/0dbee44a9bb91754ab3404660a6d2433 to your computer and use it in GitHub Desktop.
// -----// IR Dump After FuncBufferize //----- //
module {
func @collapse_dynamic_shape_of_slice(%arg0: memref<?x?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> memref<2x?x?xf32> {
%0 = bufferization.to_tensor %arg0 : memref<?x?x?x?xf32>
%1 = tensor.extract_slice %0[0, 0, %arg1, %arg1] [%arg2, %arg2, %arg3, %arg3] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
%2 = tensor.cast %1 : tensor<?x?x?x?xf32> to tensor<2x?x?x?xf32>
%3 = tensor.collapse_shape %2 [[0], [1, 2], [3]] : tensor<2x?x?x?xf32> into tensor<2x?x?xf32>
%4 = bufferization.to_memref %3 : memref<2x?x?xf32>
return %4 : memref<2x?x?xf32>
}
}
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)
// -----// IR Dump After TensorBufferize //----- //
func @collapse_dynamic_shape_of_slice(%arg0: memref<?x?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> memref<2x?x?xf32> {
%0 = bufferization.to_tensor %arg0 : memref<?x?x?x?xf32>
%1 = memref.subview %arg0[0, 0, %arg1, %arg1] [%arg2, %arg2, %arg3, %arg3] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
%2 = bufferization.to_tensor %1 : memref<?x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
%3 = memref.cast %1 : memref<?x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>> to memref<2x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
%c1 = arith.constant 1 : index
%4 = memref.dim %3, %c1 : memref<2x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
%c2 = arith.constant 2 : index
%5 = memref.dim %3, %c2 : memref<2x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
%c3 = arith.constant 3 : index
%6 = memref.dim %3, %c3 : memref<2x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
%7 = memref.alloc(%4, %5, %6) {alignment = 128 : i64} : memref<2x?x?x?xf32>
%8 = bufferization.to_tensor %3 : memref<2x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>>
memref.copy %3, %7 : memref<2x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>> to memref<2x?x?x?xf32>
%9 = memref.collapse_shape %7 [[0], [1, 2], [3]] : memref<2x?x?x?xf32> into memref<2x?x?xf32>
%10 = bufferization.to_tensor %9 : memref<2x?x?xf32>
return %9 : memref<2x?x?xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment