Consider the well-formed testcase attached here, x.mlir. Run:
bin/mlir-opt /tmp/x.mlir -test-vector-to-vector-lowering -debug
The essential part of the output (just the vector.contract
) is recorded here as y.mlir
. What's interesting is it has a reduction dimension (d0
) of size 1, only occurring on LHS not on RHS. That is what is currently causing an assert failure in lowerParallel
, motivating this investigation.
-debug
output excerpt showing that being introduced by CastAwayContractionLeadingOneDim
:
//===-------------------------------------------===//
Processing operation : 'vector.contract'(0x63e5c60) {
%4 = "vector.contract"(%0, %2, %3) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} : (vector<1x1x2xi32>, vector<1x8x2xi32>, vector<1x8xi32>) -> vector<1x8xi32>
* Pattern (anonymous namespace)::CastAwayContractionLeadingOneDim : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::CastAwayContractionLeadingOneDim"
** Insert : 'vector.transpose'(0x63e61d0)
** Insert : 'vector.extract'(0x63e9c60)
** Insert : 'vector.extract'(0x63e9cf0)
** Insert : 'vector.contract'(0x63edc50)
** Insert : 'vector.broadcast'(0x63edd20)
** Replace : 'vector.contract'(0x63e5c60)
"(anonymous namespace)::CastAwayContractionLeadingOneDim" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @foo(%arg0: vector<1x1x1x2xi32>, %arg1: vector<1x1x8x2xi32>, %arg2: vector<1x1x1x8xi32>) -> vector<1x1x1x8xi32> {
%0 = vector.extract %arg0[0] : vector<1x1x1x2xi32>
%1 = vector.extract %arg2[0] : vector<1x1x1x8xi32>
%2 = vector.extract %arg1[0] : vector<1x1x8x2xi32>
%3 = vector.extract %1[0] : vector<1x1x8xi32>
%4 = vector.transpose %0, [1, 0, 2] : vector<1x1x2xi32> to vector<1x1x2xi32>
%5 = vector.extract %4[0] : vector<1x1x2xi32>
%6 = vector.extract %3[0] : vector<1x8xi32>
%7 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction"], kind = #vector.kind<add>} %5, %2, %6 : vector<1x2xi32>, vector<1x8x2xi32> into vector<8xi32>
%8 = vector.broadcast %7 : vector<8xi32> to vector<1x8xi32>
%9 = vector.broadcast %8 : vector<1x8xi32> to vector<1x1x8xi32>
%10 = vector.broadcast %9 : vector<1x1x8xi32> to vector<1x1x1x8xi32>
return %10 : vector<1x1x1x8xi32>
}
} -> success : pattern matched
//===-------------------------------------------===//
Now here is y.mlir
causing an assert failure:
bin/mlir-opt /tmp/y.mlir -test-vector-contraction-lowering
mlir-opt: /usr/local/google/home/benoitjacob/iree/third_party/llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp:1867: mlir::Value mlir::vector::ContractionOpLowering::lowerParallel(vector::ContractionOp, int64_t, int64_t, mlir::PatternRewriter &) const: Assertion `lookup.hasValue() && "parallel index not listed in reduction"' failed.
So at the moment there is a disagreement between -test-vector-to-vector-lowering
and -test-vector-contraction-lowering
. The former produces IR that the latter asserts against.