diff --git a/enzyme/Enzyme/JLInstSimplify.cpp b/enzyme/Enzyme/JLInstSimplify.cpp index debf0883da3..3c369e8c7af 100644 --- a/enzyme/Enzyme/JLInstSimplify.cpp +++ b/enzyme/Enzyme/JLInstSimplify.cpp @@ -59,88 +59,6 @@ using namespace llvm; #define DEBUG_TYPE "jl-inst-simplify" namespace { -// Return true if guaranteed not to alias -// Return false if guaranteed to alias [with possible offset depending on flag]. -// Return {} if no information is given. -#if LLVM_VERSION_MAJOR >= 16 -std::optional -#else -llvm::Optional -#endif -arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, - llvm::LoopInfo &LI, llvm::Value *op0, - llvm::Value *op1, bool offsetAllowed = false) { - auto lhs = getBaseObject(op0, offsetAllowed); - auto rhs = getBaseObject(op1, offsetAllowed); - - if (lhs == rhs) { - return false; - } - if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy()) - return {}; - - bool noalias_lhs = isNoAlias(lhs); - bool noalias_rhs = isNoAlias(rhs); - - bool noalias[2] = {noalias_lhs, noalias_rhs}; - - for (int i = 0; i < 2; i++) { - Value *start = (i == 0) ? lhs : rhs; - Value *end = (i == 0) ? rhs : lhs; - if (noalias[i]) { - if (noalias[1 - i]) { - return true; - } - if (isa(end)) { - return true; - } - if (auto endi = dyn_cast(end)) { - if (notCapturedBefore(start, endi, 0)) { - return true; - } - } - } - if (auto ld = dyn_cast(start)) { - auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false); - if (isAllocationCall(base, TLI)) { - if (isa(end)) - return true; - if (auto endi = dyn_cast(end)) - if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) { - Instruction *starti = dyn_cast(start); - if (!starti) { - if (!isa(start)) - continue; - starti = - &cast(start)->getParent()->getEntryBlock().front(); - } - - bool overwritten = false; - allInstructionsBetween( - LI, starti, endi, [&](Instruction *I) -> bool { - if (!I->mayWriteToMemory()) - return /*earlyBreak*/ false; - - if (writesToMemoryReadBy(nullptr, AA, TLI, - /*maybeReader*/ ld, - /*maybeWriter*/ I)) { - overwritten = true; - return /*earlyBreak*/ true; - } - return /*earlyBreak*/ false; - }); - - if (!overwritten) { - return true; - } - } - } - } - } - - return {}; -} - bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, llvm::AAResults &AA, llvm::LoopInfo &LI) { bool changed = false; @@ -198,7 +116,7 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI, changed = true; continue; } - } + } } return changed; diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index f6a8db04453..2cc7d68ddea 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2218,7 +2218,10 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, const SCEV *StoreBegin = SE.getCouldNotCompute(); const SCEV *StoreEnd = SE.getCouldNotCompute(); + Value *loadPtr = nullptr; + Value *storePtr = nullptr; if (auto LI = dyn_cast(maybeReader)) { + loadPtr = LI->getPointerOperand(); LoadBegin = SE.getSCEV(LI->getPointerOperand()); if (LoadBegin != SE.getCouldNotCompute() && !LoadBegin->getType()->isIntegerTy()) { @@ -2236,6 +2239,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto SI = dyn_cast(maybeWriter)) { + storePtr = SI->getPointerOperand(); StoreBegin = SE.getSCEV(SI->getPointerOperand()); if (StoreBegin != SE.getCouldNotCompute() && !StoreBegin->getType()->isIntegerTy()) { @@ -2255,6 +2259,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto MS = dyn_cast(maybeWriter)) { + storePtr = MS->getArgOperand(0); StoreBegin = SE.getSCEV(MS->getArgOperand(0)); if (StoreBegin != SE.getCouldNotCompute() && !StoreBegin->getType()->isIntegerTy()) { @@ -2269,6 +2274,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto MS = dyn_cast(maybeWriter)) { + storePtr = MS->getArgOperand(0); StoreBegin = SE.getSCEV(MS->getArgOperand(0)); if (StoreBegin != SE.getCouldNotCompute() && !StoreBegin->getType()->isIntegerTy()) { @@ -2283,6 +2289,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } if (auto MS = dyn_cast(maybeReader)) { + loadPtr = MS->getArgOperand(1); LoadBegin = SE.getSCEV(MS->getArgOperand(1)); if (LoadBegin != SE.getCouldNotCompute() && !LoadBegin->getType()->isIntegerTy()) { @@ -2297,6 +2304,12 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA, } } + if (loadPtr && storePtr) + if (auto alias = + arePointersGuaranteedNoAlias(TLI, AA, LI, loadPtr, storePtr, true)) + if (alias.getValue()) + return false; + if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd, maybeWriter, StoreBegin, StoreEnd, scope)) return false; @@ -3887,3 +3900,85 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst, } return true; } + +// Return true if guaranteed not to alias +// Return false if guaranteed to alias [with possible offset depending on flag]. +// Return {} if no information is given. +#if LLVM_VERSION_MAJOR >= 16 +std::optional +#else +llvm::Optional +#endif +arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA, + llvm::LoopInfo &LI, llvm::Value *op0, + llvm::Value *op1, bool offsetAllowed) { + auto lhs = getBaseObject(op0, offsetAllowed); + auto rhs = getBaseObject(op1, offsetAllowed); + + if (lhs == rhs) { + return false; + } + if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy()) + return {}; + + bool noalias_lhs = isNoAlias(lhs); + bool noalias_rhs = isNoAlias(rhs); + + bool noalias[2] = {noalias_lhs, noalias_rhs}; + + for (int i = 0; i < 2; i++) { + Value *start = (i == 0) ? lhs : rhs; + Value *end = (i == 0) ? rhs : lhs; + if (noalias[i]) { + if (noalias[1 - i]) { + return true; + } + if (isa(end)) { + return true; + } + if (auto endi = dyn_cast(end)) { + if (notCapturedBefore(start, endi, 0)) { + return true; + } + } + } + if (auto ld = dyn_cast(start)) { + auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false); + if (isAllocationCall(base, TLI)) { + if (isa(end)) + return true; + if (auto endi = dyn_cast(end)) + if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) { + Instruction *starti = dyn_cast(start); + if (!starti) { + if (!isa(start)) + continue; + starti = + &cast(start)->getParent()->getEntryBlock().front(); + } + + bool overwritten = false; + allInstructionsBetween( + LI, starti, endi, [&](Instruction *I) -> bool { + if (!I->mayWriteToMemory()) + return /*earlyBreak*/ false; + + if (writesToMemoryReadBy(nullptr, AA, TLI, + /*maybeReader*/ ld, + /*maybeWriter*/ I)) { + overwritten = true; + return /*earlyBreak*/ true; + } + return /*earlyBreak*/ false; + }); + + if (!overwritten) { + return true; + } + } + } + } + } + + return {}; +} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 2fbe3c8db69..62493de9046 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -2104,4 +2104,16 @@ bool isNVLoad(const llvm::Value *V); bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst, size_t checkLoadCaptured); +// Return true if guaranteed not to alias +// Return false if guaranteed to alias [with possible offset depending on flag]. +// Return {} if no information is given. +#if LLVM_VERSION_MAJOR >= 16 +std::optional +#else +llvm::Optional +#endif +arePointersGuaranteedNoAlias(llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, + llvm::LoopInfo &LI, llvm::Value *op0, + llvm::Value *op1, bool offsetAllowed = false); + #endif // ENZYME_UTILS_H