Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AmosLewis/8618bc5e191674c30d03c50c4a921839 to your computer and use it in GitHub Desktop.
Save AmosLewis/8618bc5e191674c30d03c50c4a921839 to your computer and use it in GitHub Desktop.
module {
func.func @test_reduce_prod_default_axes_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%int0 = torch.constant.int 0
%int0_0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.aten.dim %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.int
%1 = torch.aten.lt.int %int0_0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int
%3 = torch.aten.mul.int %2, %0 : !torch.int, !torch.int -> !torch.int
%4 = torch.aten.add.int %int0_0, %3 : !torch.int, !torch.int -> !torch.int
%5 = torch.aten.lt.int %int1, %int0 : !torch.int, !torch.int -> !torch.bool
%6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
%7 = torch.aten.mul.int %6, %0 : !torch.int, !torch.int -> !torch.int
%8 = torch.aten.add.int %int1, %7 : !torch.int, !torch.int -> !torch.int
%9 = torch.aten.lt.int %int2, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.aten.Int.bool %9 : !torch.bool -> !torch.int
%11 = torch.aten.mul.int %10, %0 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.add.int %int2, %11 : !torch.int, !torch.int -> !torch.int
%true = torch.constant.bool true
%none = torch.constant.none
%13 = torch.aten.prod.dim_int %arg0, %4, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
%14 = torch.aten.prod.dim_int %13, %8, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
%15 = torch.aten.prod.dim_int %14, %12, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
%int-1 = torch.constant.int -1
%int-1_1 = torch.constant.int -1
%int-1_2 = torch.constant.int -1
%16 = torch.prim.ListConstruct %int-1, %int-1_1, %int-1_2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%17 = torch.aten.reshape %15, %16 : !torch.vtensor<[?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[1,1,1],f32>
return %17 : !torch.vtensor<[1,1,1],f32>
}
}
@AmosLewis
Copy link
Author

AmosLewis commented Feb 28, 2024

FIXED ISSUE.

 torch-mlir-opt --convert-torch-to-linalg ../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir       
../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir:22:11: error: failed to legalize operation 'torch.aten.prod.dim_int' that was explicitly marked illegal
    %13 = torch.aten.prod.dim_int %arg0, %4, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
          ^
../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir:22:11: note: see current operation: %19 = "torch.aten.prod.dim_int"(%arg0, %8, %17, %18) : (!torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none) -> !torch.vtensor<[?,?,?],f32>

If we op.getDim().dump(); before https://github.com/llvm/torch-mlir/blob/e48fe4588631e7a37a2899f9d4cd5c4cbc967481/lib/Conversion/TorchToLinalg/Reduction.cpp#L467
we get the print %5 = torch.aten.add.int %int0_0, %4 : !torch.int, !torch.int -> !torch.int
which won't jump into the if (matchPattern(op.getDim(), m_TorchConstantInt(&dim))) since the dim type is torch.int not torch.constant.int, so it return fail.

FIXED by add --torch-decompose-complex-ops --cse --canonicalize

@AmosLewis
Copy link
Author

AmosLewis commented Feb 28, 2024

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize ../onnx_torch_import_20231121/node_test_reduce_prod_default_axes_keepdims_random_model.mlir

module {
  func.func @test_reduce_prod_default_axes_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %int-1 = torch.constant.int -1
    %none = torch.constant.none
    %true = torch.constant.bool true
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %0 = torch.aten.prod.dim_int %arg0, %int0, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
    %1 = torch.aten.prod.dim_int %0, %int1, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
    %2 = torch.aten.prod.dim_int %1, %int2, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
    %3 = torch.prim.ListConstruct %int-1, %int-1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %4 = torch.aten.view %2, %3 : !torch.vtensor<[?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[1,1,1],f32>
    return %4 : !torch.vtensor<[1,1,1],f32>
  }
}

torch-mlir-opt --convert-torch-to-linalg ../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir

torch-mlir-opt --convert-torch-to-linalg ../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir
../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir:13:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
    %4 = torch.aten.view %2, %3 : !torch.vtensor<[?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[1,1,1],f32>
         ^
../test/prod/node_test_reduce_prod_default_axes_keepdims_random_model_torch.mlir:13:10: note: see current operation: %51 = "torch.aten.view"(%49, %50) : (!torch.vtensor<[?,?,?],f32>, !torch.list<int>) -> !torch.vtensor<[1,1,1],f32>
torch-mlir-opt: /home/chi/src/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed.
  * Pattern : 'torch.aten.view -> ()' {
Trying to match "(anonymous namespace)::ConvertAtenViewOp"
    ** Insert  : 'arith.constant'(0x55e7d0b01740)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::TensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneOperand<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface::Trait<Empty>)
    ** Insert  : 'tensor.dim'(0x55e7d0b00860)
    ** Insert  : 'arith.constant'(0x55e7d0b017b0)
    ** Insert  : 'tensor.dim'(0x55e7d0b00910)
    ** Insert  : 'arith.constant'(0x55e7d0b00e30)
    ** Insert  : 'tensor.dim'(0x55e7d0b00ea0)
    ** Failure : at most one element in size list is allowed to be -1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment