-
-
Save jpienaar/82c79ae73b2f46be4c9f365334bf8dc3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt | |
index a9843d1a1249..d78b6e95f8d4 100644 | |
--- a/mlir/include/mlir/Dialect/CMakeLists.txt | |
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt | |
@@ -27,6 +27,7 @@ add_subdirectory(Quant) | |
add_subdirectory(SCF) | |
add_subdirectory(SPIRV) | |
add_subdirectory(Shape) | |
+add_subdirectory(Shaped) | |
add_subdirectory(SparseTensor) | |
add_subdirectory(Tensor) | |
add_subdirectory(Tosa) | |
diff --git a/mlir/include/mlir/Dialect/Shaped/CMakeLists.txt b/mlir/include/mlir/Dialect/Shaped/CMakeLists.txt | |
new file mode 100644 | |
index 000000000000..f33061b2d87c | |
--- /dev/null | |
+++ b/mlir/include/mlir/Dialect/Shaped/CMakeLists.txt | |
@@ -0,0 +1 @@ | |
+add_subdirectory(IR) | |
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shaped/IR/CMakeLists.txt | |
new file mode 100644 | |
index 000000000000..8c2c257896de | |
--- /dev/null | |
+++ b/mlir/include/mlir/Dialect/Shaped/IR/CMakeLists.txt | |
@@ -0,0 +1,2 @@ | |
+add_mlir_dialect(ShapedOps shaped) | |
+add_mlir_doc(ShapedOps ShapedDialectOps Dialects/ -gen-dialect-doc) | |
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/ShapedBase.td b/mlir/include/mlir/Dialect/Shaped/IR/ShapedBase.td | |
new file mode 100644 | |
index 000000000000..a6d6e9b6cf60 | |
--- /dev/null | |
+++ b/mlir/include/mlir/Dialect/Shaped/IR/ShapedBase.td | |
@@ -0,0 +1,36 @@ | |
+//===- ShapedBase.td ---------------------------------------*- tablegen -*-===// | |
+// | |
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |
+// See https://llvm.org/LICENSE.txt for license information. | |
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
+// | |
+//===----------------------------------------------------------------------===// | |
+// | |
+// Base definitions for the `shaped` dialect. | |
+// | |
+//===----------------------------------------------------------------------===// | |
+ | |
+#ifndef SHAPED_BASE_TD | |
+#define SHAPED_BASE_TD | |
+ | |
+include "mlir/IR/OpBase.td" | |
+ | |
+//===----------------------------------------------------------------------===// | |
+// Shaped type dialect definitions | |
+//===----------------------------------------------------------------------===// | |
+ | |
+def ShapedDialect : Dialect { | |
+ let name = "shaped"; | |
+ | |
+ let summary = "Types and operations for shaped dialect"; | |
+ let description = [{ | |
+ This dialect contains operations for shaped type interaction. | |
+ }]; | |
+ | |
+ let dependentDialects = [ | |
+ "arith::ArithDialect" | |
+ ]; | |
+ let cppNamespace = "::mlir::shaped"; | |
+} | |
+ | |
+#endif // SHAPED_BASE_TD | |
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/ShapedDialect.h b/mlir/include/mlir/Dialect/Shaped/IR/ShapedDialect.h | |
new file mode 100644 | |
index 000000000000..4621437f8e3f | |
--- /dev/null | |
+++ b/mlir/include/mlir/Dialect/Shaped/IR/ShapedDialect.h | |
@@ -0,0 +1,27 @@ | |
+//===- Shaped.h - MLIR Shaped dialect ---------------------------*- C++ -*-===// | |
+// | |
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |
+// See https://llvm.org/LICENSE.txt for license information. | |
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
+// | |
+//===----------------------------------------------------------------------===// | |
+// | |
+// This file defines the shaped dialect that is used to interact with ShaepdType | |
+// generally. | |
+// | |
+//===----------------------------------------------------------------------===// | |
+ | |
+#ifndef MLIR_DIALECT_SHAPED_IR_SHAPEDDIALECT_H | |
+#define MLIR_DIALECT_SHAPED_IR_SHAPEDDIALECT_H | |
+ | |
+#include "mlir/IR/Dialect.h" | |
+#include "mlir/IR/OpDefinition.h" | |
+#include "mlir/IR/OpImplementation.h" | |
+#include "mlir/Interfaces/InferTypeOpInterface.h" | |
+ | |
+#define GET_OP_CLASSES | |
+#include "mlir/Dialect/Shaped/IR/ShapedOps.h.inc" | |
+ | |
+#include "mlir/Dialect/Shaped/IR/ShapedOpsDialect.h.inc" | |
+ | |
+#endif // MLIR_DIALECT_SHAPED_IR_SHAPEDIALECT_H | |
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/ShapedOps.td b/mlir/include/mlir/Dialect/Shaped/IR/ShapedOps.td | |
new file mode 100644 | |
index 000000000000..8059f6c206bc | |
--- /dev/null | |
+++ b/mlir/include/mlir/Dialect/Shaped/IR/ShapedOps.td | |
@@ -0,0 +1,51 @@ | |
+//===- ShapedOps.td - Shaped operations definition ---------*- tablegen -*-===// | |
+// | |
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |
+// See https://llvm.org/LICENSE.txt for license information. | |
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
+// | |
+//===----------------------------------------------------------------------===// | |
+// | |
+// This is the operation definition file for Shaped dialect operations. | |
+// | |
+//===----------------------------------------------------------------------===// | |
+ | |
+#ifndef SHAPED_OPS | |
+#define SHAPED_OPS | |
+ | |
+include "mlir/Dialect/Shaped/IR/ShapedBase.td" | |
+include "mlir/Interfaces/InferTypeOpInterface.td" | |
+include "mlir/IR/OpAsmInterface.td" | |
+ | |
+// Base class for the operation in this dialect | |
+class Shaped_Op<string mnemonic, list<Trait> traits = []> : | |
+ Op<ShapedDialect, mnemonic, traits>; | |
+ | |
+//===----------------------------------------------------------------------===// | |
+// DimOp | |
+//===----------------------------------------------------------------------===// | |
+ | |
+def Shaped_DimOp : Shaped_Op<"dim", [ | |
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { | |
+ let summary = "Gets the specified extent from the shape of a shaped input"; | |
+ let description = [{ | |
+ Gets the extent indexed by `dim` from the shape of the `value` operand. If | |
+ the index is error or out-of-bound then behavior is undefined. | |
+ | |
+ Note: This op is likely to be canonicalized away to the dim op of the | |
+ particular ShapedType operand. | |
+ }]; | |
+ let arguments = (ins AnyShaped:$value, Index:$index); | |
+ let results = (outs Index:$extent); | |
+ | |
+ let assemblyFormat = [{ | |
+ attr-dict $value `,` $index `:` type($value) | |
+ }]; | |
+ | |
+ let builders = [ | |
+ // Builder that allows passing a constant dimension as a simple integer. | |
+ OpBuilder<(ins "Value":$value, "int64_t":$index)> | |
+ ]; | |
+} | |
+ | |
+#endif // SHAPED_OPS | |
\ No newline at end of file | |
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h | |
index adbbb847adfb..99e76c6ed540 100644 | |
--- a/mlir/include/mlir/InitAllDialects.h | |
+++ b/mlir/include/mlir/InitAllDialects.h | |
@@ -57,6 +57,7 @@ | |
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" | |
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | |
#include "mlir/Dialect/Shape/IR/Shape.h" | |
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h" | |
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" | |
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" | |
#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" | |
@@ -109,6 +110,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { | |
NVVM::NVVMDialect, | |
ROCDL::ROCDLDialect, | |
shape::ShapeDialect, | |
+ shaped::ShapedDialect, | |
sparse_tensor::SparseTensorDialect, | |
tensor::TensorDialect, | |
transform::TransformDialect, | |
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt | |
index 5cf5c125a7cb..fd2c189b92af 100644 | |
--- a/mlir/lib/Dialect/CMakeLists.txt | |
+++ b/mlir/lib/Dialect/CMakeLists.txt | |
@@ -26,6 +26,7 @@ add_subdirectory(PDLInterp) | |
add_subdirectory(Quant) | |
add_subdirectory(SCF) | |
add_subdirectory(Shape) | |
+add_subdirectory(Shaped) | |
add_subdirectory(SparseTensor) | |
add_subdirectory(SPIRV) | |
add_subdirectory(Tensor) | |
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | |
index f9228380c4f2..519d518834e6 100644 | |
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | |
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | |
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect | |
MLIRInferTypeOpInterface | |
MLIRIR | |
MLIRShapedOpInterfaces | |
+ MLIRShapedDialect | |
MLIRSideEffectInterfaces | |
MLIRViewLikeInterface | |
) | |
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | |
index d3a1ae1663c0..df442fd0bf24 100644 | |
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | |
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | |
@@ -10,6 +10,7 @@ | |
#include "mlir/Dialect/Arith/Utils/Utils.h" | |
#include "mlir/Dialect/MemRef/IR/MemRef.h" | |
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" | |
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h" | |
#include "mlir/Dialect/Utils/StaticValueUtils.h" | |
#include "mlir/IR/AffineMap.h" | |
#include "mlir/IR/Builders.h" | |
@@ -1152,9 +1153,19 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { | |
} // namespace | |
+static LogicalResult replaceShapedDimOp(shaped::DimOp dimOp, PatternRewriter& rewriter) { | |
+ llvm::errs() << __FILE__ << ":" <<__LINE__ << " attempting\n"; | |
+ if (dimOp.getValue().getType().isa<MemRefType>()) { | |
+ rewriter.replaceOpWithNewOp<DimOp>(dimOp, dimOp.getValue(), dimOp.getIndex()); | |
+ return success(); | |
+ } | |
+ return failure(); | |
+} | |
+ | |
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
MLIRContext *context) { | |
results.add<DimOfMemRefReshape>(context); | |
+ results.add(replaceShapedDimOp); | |
} | |
// --------------------------------------------------------------------------- | |
diff --git a/mlir/lib/Dialect/Shaped/CMakeLists.txt b/mlir/lib/Dialect/Shaped/CMakeLists.txt | |
new file mode 100644 | |
index 000000000000..f33061b2d87c | |
--- /dev/null | |
+++ b/mlir/lib/Dialect/Shaped/CMakeLists.txt | |
@@ -0,0 +1 @@ | |
+add_subdirectory(IR) | |
diff --git a/mlir/lib/Dialect/Shaped/IR/CMakeLists.txt b/mlir/lib/Dialect/Shaped/IR/CMakeLists.txt | |
new file mode 100644 | |
index 000000000000..8a13da48995c | |
--- /dev/null | |
+++ b/mlir/lib/Dialect/Shaped/IR/CMakeLists.txt | |
@@ -0,0 +1,15 @@ | |
+add_mlir_dialect_library(MLIRShapedDialect | |
+ Shaped.cpp | |
+ | |
+ ADDITIONAL_HEADER_DIRS | |
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shaped | |
+ | |
+ DEPENDS | |
+ MLIRShapedOpsIncGen | |
+ | |
+ LINK_LIBS PUBLIC | |
+ MLIRArithDialect | |
+ MLIRDialect | |
+ MLIRInferTypeOpInterface | |
+ MLIRIR | |
+ ) | |
diff --git a/mlir/lib/Dialect/Shaped/IR/Shaped.cpp b/mlir/lib/Dialect/Shaped/IR/Shaped.cpp | |
new file mode 100644 | |
index 000000000000..a4e081b5b938 | |
--- /dev/null | |
+++ b/mlir/lib/Dialect/Shaped/IR/Shaped.cpp | |
@@ -0,0 +1,56 @@ | |
+//===- Shaped.cpp - MLIR Shaped Operations --------------------------------===// | |
+// | |
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |
+// See https://llvm.org/LICENSE.txt for license information. | |
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
+// | |
+//===----------------------------------------------------------------------===// | |
+ | |
+#include <utility> | |
+ | |
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h" | |
+ | |
+#include "mlir/Dialect/Arith/IR/Arith.h" | |
+#include "mlir/Dialect/Traits.h" | |
+#include "mlir/IR/Builders.h" | |
+#include "mlir/IR/BuiltinTypes.h" | |
+#include "mlir/IR/DialectImplementation.h" | |
+#include "mlir/IR/FunctionImplementation.h" | |
+#include "mlir/IR/Matchers.h" | |
+#include "mlir/IR/PatternMatch.h" | |
+#include "mlir/IR/TypeUtilities.h" | |
+#include "mlir/Transforms/InliningUtils.h" | |
+#include "llvm/ADT/SetOperations.h" | |
+#include "llvm/ADT/SmallString.h" | |
+#include "llvm/ADT/TypeSwitch.h" | |
+#include "llvm/Support/raw_ostream.h" | |
+ | |
+using namespace mlir; | |
+using namespace mlir::shaped; | |
+ | |
+#include "mlir/Dialect/Shaped/IR/ShapedOpsDialect.cpp.inc" | |
+ | |
+void ShapedDialect::initialize() { | |
+ addOperations< | |
+#define GET_OP_LIST | |
+#include "mlir/Dialect/Shaped/IR/ShapedOps.cpp.inc" | |
+ >(); | |
+} | |
+ | |
+//===----------------------------------------------------------------------===// | |
+// DimOp | |
+//===----------------------------------------------------------------------===// | |
+ | |
+void mlir::shaped::DimOp::build(OpBuilder &builder, OperationState &result, Value source, | |
+ int64_t index) { | |
+ auto loc = result.location; | |
+ Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); | |
+ build(builder, result, source, indexValue); | |
+} | |
+ | |
+void mlir::shaped::DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { | |
+ setNameFn(getResult(), "dim"); | |
+} | |
+ | |
+#define GET_OP_CLASSES | |
+#include "mlir/Dialect/Shaped/IR/ShapedOps.cpp.inc" | |
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt | |
index fad4478ee787..a4bd532de17d 100644 | |
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt | |
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt | |
@@ -30,6 +30,7 @@ add_mlir_dialect_library(MLIRTensorDialect | |
MLIRInferTypeOpInterface | |
MLIRParallelCombiningOpInterface | |
MLIRShapedOpInterfaces | |
+ MLIRShapedDialect | |
MLIRSideEffectInterfaces | |
MLIRSupport | |
MLIRViewLikeInterface | |
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | |
index 4515d711f72b..e35bc2e1a057 100644 | |
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | |
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | |
@@ -10,6 +10,7 @@ | |
#include "mlir/Dialect/Arith/IR/Arith.h" | |
#include "mlir/Dialect/Arith/Utils/Utils.h" | |
#include "mlir/Dialect/Complex/IR/Complex.h" | |
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h" | |
#include "mlir/Dialect/Tensor/IR/Tensor.h" | |
#include "mlir/Dialect/Utils/IndexingUtils.h" | |
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" | |
@@ -490,9 +491,19 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> { | |
}; | |
} // namespace | |
+LogicalResult tensorReplaceShapedDimOp(shaped::DimOp dimOp, PatternRewriter& rewriter) { | |
+ llvm::errs() << __FILE__ << ":" <<__LINE__ << " attempting\n"; | |
+ if (dimOp.getValue().getType().isa<TensorType>()) { | |
+ rewriter.replaceOpWithNewOp<DimOp>(dimOp, dimOp.getValue(), dimOp.getIndex()); | |
+ return success(); | |
+ } | |
+ return failure(); | |
+} | |
+ | |
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
MLIRContext *context) { | |
results.add<DimOfCastOp>(context); | |
+ results.add(tensorReplaceShapedDimOp); | |
} | |
//===----------------------------------------------------------------------===// | |
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir | |
index 3d9f71e26055..dc5b2ed58a09 100644 | |
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir | |
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir | |
@@ -894,3 +894,14 @@ func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>, | |
to memref<?xf32, strided<[?], offset: ?>> | |
return %1 : memref<?xf32, strided<[?], offset: ?>> | |
} | |
+ | |
+// ----- | |
+ | |
+// Test case: Folding to memref.dim. | |
+// CHECK-LABEL: func @shape_dim( | |
+// CHECK: memref.dim | |
+func.func @shaped_dim(%arg0: memref<2x?x4x?x5xindex>) -> index { | |
+ %c3 = arith.constant 3 : index | |
+ %1 = shaped.dim %arg0, %c3 : memref<2x?x4x?x5xindex> | |
+ return %1 : index | |
+} | |
\ No newline at end of file | |
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir | |
index f4706fc439b9..c9a236069589 100644 | |
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir | |
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir | |
@@ -1774,3 +1774,14 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde | |
%packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32> | |
return %packed : tensor<?x?x?x?xf32> | |
} | |
+ | |
+// ----- | |
+ | |
+// Test case: Folding to tensor.dim. | |
+// CHECK-LABEL: func @shaped_dim( | |
+// CHECK: tensor.dim | |
+func.func @shaped_dim(%arg0: tensor<2x?x4x?x5xindex>) -> index { | |
+ %c3 = arith.constant 3 : index | |
+ %1 = shaped.dim %arg0, %c3 : tensor<2x?x4x?x5xindex> | |
+ return %1 : index | |
+} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment