Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created July 4, 2022 20:26
Show Gist options
  • Save bjacob/d8be8ec7e70ed0be4b3a5794ced2a7e8 to your computer and use it in GitHub Desktop.
Save bjacob/d8be8ec7e70ed0be4b3a5794ced2a7e8 to your computer and use it in GitHub Desktop.

Consider the well-formed testcase attached here, x.mlir. Run:

bin/mlir-opt /tmp/x.mlir -test-vector-to-vector-lowering -debug

The essential part of the output (just the vector.contract) is recorded here as y.mlir. What's interesting is it has a reduction dimension (d0) of size 1, only occurring on LHS not on RHS. That is what is currently causing an assert failure in lowerParallel, motivating this investigation.

-debug output excerpt showing that being introduced by CastAwayContractionLeadingOneDim:

//===-------------------------------------------===//
Processing operation : 'vector.contract'(0x63e5c60) {
  %4 = "vector.contract"(%0, %2, %3) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} : (vector<1x1x2xi32>, vector<1x8x2xi32>, vector<1x8xi32>) -> vector<1x8xi32>


  * Pattern (anonymous namespace)::CastAwayContractionLeadingOneDim : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::CastAwayContractionLeadingOneDim"
    ** Insert  : 'vector.transpose'(0x63e61d0)
    ** Insert  : 'vector.extract'(0x63e9c60)
    ** Insert  : 'vector.extract'(0x63e9cf0)
    ** Insert  : 'vector.contract'(0x63edc50)
    ** Insert  : 'vector.broadcast'(0x63edd20)
    ** Replace : 'vector.contract'(0x63e5c60)
"(anonymous namespace)::CastAwayContractionLeadingOneDim" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @foo(%arg0: vector<1x1x1x2xi32>, %arg1: vector<1x1x8x2xi32>, %arg2: vector<1x1x1x8xi32>) -> vector<1x1x1x8xi32> {
  %0 = vector.extract %arg0[0] : vector<1x1x1x2xi32>
  %1 = vector.extract %arg2[0] : vector<1x1x1x8xi32>
  %2 = vector.extract %arg1[0] : vector<1x1x8x2xi32>
  %3 = vector.extract %1[0] : vector<1x1x8xi32>
  %4 = vector.transpose %0, [1, 0, 2] : vector<1x1x2xi32> to vector<1x1x2xi32>
  %5 = vector.extract %4[0] : vector<1x1x2xi32>
  %6 = vector.extract %3[0] : vector<1x8xi32>
  %7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction"], kind = #vector.kind<add>} %5, %2, %6 : vector<1x2xi32>, vector<1x8x2xi32> into vector<8xi32>
  %8 = vector.broadcast %7 : vector<8xi32> to vector<1x8xi32>
  %9 = vector.broadcast %8 : vector<1x8xi32> to vector<1x1x8xi32>
  %10 = vector.broadcast %9 : vector<1x1x8xi32> to vector<1x1x1x8xi32>
  return %10 : vector<1x1x1x8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

Now here is y.mlir causing an assert failure:

bin/mlir-opt /tmp/y.mlir -test-vector-contraction-lowering
mlir-opt: /usr/local/google/home/benoitjacob/iree/third_party/llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp:1867: mlir::Value mlir::vector::ContractionOpLowering::lowerParallel(vector::ContractionOp, int64_t, int64_t, mlir::PatternRewriter &) const: Assertion `lookup.hasValue() && "parallel index not listed in reduction"' failed.

So at the moment there is a disagreement between -test-vector-to-vector-lowering and -test-vector-contraction-lowering. The former produces IR that the latter asserts against.

func.func @foo(%arg0 : vector<1x1x1x2xi32>, %arg1 : vector<1x1x8x2xi32>, %arg2 : vector<1x1x1x8xi32>) -> vector<1x1x1x8xi32> {
%res = vector.contract {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
],
iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %arg2 : vector<1x1x1x2xi32>, vector<1x1x8x2xi32>, vector<1x1x8x2xi32> into vector<1x1x1x8xi32>
return %res : vector<1x1x1x8xi32>
}
func.func @foo(%arg0 : vector<1x2xi32>, %arg1 : vector<8x2xi32>, %arg2 : vector<8xi32>) -> vector<8xi32> {
%res = vector.contract {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d1)>
],
iterator_types = ["reduction", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<8x2xi32>, vector<8xi32> into vector<8xi32>
return %res : vector<8xi32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment