Skip to content

Instantly share code, notes, and snippets.

@flaub
Created January 5, 2020 22:58
Show Gist options
  • Save flaub/3a2448daf2ed50df2f25e77c7fb42d77 to your computer and use it in GitHub Desktop.
Save flaub/3a2448daf2ed50df2f25e77c7fb42d77 to your computer and use it in GitHub Desktop.
Experiment with DRR + DialectConversion
struct LoweringPass : public mlir::ModulePass<LoweringPass> {
void runOnModule() final {
// Set up target (i.e. what is legal)
// ...
// Setup rewrite patterns
OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), patterns);
// Run the conversion
if (failed(applyPartialConversion(getModule(), target, patterns, nullptr))) {
signalPassFailure();
return;
}
}
};
template <typename OpType>
Value eltwiseOpToParallelFor(OpBuilder& builder, Value from) {
auto rewriter = dynamic_cast<mlir::ConversionPatternRewriter*>(&builder);
if (!rewriter) {
// FIXME: how to report failure?
return {};
}
TypeConverter typeConverter;
auto op = from.getDefiningOp();
auto loc = op->getLoc();
auto resultType = op->getResult(0)->getType();
auto resultMemRefType = typeConverter.convertType(resultType).cast<MemRefType>();
auto resultMemRef = builder.create<AllocOp>(loc, resultMemRefType).getResult();
auto ranges = builder.getI64ArrayAttr(resultMemRefType.getShape());
auto dynamicRanges = ArrayRef<Value>();
auto forOp = builder.create<AffineParallelForOp>(loc, ranges, dynamicRanges);
auto body = builder.createBlock(&forOp.inner());
SmallVector<Value, 8> idxs;
for (size_t i = 0; i < ranges.size(); i++) {
idxs.push_back(body->addArgument(builder.getIndexType()));
}
SmallVector<Value, 4> scalars;
for (auto operand : op->getOperands()) {
auto memref = rewriter->getRemappedValue(operand);
scalars.push_back(builder.create<AffineLoadOp>(loc, memref, idxs));
}
auto attrs = ArrayRef<NamedAttribute>{};
auto elementType = resultMemRefType.getElementType();
auto resultTypes = llvm::makeArrayRef(elementType);
auto result = builder.create<OpType>(loc, resultTypes, scalars, attrs);
builder.create<AffineStoreOp>(loc, result, resultMemRef, idxs);
builder.create<AffineTerminatorOp>(loc);
return resultMemRef;
}
#include "rewrites.cc.inc"
#pragma once
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
} // namespace mlir
void populateWithGenerated(mlir::MLIRContext* context, mlir::OwningRewritePatternList* patterns);
#ifndef __PML_CONVERSION_ELTWISE_TO_PXA__
#define __PML_CONVERSION_ELTWISE_TO_PXA__
include "pmlc/dialect/eltwise/ops.td"
include "mlir/Dialect/StandardOps/Ops.td"
class EltwiseOpToParallelFor<Op op> : NativeCodeCall<
"eltwiseOpToParallelFor<" # op # ">($_builder, $0)">;
class EltwiseOpConversionPat<Op from, Op into, TypeConstraint cons> : Pat<
(from:$op $lhs, $rhs),
(EltwiseOpToParallelFor<into> $op),
[(cons $op)]>;
def : EltwiseOpConversionPat<EW_AddOp, AddFOp, EltwiseFloat>;
#endif // __PML_CONVERSION_ELTWISE_TO_PXA__
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment