Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 12, 2024
1 parent 9aa0fd4 commit f2f0ae4
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 83 deletions.
84 changes: 1 addition & 83 deletions enzyme/Enzyme/JLInstSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,88 +59,6 @@ using namespace llvm;
#define DEBUG_TYPE "jl-inst-simplify"
namespace {

// Return true if guaranteed not to alias
// Return false if guaranteed to alias [with possible offset depending on flag].
// Return {} if no information is given.
#if LLVM_VERSION_MAJOR >= 16
std::optional<bool>
#else
llvm::Optional<bool>
#endif
arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA,
llvm::LoopInfo &LI, llvm::Value *op0,
llvm::Value *op1, bool offsetAllowed = false) {
auto lhs = getBaseObject(op0, offsetAllowed);
auto rhs = getBaseObject(op1, offsetAllowed);

if (lhs == rhs) {
return false;
}
if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy())
return {};

bool noalias_lhs = isNoAlias(lhs);
bool noalias_rhs = isNoAlias(rhs);

bool noalias[2] = {noalias_lhs, noalias_rhs};

for (int i = 0; i < 2; i++) {
Value *start = (i == 0) ? lhs : rhs;
Value *end = (i == 0) ? rhs : lhs;
if (noalias[i]) {
if (noalias[1 - i]) {
return true;
}
if (isa<Argument>(end)) {
return true;
}
if (auto endi = dyn_cast<Instruction>(end)) {
if (notCapturedBefore(start, endi, 0)) {
return true;
}
}
}
if (auto ld = dyn_cast<LoadInst>(start)) {
auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false);
if (isAllocationCall(base, TLI)) {
if (isa<Argument>(end))
return true;
if (auto endi = dyn_cast<Instruction>(end))
if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) {
Instruction *starti = dyn_cast<Instruction>(start);
if (!starti) {
if (!isa<Argument>(start))
continue;
starti =
&cast<Argument>(start)->getParent()->getEntryBlock().front();
}

bool overwritten = false;
allInstructionsBetween(
LI, starti, endi, [&](Instruction *I) -> bool {
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;

if (writesToMemoryReadBy(nullptr, AA, TLI,
/*maybeReader*/ ld,
/*maybeWriter*/ I)) {
overwritten = true;
return /*earlyBreak*/ true;
}
return /*earlyBreak*/ false;
});

if (!overwritten) {
return true;
}
}
}
}
}

return {};
}

bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
llvm::AAResults &AA, llvm::LoopInfo &LI) {
bool changed = false;
Expand Down Expand Up @@ -198,7 +116,7 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
changed = true;
continue;
}
}
}
}

return changed;
Expand Down
95 changes: 95 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2218,7 +2218,10 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
const SCEV *StoreBegin = SE.getCouldNotCompute();
const SCEV *StoreEnd = SE.getCouldNotCompute();

Value *loadPtr = nullptr;
Value *storePtr = nullptr;
if (auto LI = dyn_cast<LoadInst>(maybeReader)) {
loadPtr = LI->getPointerOperand();
LoadBegin = SE.getSCEV(LI->getPointerOperand());
if (LoadBegin != SE.getCouldNotCompute() &&
!LoadBegin->getType()->isIntegerTy()) {
Expand All @@ -2236,6 +2239,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
}
}
if (auto SI = dyn_cast<StoreInst>(maybeWriter)) {
storePtr = SI->getPointerOperand();
StoreBegin = SE.getSCEV(SI->getPointerOperand());
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
Expand All @@ -2255,6 +2259,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
}
}
if (auto MS = dyn_cast<MemSetInst>(maybeWriter)) {
storePtr = MS->getArgOperand(0);
StoreBegin = SE.getSCEV(MS->getArgOperand(0));
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
Expand All @@ -2269,6 +2274,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
}
}
if (auto MS = dyn_cast<MemTransferInst>(maybeWriter)) {
storePtr = MS->getArgOperand(0);
StoreBegin = SE.getSCEV(MS->getArgOperand(0));
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
Expand All @@ -2283,6 +2289,7 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
}
}
if (auto MS = dyn_cast<MemTransferInst>(maybeReader)) {
loadPtr = MS->getArgOperand(1);
LoadBegin = SE.getSCEV(MS->getArgOperand(1));
if (LoadBegin != SE.getCouldNotCompute() &&
!LoadBegin->getType()->isIntegerTy()) {
Expand All @@ -2297,6 +2304,12 @@ bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
}
}

if (loadPtr && storePtr)
if (auto alias =
arePointersGuaranteedNoAlias(TLI, AA, LI, loadPtr, storePtr, true))
if (alias.getValue())
return false;

if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd,
maybeWriter, StoreBegin, StoreEnd, scope))
return false;
Expand Down Expand Up @@ -3887,3 +3900,85 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst,
}
return true;
}

// Return true if guaranteed not to alias
// Return false if guaranteed to alias [with possible offset depending on flag].
// Return {} if no information is given.
#if LLVM_VERSION_MAJOR >= 16
std::optional<bool>
#else
llvm::Optional<bool>
#endif
arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA,
llvm::LoopInfo &LI, llvm::Value *op0,
llvm::Value *op1, bool offsetAllowed) {
auto lhs = getBaseObject(op0, offsetAllowed);
auto rhs = getBaseObject(op1, offsetAllowed);

if (lhs == rhs) {
return false;
}
if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy())
return {};

bool noalias_lhs = isNoAlias(lhs);
bool noalias_rhs = isNoAlias(rhs);

bool noalias[2] = {noalias_lhs, noalias_rhs};

for (int i = 0; i < 2; i++) {
Value *start = (i == 0) ? lhs : rhs;
Value *end = (i == 0) ? rhs : lhs;
if (noalias[i]) {
if (noalias[1 - i]) {
return true;
}
if (isa<Argument>(end)) {
return true;
}
if (auto endi = dyn_cast<Instruction>(end)) {
if (notCapturedBefore(start, endi, 0)) {
return true;
}
}
}
if (auto ld = dyn_cast<LoadInst>(start)) {
auto base = getBaseObject(ld->getOperand(0), /*offsetAllowed*/ false);
if (isAllocationCall(base, TLI)) {
if (isa<Argument>(end))
return true;
if (auto endi = dyn_cast<Instruction>(end))
if (isNoAlias(end) || (notCapturedBefore(start, endi, 1))) {
Instruction *starti = dyn_cast<Instruction>(start);
if (!starti) {
if (!isa<Argument>(start))
continue;
starti =
&cast<Argument>(start)->getParent()->getEntryBlock().front();
}

bool overwritten = false;
allInstructionsBetween(
LI, starti, endi, [&](Instruction *I) -> bool {
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;

if (writesToMemoryReadBy(nullptr, AA, TLI,
/*maybeReader*/ ld,
/*maybeWriter*/ I)) {
overwritten = true;
return /*earlyBreak*/ true;
}
return /*earlyBreak*/ false;
});

if (!overwritten) {
return true;
}
}
}
}
}

return {};
}
12 changes: 12 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2104,4 +2104,16 @@ bool isNVLoad(const llvm::Value *V);
bool notCapturedBefore(llvm::Value *V, llvm::Instruction *inst,
size_t checkLoadCaptured);

// Return true if guaranteed not to alias
// Return false if guaranteed to alias [with possible offset depending on flag].
// Return {} if no information is given.
#if LLVM_VERSION_MAJOR >= 16
std::optional<bool>
#else
llvm::Optional<bool>
#endif
arePointersGuaranteedNoAlias(llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA,
llvm::LoopInfo &LI, llvm::Value *op0,
llvm::Value *op1, bool offsetAllowed = false);

#endif // ENZYME_UTILS_H

0 comments on commit f2f0ae4

Please sign in to comment.