Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 12, 2024
1 parent 50a63aa commit bf22e01
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 37 deletions.
13 changes: 8 additions & 5 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3551,15 +3551,18 @@ void createInvertedTerminator(DiffeGradientUtils *gutils,
}

Function *EnzymeLogic::CreatePrimalAndGradient(
RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA,
RequestContext context, const ReverseCacheKey &&prevkey, TypeAnalysis &TA,
const AugmentedReturn *augmenteddata, bool omp) {

TimeTraceScope timeScope("CreatePrimalAndGradient", key.todiff->getName());
TimeTraceScope timeScope("CreatePrimalAndGradient",
prevkey.todiff->getName());

assert(key.mode == DerivativeMode::ReverseModeCombined ||
key.mode == DerivativeMode::ReverseModeGradient);
assert(prevkey.mode == DerivativeMode::ReverseModeCombined ||
prevkey.mode == DerivativeMode::ReverseModeGradient);

FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(key.typeInfo, key.todiff);
FnTypeInfo oldTypeInfo =
preventTypeAnalysisLoops(prevkey.typeInfo, prevkey.todiff);
auto key = prevkey.replaceTypeInfo(oldTypeInfo);

if (key.retType != DIFFE_TYPE::CONSTANT)
assert(!key.todiff->getReturnType()->isVoidTy());
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ struct ReverseCacheKey {
const FnTypeInfo typeInfo;
bool runtimeActivity;

ReverseCacheKey replaceTypeInfo(const FnTypeInfo &newTypeInfo) const {
return {todiff, retType, constant_args, overwritten_args,
returnUsed, shadowReturnUsed, mode, width,
freeMemory, AtomicAdd, additionalType, forceAnonymousTape,
newTypeInfo, runtimeActivity};
}
/*
inline bool operator==(const ReverseCacheKey& rhs) const {
return todiff == rhs.todiff &&
Expand Down
89 changes: 57 additions & 32 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3850,57 +3850,82 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst,
todo.push_back(BB);
}
}
SmallVector<std::pair<Value *, size_t>, 1> todo = {
std::make_pair(V, checkLoadCaptures)};
std::set<std::pair<Value *, size_t>> seen;
SmallVector<std::tuple<Instruction *, size_t, Value *>, 1> todo;
for (auto U : V->users()) {
todo.emplace_back(cast<Instruction>(U), checkLoadCaptures, V);
}
std::set<std::tuple<Value *, size_t, Value *>> seen;
while (todo.size()) {
auto pair = todo.pop_back_val();
if (seen.count(pair))
continue;
auto cur = pair.first;
for (auto U : cur->users()) {
auto UI = dyn_cast<Instruction>(U);
if (!regionBetween.count(UI->getParent()))
auto UI = std::get<0>(pair);
auto level = std::get<1>(pair);
auto prev = std::get<2>(pair);
if (!regionBetween.count(UI->getParent()))
continue;
if (UI->getParent() == VI->getParent()) {
if (UI->comesBefore(VI))
continue;
if (UI->getParent() == VI->getParent()) {
if (UI->comesBefore(VI))
continue;
}
if (UI->getParent() == inst->getParent())
if (inst->comesBefore(UI))
continue;

if (isPointerArithmeticInst(UI, /*includephi*/ true,
/*includebin*/ true)) {
for (auto U2 : UI->users()) {
auto UI2 = cast<Instruction>(U2);
todo.emplace_back(UI2, level, UI);
}
if (UI->getParent() == inst->getParent())
if (inst->comesBefore(UI))
continue;
continue;
}

if (isa<MemSetInst>(UI))
continue;

if (isPointerArithmeticInst(UI, /*includephi*/ true,
/*includebin*/ true)) {
todo.emplace_back(UI, pair.second);
if (isa<MemTransferInst>(UI)) {
if (level == 0)
continue;
}
if (UI->getOperand(1) != prev)
continue;
}

if (auto CI = dyn_cast<CallBase>(UI)) {
if (auto CI = dyn_cast<CallBase>(UI)) {
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 0, size = CI->arg_size(); i < size; i++)
for (size_t i = 0, size = CI->arg_size(); i < size; i++)
#else
for (size_t i = 0, size = CI->getNumArgOperands(); i < size; i++)
for (size_t i = 0, size = CI->getNumArgOperands(); i < size; i++)
#endif
{
if (cur == CI->getArgOperand(i)) {
if (isNoCapture(CI, i))
continue;
return false;
}
{
if (prev == CI->getArgOperand(i)) {
if (isNoCapture(CI, i) && level == 0)
continue;
return false;
}
return true;
}
return true;
}

if (isa<CmpInst>(UI)) {
continue;
if (isa<CmpInst>(UI)) {
continue;
}
if (isa<LoadInst>(UI)) {
if (level) {
for (auto U2 : UI->users()) {
auto UI2 = cast<Instruction>(U2);
todo.emplace_back(UI2, level - 1, UI);
}
}
if (isa<LoadInst>(UI) && pair.second) {
todo.emplace_back(UI, pair.second - 1);
continue;
}
// storing into it.
if (auto SI = dyn_cast<StoreInst>(UI)) {
if (SI->getValueOperand() != prev) {
continue;
}
return false;
}
return false;
}
return true;
}
Expand Down

0 comments on commit bf22e01

Please sign in to comment.