Skip to content

Instantly share code, notes, and snippets.

@jpienaar
Created March 16, 2023 12:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jpienaar/82c79ae73b2f46be4c9f365334bf8dc3 to your computer and use it in GitHub Desktop.
Save jpienaar/82c79ae73b2f46be4c9f365334bf8dc3 to your computer and use it in GitHub Desktop.
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index a9843d1a1249..d78b6e95f8d4 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -27,6 +27,7 @@ add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(SPIRV)
add_subdirectory(Shape)
+add_subdirectory(Shaped)
add_subdirectory(SparseTensor)
add_subdirectory(Tensor)
add_subdirectory(Tosa)
diff --git a/mlir/include/mlir/Dialect/Shaped/CMakeLists.txt b/mlir/include/mlir/Dialect/Shaped/CMakeLists.txt
new file mode 100644
index 000000000000..f33061b2d87c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shaped/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shaped/IR/CMakeLists.txt
new file mode 100644
index 000000000000..8c2c257896de
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shaped/IR/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(ShapedOps shaped)
+add_mlir_doc(ShapedOps ShapedDialectOps Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/ShapedBase.td b/mlir/include/mlir/Dialect/Shaped/IR/ShapedBase.td
new file mode 100644
index 000000000000..a6d6e9b6cf60
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shaped/IR/ShapedBase.td
@@ -0,0 +1,36 @@
+//===- ShapedBase.td ---------------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Base definitions for the `shaped` dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SHAPED_BASE_TD
+#define SHAPED_BASE_TD
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Shaped type dialect definitions
+//===----------------------------------------------------------------------===//
+
+def ShapedDialect : Dialect {
+ let name = "shaped";
+
+ let summary = "Types and operations for shaped dialect";
+ let description = [{
+ This dialect contains operations for shaped type interaction.
+ }];
+
+ let dependentDialects = [
+ "arith::ArithDialect"
+ ];
+ let cppNamespace = "::mlir::shaped";
+}
+
+#endif // SHAPED_BASE_TD
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/ShapedDialect.h b/mlir/include/mlir/Dialect/Shaped/IR/ShapedDialect.h
new file mode 100644
index 000000000000..4621437f8e3f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shaped/IR/ShapedDialect.h
@@ -0,0 +1,27 @@
+//===- Shaped.h - MLIR Shaped dialect ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the shaped dialect that is used to interact with ShaepdType
+// generally.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SHAPED_IR_SHAPEDDIALECT_H
+#define MLIR_DIALECT_SHAPED_IR_SHAPEDDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Shaped/IR/ShapedOps.h.inc"
+
+#include "mlir/Dialect/Shaped/IR/ShapedOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_SHAPED_IR_SHAPEDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Shaped/IR/ShapedOps.td b/mlir/include/mlir/Dialect/Shaped/IR/ShapedOps.td
new file mode 100644
index 000000000000..8059f6c206bc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shaped/IR/ShapedOps.td
@@ -0,0 +1,51 @@
+//===- ShapedOps.td - Shaped operations definition ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for Shaped dialect operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SHAPED_OPS
+#define SHAPED_OPS
+
+include "mlir/Dialect/Shaped/IR/ShapedBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/IR/OpAsmInterface.td"
+
+// Base class for the operation in this dialect
+class Shaped_Op<string mnemonic, list<Trait> traits = []> :
+ Op<ShapedDialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
+def Shaped_DimOp : Shaped_Op<"dim", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ let summary = "Gets the specified extent from the shape of a shaped input";
+ let description = [{
+ Gets the extent indexed by `dim` from the shape of the `value` operand. If
+ the index is error or out-of-bound then behavior is undefined.
+
+ Note: This op is likely to be canonicalized away to the dim op of the
+ particular ShapedType operand.
+ }];
+ let arguments = (ins AnyShaped:$value, Index:$index);
+ let results = (outs Index:$extent);
+
+ let assemblyFormat = [{
+ attr-dict $value `,` $index `:` type($value)
+ }];
+
+ let builders = [
+ // Builder that allows passing a constant dimension as a simple integer.
+ OpBuilder<(ins "Value":$value, "int64_t":$index)>
+ ];
+}
+
+#endif // SHAPED_OPS
\ No newline at end of file
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index adbbb847adfb..99e76c6ed540 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -57,6 +57,7 @@
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h"
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
@@ -109,6 +110,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,
shape::ShapeDialect,
+ shaped::ShapedDialect,
sparse_tensor::SparseTensorDialect,
tensor::TensorDialect,
transform::TransformDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 5cf5c125a7cb..fd2c189b92af 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -26,6 +26,7 @@ add_subdirectory(PDLInterp)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
+add_subdirectory(Shaped)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
add_subdirectory(Tensor)
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index f9228380c4f2..519d518834e6 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRShapedOpInterfaces
+ MLIRShapedDialect
MLIRSideEffectInterfaces
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d3a1ae1663c0..df442fd0bf24 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1152,9 +1153,19 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
} // namespace
+static LogicalResult replaceShapedDimOp(shaped::DimOp dimOp, PatternRewriter& rewriter) {
+ llvm::errs() << __FILE__ << ":" <<__LINE__ << " attempting\n";
+ if (dimOp.getValue().getType().isa<MemRefType>()) {
+ rewriter.replaceOpWithNewOp<DimOp>(dimOp, dimOp.getValue(), dimOp.getIndex());
+ return success();
+ }
+ return failure();
+}
+
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape>(context);
+ results.add(replaceShapedDimOp);
}
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Shaped/CMakeLists.txt b/mlir/lib/Dialect/Shaped/CMakeLists.txt
new file mode 100644
index 000000000000..f33061b2d87c
--- /dev/null
+++ b/mlir/lib/Dialect/Shaped/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Shaped/IR/CMakeLists.txt b/mlir/lib/Dialect/Shaped/IR/CMakeLists.txt
new file mode 100644
index 000000000000..8a13da48995c
--- /dev/null
+++ b/mlir/lib/Dialect/Shaped/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRShapedDialect
+ Shaped.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shaped
+
+ DEPENDS
+ MLIRShapedOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRDialect
+ MLIRInferTypeOpInterface
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/Shaped/IR/Shaped.cpp b/mlir/lib/Dialect/Shaped/IR/Shaped.cpp
new file mode 100644
index 000000000000..a4e081b5b938
--- /dev/null
+++ b/mlir/lib/Dialect/Shaped/IR/Shaped.cpp
@@ -0,0 +1,56 @@
+//===- Shaped.cpp - MLIR Shaped Operations --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <utility>
+
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/FunctionImplementation.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::shaped;
+
+#include "mlir/Dialect/Shaped/IR/ShapedOpsDialect.cpp.inc"
+
+void ShapedDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Shaped/IR/ShapedOps.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+
+void mlir::shaped::DimOp::build(OpBuilder &builder, OperationState &result, Value source,
+ int64_t index) {
+ auto loc = result.location;
+ Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
+ build(builder, result, source, indexValue);
+}
+
+void mlir::shaped::DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "dim");
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Shaped/IR/ShapedOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index fad4478ee787..a4bd532de17d 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -30,6 +30,7 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRInferTypeOpInterface
MLIRParallelCombiningOpInterface
MLIRShapedOpInterfaces
+ MLIRShapedDialect
MLIRSideEffectInterfaces
MLIRSupport
MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4515d711f72b..e35bc2e1a057 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Shaped/IR/ShapedDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -490,9 +491,19 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
};
} // namespace
+LogicalResult tensorReplaceShapedDimOp(shaped::DimOp dimOp, PatternRewriter& rewriter) {
+ llvm::errs() << __FILE__ << ":" <<__LINE__ << " attempting\n";
+ if (dimOp.getValue().getType().isa<TensorType>()) {
+ rewriter.replaceOpWithNewOp<DimOp>(dimOp, dimOp.getValue(), dimOp.getIndex());
+ return success();
+ }
+ return failure();
+}
+
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp>(context);
+ results.add(tensorReplaceShapedDimOp);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3d9f71e26055..dc5b2ed58a09 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -894,3 +894,14 @@ func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>,
to memref<?xf32, strided<[?], offset: ?>>
return %1 : memref<?xf32, strided<[?], offset: ?>>
}
+
+// -----
+
+// Test case: Folding to memref.dim.
+// CHECK-LABEL: func @shape_dim(
+// CHECK: memref.dim
+func.func @shaped_dim(%arg0: memref<2x?x4x?x5xindex>) -> index {
+ %c3 = arith.constant 3 : index
+ %1 = shaped.dim %arg0, %c3 : memref<2x?x4x?x5xindex>
+ return %1 : index
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f4706fc439b9..c9a236069589 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1774,3 +1774,14 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
%packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
return %packed : tensor<?x?x?x?xf32>
}
+
+// -----
+
+// Test case: Folding to tensor.dim.
+// CHECK-LABEL: func @shaped_dim(
+// CHECK: tensor.dim
+func.func @shaped_dim(%arg0: tensor<2x?x4x?x5xindex>) -> index {
+ %c3 = arith.constant 3 : index
+ %1 = shaped.dim %arg0, %c3 : tensor<2x?x4x?x5xindex>
+ return %1 : index
+}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment