Created
November 3, 2021 05:08
-
-
Save joker-eph/8f00e79783688bf6349b6b5d9d44a480 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp | |
index 351a48a7d515..59f54e78c2a3 100644 | |
--- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp | |
+++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp | |
@@ -59,7 +59,7 @@ class ArrayCopyAnalysis { | |
public: | |
using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; | |
using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; | |
- using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>; | |
+ using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>; | |
ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { | |
construct(op->getRegions()); | |
@@ -79,15 +79,9 @@ public: | |
/// back to the array load that is the original source of the array value. | |
const OperationUseMapT &getUseMap() const { return useMap; } | |
- /// For ArrayLoad `load`, return the transitive set of all OpOperands. | |
- UseSetT getLoadUseSet(mlir::Operation *load) const { | |
- assert(loadMapSets.count(load) && "analysis missed an array load?"); | |
- return loadMapSets.lookup(load); | |
- } | |
- | |
/// Get all the array value operations that use the original array value | |
/// as passed to `store`. | |
- void arrayAccesses(llvm::SmallVectorImpl<mlir::Operation *> &accesses, | |
+ const llvm::SmallVector<mlir::Operation *> &arrayAccesses( | |
ArrayLoadOp load); | |
private: | |
@@ -269,19 +263,12 @@ private: | |
/// Find all the array operations that access the array value that is loaded by | |
/// the array load operation, `load`. | |
-void ArrayCopyAnalysis::arrayAccesses( | |
- llvm::SmallVectorImpl<mlir::Operation *> &accesses, ArrayLoadOp load) { | |
- accesses.clear(); | |
+const llvm::SmallVector<mlir::Operation *> & ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { | |
auto lmIter = loadMapSets.find(load); | |
if (lmIter != loadMapSets.end()) { | |
- for (auto *opnd : lmIter->second) { | |
- auto *owner = opnd->getOwner(); | |
- if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(owner)) | |
- accesses.push_back(owner); | |
- } | |
- return; | |
+ return lmIter->getSecond(); | |
} | |
- | |
+ llvm::SmallVector<mlir::Operation *> accesses; | |
UseSetT visited; | |
llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] | |
@@ -357,7 +344,7 @@ void ArrayCopyAnalysis::arrayAccesses( | |
llvm::report_fatal_error("array value reached unexpected op"); | |
} | |
} | |
- loadMapSets.insert({load, visited}); | |
+ return loadMapSets.insert({load, accesses}).first->getSecond(); | |
} | |
/// Is there a conflict between the array value that was updated and to be | |
@@ -443,8 +430,8 @@ void ArrayCopyAnalysis::construct(mlir::MutableArrayRef<mlir::Region> regions) { | |
if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { | |
llvm::SmallVector<Operation *> values; | |
ReachCollector::reachingValues(values, st.sequence()); | |
- llvm::SmallVector<Operation *> accesses; | |
- arrayAccesses(accesses, | |
+ const llvm::SmallVector<Operation *> &accesses = | |
+ arrayAccesses( | |
mlir::cast<ArrayLoadOp>(st.original().getDefiningOp())); | |
if (conflictDetected(values, accesses, st)) { | |
LLVM_DEBUG(llvm::dbgs() | |
@@ -459,8 +446,8 @@ void ArrayCopyAnalysis::construct(mlir::MutableArrayRef<mlir::Region> regions) { | |
<< "map: adding {" << *ld << " -> " << st << "}\n"); | |
useMap.insert({ld, &op}); | |
} else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { | |
- llvm::SmallVector<mlir::Operation *> accesses; | |
- arrayAccesses(accesses, load); | |
+ const llvm::SmallVector<mlir::Operation *> &accesses = | |
+ arrayAccesses(load); | |
LLVM_DEBUG(llvm::dbgs() << "process load: " << load | |
<< ", accesses: " << accesses.size() << '\n'); | |
for (auto *acc : accesses) { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment