Skip to content

Instantly share code, notes, and snippets.

@j2kun
Created September 19, 2023 21:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save j2kun/ef74a9064cc1e5948586ff92e9c1ef11 to your computer and use it in GitHub Desktop.
Save j2kun/ef74a9064cc1e5948586ff92e9c1ef11 to your computer and use it in GitHub Desktop.
Generated rewrite pattern for lifting complex conjugate through polynomial evaluation
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Rewriters *|
|* *|
|* Automatically generated file, do not edit! *|
|* *|
\*===----------------------------------------------------------------------===*/
/* Generated from:
lib/Dialect/Poly/PolyPatterns.td:8
*/
struct LiftConjThroughEval : public ::mlir::RewritePattern {
LiftConjThroughEval(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("poly.eval", 2, context, {"complex.conj", "poly.eval"}) {}
::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
::mlir::PatternRewriter &rewriter) const override {
// Variables for capturing values and attributes used while creating ops
::mlir::Operation::operand_range z(op0->getOperands());
::mlir::Operation::operand_range f(op0->getOperands());
::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
// Match
tblgen_ops.push_back(op0);
auto castedOp0 = ::llvm::dyn_cast<::mlir::tutorial::poly::EvalOp>(op0); (void)castedOp0;
f = castedOp0.getODSOperands(0);
{
auto *op1 = (*castedOp0.getODSOperands(1).begin()).getDefiningOp();
if (!(op1)){
return rewriter.notifyMatchFailure(castedOp0, [&](::mlir::Diagnostic &diag) {
diag << "There's no operation that defines operand 1 of castedOp0";
});
}
auto castedOp1 = ::llvm::dyn_cast<::mlir::complex::ConjOp>(op1); (void)castedOp1;
if (!(castedOp1)){
return rewriter.notifyMatchFailure(op1, [&](::mlir::Diagnostic &diag) {
diag << "castedOp1 is not ::mlir::complex::ConjOp type";
});
}
z = castedOp1.getODSOperands(0);
tblgen_ops.push_back(op1);
}
// Rewrite
auto odsLoc = rewriter.getFusedLoc({tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()}); (void)odsLoc;
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
::mlir::tutorial::poly::EvalOp tblgen_EvalOp_0;
{
::mlir::Value tblgen_value_0 = (*f.begin());
::mlir::Value tblgen_value_1 = (*z.begin());
tblgen_EvalOp_0 = rewriter.create<::mlir::tutorial::poly::EvalOp>(odsLoc,
/*input=*/tblgen_value_0,
/*point=*/tblgen_value_1
);
}
::mlir::complex::ConjOp tblgen_ConjOp_1;
{
::llvm::SmallVector<::mlir::Value, 4> tblgen_values; (void)tblgen_values;
::llvm::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;
tblgen_values.push_back((*tblgen_EvalOp_0.getODSResults(0).begin()));
tblgen_ConjOp_1 = rewriter.create<::mlir::complex::ConjOp>(odsLoc, tblgen_values, tblgen_attrs);
}
for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ tblgen_ConjOp_1.getODSResults(0) }) {
tblgen_repl_values.push_back(v);
}
rewriter.replaceOp(op0, tblgen_repl_values);
return ::mlir::success();
};
};
void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::RewritePatternSet &patterns) {
patterns.add<LiftConjThroughEval>(patterns.getContext());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment