Skip to content

Instantly share code, notes, and snippets.

@ftynse
Created March 8, 2021 12:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ftynse/0dda302701aca13e0508f341a69a3055 to your computer and use it in GitHub Desktop.
Save ftynse/0dda302701aca13e0508f341a69a3055 to your computer and use it in GitHub Desktop.
diff --git a/LowerToLLVM.h b/LowerToLLVM.h
index 6342c71..49c8ddb 100644
--- a/LowerToLLVM.h
+++ b/LowerToLLVM.h
@@ -21,7 +21,7 @@
class TestOpLowering : public mlir::OpConversionPattern<TestOp>
{
public:
- TestOpLowering(mlir::MLIRContext* ctx, TypeConverter& typeConverter)
+ TestOpLowering(mlir::MLIRContext* ctx, mlir::TypeConverter& typeConverter)
: mlir::OpConversionPattern<TestOp>(typeConverter, ctx, 1)
{
}
@@ -75,9 +75,51 @@ public:
registry.insert<mlir::LLVM::LLVMDialect>();
}
+ mlir::LogicalResult step1(mlir::ModuleOp module) {
+ mlir::ConversionTarget target(getContext());
+ target.addLegalDialect<mlir::StandardOpsDialect, mlir::BuiltinDialect>();
+
+ mlir::TypeConverter typeConverter;
+ typeConverter.addConversion([](mlir::Type type) ->llvm::Optional<mlir::Type> {
+ if (llvm::isa<mlir::BuiltinDialect>(type.getDialect()))
+ return type;
+ return llvm::None;
+ });
+ typeConverter.addConversion([this](BooleanType type) {
+ return mlir::IntegerType::get(&getContext(), 1);
+ });
+ typeConverter.addTargetMaterialization([](mlir::OpBuilder &builder,
+ mlir::IntegerType resultType,
+ mlir::ValueRange inputs,
+ mlir::Location loc)
+ -> llvm::Optional<mlir::Value> {
+ if (inputs.size() != 1 || !inputs[0].getType().isa<BooleanType>())
+ return llvm::None;
+
+ return builder.create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs[0]).getResult(0);
+ });
+ typeConverter.addSourceMaterialization([](mlir::OpBuilder &builder,
+ BooleanType resultType,
+ mlir::ValueRange inputs,
+ mlir::Location loc)
+ -> llvm::Optional<mlir::Value> {
+ if (inputs.size() != 1 || !inputs[0].getType().isa<mlir::IntegerType>())
+ return llvm::None;
+
+ return builder.create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs[0]).getResult(0);
+ });
+
+ mlir::OwningRewritePatternList patterns;
+ patterns.insert<TestOpLowering>(&getContext(), typeConverter);
+ return applyPartialConversion(module, target, std::move(patterns));
+ }
+
void runOnOperation() final {
auto module = getOperation();
+ if (failed(step1(module)))
+ return signalPassFailure();
+#if 0
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
target.addIllegalOp<mlir::LLVM::DialectCastOp>();
@@ -104,11 +146,12 @@ public:
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our "illegal"
// operations were not converted successfully.
- if (failed(applyFullConversion(module, target, std::move(patterns))))
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
{
mlir::emitError(module.getLoc(), "Error in converting to LLVM dialect\n");
signalPassFailure();
}
+#endif
}
};
diff --git a/main.cpp b/main.cpp
index e1db88e..94715da 100644
--- a/main.cpp
+++ b/main.cpp
@@ -1,5 +1,7 @@
#include <iostream>
#include <llvm/ADT/SmallVector.h>
+#include <llvm/Support/InitLLVM.h>
+#include <llvm/Support/CommandLine.h>
#include <mlir/Conversion/Passes.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
@@ -14,7 +16,12 @@
#include "TypeConverter.h"
#include "Types.h"
-int main() {
+int main(int argc, char **argv) {
+ llvm::InitLLVM y(argc, argv);
+ mlir::registerMLIRContextCLOptions();
+ mlir::registerPassManagerCLOptions();
+ llvm::cl::ParseCommandLineOptions(argc, argv, "foo");
+
mlir::MLIRContext context;
context.loadDialect<MyDialect>();
context.loadDialect<mlir::StandardOpsDialect>();
@@ -63,7 +70,11 @@ int main() {
// Convert the module to LLVM IR
mlir::PassManager passManager(&context);
passManager.addPass(createMyDialectToLLVMLoweringPass());
- passManager.run(module);
+ passManager.addPass(mlir::createLowerToLLVMPass());
+ if (failed(passManager.run(module)))
+ return 1;
+
+ module.dump();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment