From e46c34d2856b69bc60d8e23676563f8aa275b3f4 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Fri, 6 Oct 2023 23:53:11 +0300 Subject: [PATCH] Fix and simplify reference vars analysis. --- include/clad/Differentiator/TBRAnalyzer.h | 61 +++--- lib/Differentiator/TBRAnalyzer.cpp | 237 ++++++++++------------ 2 files changed, 130 insertions(+), 168 deletions(-) diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index f2943ab23..b2677e4e7 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -116,12 +116,11 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// less space. ObjMap* objData; ArrMap* arrData; - VarData* refData; + Expr* refData; VarDataValue() : fundData(false) {} }; VarDataType type; VarDataValue val; - bool isReferenced = false; VarData() = default; VarData(const VarData&) = delete; @@ -140,34 +139,29 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { for (auto& pair : *val.arrData) delete pair.second; } - - /// Recursively sets all the leaves' bools to isReq. - void setIsRequired(bool isReq = true); - /// Returns true if there is at least one required to store node among - /// child nodes. - bool findReq() const; - /// Whenever an array element with a non-constant index is set to required - /// this function is used to set to required all the array elements that - /// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required - /// when 'a[k].y' is set to required). Takes unwrapped sequence of - /// indices/members of the expression being overlaid and the index of of the - /// current index/member. - void overlay(llvm::SmallVector& IdxAndMemberSequence, - size_t i); - /// Used to merge together VarData for one variable from two branches - /// (e.g. after an if-else statements). Look at the Control Flow section for - /// more information. - void merge(VarData* mergeData); - /// Used to recursively copy VarData when separating into different branches - /// (e.g. when entering an if-else statements). Look at the Control Flow - /// section for more information. refVars stores copied nodes for - /// corresponding original nodes in case those are referenced (a referenced - /// node is a child to multiple nodes, therefore, we need to make sure we - /// don't make multiple copies of it). - VarData* copy(); - VarData* copy(std::unordered_map& refVars); - void restoreRefs(std::unordered_map& refVars); }; + /// Recursively sets all the leaves' bools to isReq. + void setIsRequired(VarData* varData, bool isReq = true); + /// Whenever an array element with a non-constant index is set to required + /// this function is used to set to required all the array elements that + /// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required + /// when 'a[k].y' is set to required). Takes unwrapped sequence of + /// indices/members of the expression being overlaid and the index of of the + /// current index/member. + void overlay(VarData* targetData, + llvm::SmallVector& IdxAndMemberSequence, + size_t i); + /// Returns true if there is at least one required to store node among + /// child nodes. + bool findReq(const VarData* varData); + /// Used to merge together VarData for one variable from two branches + /// (e.g. after an if-else statements). Look at the Control Flow section for + /// more information. + void merge(VarData* targetData, VarData* mergeData); + /// Used to recursively copy VarData when separating into different branches + /// (e.g. when entering an if-else statements). Look at the Control Flow + /// section for more information. + VarData* copy(VarData* copyData); clang::CFGBlock* getCFGBlockByID(unsigned ID); @@ -215,12 +209,11 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { void emplace(std::pair pair) { data.emplace(pair); } - - std::unique_ptr - collectDataFromPredecessors(VarsData* limit = nullptr); - VarsData* findLowestCommonAncestor(VarsData* other); - void merge(VarsData* mergeData); }; + std::unique_ptr + collectDataFromPredecessors(VarsData* varsData, VarsData* limit = nullptr); + VarsData* findLowestCommonAncestor(VarsData* varsData1, VarsData* varsData2); + void merge(VarsData* targetData, VarsData* mergeData); /// Used to find DeclRefExpr's that will be used in the backwards pass. /// In order to be marked as required, a variables has to appear in a place diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 7045dac4d..62f826781 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -4,123 +4,100 @@ using namespace clang; namespace clad { -void TBRAnalyzer::VarData::setIsRequired(bool isReq) { - if (type == FUND_TYPE) - val.fundData = isReq; - else if (type == OBJ_TYPE) - for (auto& pair : *val.objData) - pair.second->setIsRequired(isReq); - else if (type == ARR_TYPE) - for (auto& pair : *val.arrData) - pair.second->setIsRequired(isReq); - else if (type == REF_TYPE && val.refData) - val.refData->setIsRequired(isReq); -} - -void TBRAnalyzer::VarData::merge(VarData* mergeData) { - if (this->type == FUND_TYPE) { - this->val.fundData = this->val.fundData || mergeData->val.fundData; - } else if (this->type == OBJ_TYPE) { - for (auto& pair : *this->val.objData) - pair.second->merge((*mergeData->val.objData)[pair.first]); - } else if (this->type == ARR_TYPE) { +void TBRAnalyzer::setIsRequired(VarData* varData, bool isReq) { + if (varData->type == VarData::FUND_TYPE) + varData->val.fundData = isReq; + else if (varData->type == VarData::OBJ_TYPE) + for (auto& pair : *varData->val.objData) + setIsRequired(pair.second, isReq); + else if (varData->type == VarData::ARR_TYPE) + for (auto& pair : *varData->val.arrData) + setIsRequired(pair.second, isReq); + else if (varData->type == VarData::REF_TYPE && varData->val.refData) + setIsRequired(getExprVarData(varData->val.refData), isReq); +} + +void TBRAnalyzer::merge(VarData* targetData, VarData* mergeData) { + if (targetData->type == VarData::FUND_TYPE) { + targetData->val.fundData = + targetData->val.fundData || mergeData->val.fundData; + } else if (targetData->type == VarData::OBJ_TYPE) { + for (auto& pair : *targetData->val.objData) + merge(pair.second, (*mergeData->val.objData)[pair.first]); + } else if (targetData->type == VarData::ARR_TYPE) { /// FIXME: Currently non-constant indices are not supported in merging. - for (auto& pair : *this->val.arrData) { + for (auto& pair : *targetData->val.arrData) { auto it = mergeData->val.arrData->find(pair.first); if (it != mergeData->val.arrData->end()) - pair.second->merge(it->second); + merge(pair.second, it->second); } for (auto& pair : *mergeData->val.arrData) { - auto it = this->val.arrData->find(pair.first); - if (it == mergeData->val.arrData->end()) { - std::unordered_map refVars; - (*this->val.arrData)[pair.first] = pair.second->copy(refVars); - } + auto it = targetData->val.arrData->find(pair.first); + if (it == mergeData->val.arrData->end()) + (*targetData->val.arrData)[pair.first] = copy(pair.second); } - } else if (this->type == REF_TYPE && this->val.refData) { - /// FIXME: add support for merging references. - // this->val.refData->merge(mergeData->val.refData); } + /// This might be useful in future if used to analyse pointers. However, for + /// now it's only used for references for which merging doesn't make sense. + // else if (this->type == VarData::REF_TYPE) {} } -TBRAnalyzer::VarData* TBRAnalyzer::VarData::copy() { - std::unordered_map refVars; - VarData* res = copy(refVars); - res->restoreRefs(refVars); - return res; -} - -TBRAnalyzer::VarData* -TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { +TBRAnalyzer::VarData* TBRAnalyzer::copy(VarData* copyData) { auto* res = new VarData(); - /// The child node of a reference node should be copied only once. Hence, - /// we use refVars to match original referenced nodes to corresponding copies. - if (isReferenced) - refVars[this] = res; - res->type = this->type; - if (this->type == FUND_TYPE) { - res->val.fundData = this->val.fundData; - } else if (this->type == OBJ_TYPE) { + res->type = copyData->type; + if (copyData->type == VarData::FUND_TYPE) { + res->val.fundData = copyData->val.fundData; + } else if (copyData->type == VarData::OBJ_TYPE) { res->val.objData = new ObjMap(); - for (auto& pair : *this->val.objData) - (*res->val.objData)[pair.first] = pair.second->copy(refVars); - } else if (this->type == ARR_TYPE) { + for (auto& pair : *copyData->val.objData) + (*res->val.objData)[pair.first] = copy(pair.second); + } else if (copyData->type == VarData::ARR_TYPE) { res->val.arrData = new ArrMap(); - for (auto& pair : *this->val.arrData) - (*res->val.arrData)[pair.first] = pair.second->copy(refVars); - } else if (this->type == REF_TYPE && this->val.refData) { - res->val.refData = this->val.refData; + for (auto& pair : *copyData->val.arrData) + (*res->val.arrData)[pair.first] = copy(pair.second); + } else if (copyData->type == VarData::REF_TYPE && copyData->val.refData) { + res->val.refData = copyData->val.refData; } return res; } -void TBRAnalyzer::VarData::restoreRefs( - std::unordered_map& refVars) { - if (this->type == OBJ_TYPE) - for (auto& pair : *val.objData) - pair.second->restoreRefs(refVars); - else if (this->type == ARR_TYPE) - for (auto& pair : *this->val.arrData) - pair.second->restoreRefs(refVars); - else if (this->type == REF_TYPE && this->val.refData) - this->val.refData = refVars[this->val.refData]; -} - -bool TBRAnalyzer::VarData::findReq() const { - if (type == FUND_TYPE) - return val.fundData; - if (type == OBJ_TYPE) { - for (auto& pair : *val.objData) - if (pair.second->findReq()) +bool TBRAnalyzer::findReq(const VarData* varData) { + if (varData->type == VarData::FUND_TYPE) + return varData->val.fundData; + if (varData->type == VarData::OBJ_TYPE) { + for (auto& pair : *varData->val.objData) + if (findReq(pair.second)) return true; - } else if (type == ARR_TYPE) { - for (auto& pair : *val.arrData) - if (pair.second->findReq()) + } else if (varData->type == VarData::ARR_TYPE) { + for (auto& pair : *varData->val.arrData) + if (findReq(pair.second)) return true; - } else if (type == REF_TYPE && val.refData) { - if (val.refData->findReq()) + } else if (varData->type == VarData::REF_TYPE && varData->val.refData) { + if (findReq(getExprVarData(varData->val.refData))) return true; } return false; } -void TBRAnalyzer::VarData::overlay( +void TBRAnalyzer::overlay( + VarData* targetData, llvm::SmallVector& IdxAndMemberSequence, size_t i) { if (i == 0) { - setIsRequired(); + setIsRequired(targetData); return; } --i; IdxOrMember& curIdxOrMember = IdxAndMemberSequence[i]; if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::FIELD) { - (*val.objData)[curIdxOrMember.val.field]->overlay(IdxAndMemberSequence, i); + overlay((*targetData->val.objData)[curIdxOrMember.val.field], + IdxAndMemberSequence, i); } else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) { auto idx = curIdxOrMember.val.index; if (eqAPInt(idx, llvm::APInt(2, -1, true))) - for (auto& pair : *val.arrData) - pair.second->overlay(IdxAndMemberSequence, i); + for (auto& pair : *targetData->val.arrData) + overlay(pair.second, IdxAndMemberSequence, i); else - (*val.arrData)[idx]->overlay(IdxAndMemberSequence, i); + overlay((*targetData->val.arrData)[idx], IdxAndMemberSequence, i); } } @@ -129,9 +106,6 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, if (const auto* FD = dyn_cast(ME->getMemberDecl())) { const auto* base = ME->getBase(); VarData* baseData = getExprVarData(base); - /// If the VarData is ref type just go to the VarData being referenced. - if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) - baseData = baseData->val.refData; if (!baseData) return nullptr; @@ -161,9 +135,6 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, const auto* base = ASE->getBase()->IgnoreImpCasts(); VarData* baseData = getExprVarData(base); - /// If the VarData is ref type just go to the VarData being referenced. - if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) - baseData = baseData->val.refData; if (!baseData) return nullptr; @@ -182,8 +153,7 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, /// Since -1 represents non-const indices, whenever we add a new index we /// have to copy the VarData of -1's element (if an element with undefined /// index was used this might be our current element). - std::unordered_map dummy; - idxData = (*baseArrMap)[llvm::APInt(2, -1, true)]->copy(dummy); + idxData = copy((*baseArrMap)[llvm::APInt(2, -1, true)]); return idxData; } @@ -201,7 +171,6 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, /// ``this`` does not have a declaration so it is represented with nullptr. if (const auto* DRE = dyn_cast(E)) VD = dyn_cast(DRE->getDecl()); - auto* branch = &getCurBranch(); while (branch) { auto it = branch->find(VD); @@ -217,15 +186,18 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, if (const auto* ASE = dyn_cast(E)) EData = getArrSubVarData(ASE, addNonConstIdx); + if (EData && EData->type == VarData::REF_TYPE && EData->val.refData) + EData = getExprVarData(EData->val.refData); + return EData; } TBRAnalyzer::VarData::VarData(const QualType QT) { if (QT->isReferenceType()) { - type = VarData::VarDataType::REF_TYPE; + type = VarData::REF_TYPE; val.refData = nullptr; } else if (utils::isArrayOrPointerType(QT)) { - type = VarData::VarDataType::ARR_TYPE; + type = VarData::ARR_TYPE; val.arrData = new ArrMap(); const Type* elemType; if (const auto pointerType = llvm::dyn_cast(QT)) @@ -235,10 +207,10 @@ TBRAnalyzer::VarData::VarData(const QualType QT) { auto& idxData = (*val.arrData)[llvm::APInt(2, -1, true)]; idxData = new VarData(QualType::getFromOpaquePtr(elemType)); } else if (QT->isBuiltinType()) { - type = VarData::VarDataType::FUND_TYPE; + type = VarData::FUND_TYPE; val.fundData = false; } else if (QT->isRecordType()) { - type = VarData::VarDataType::OBJ_TYPE; + type = VarData::OBJ_TYPE; const auto* recordDecl = QT->getAs()->getDecl(); auto& newObjMap = val.objData; newObjMap = new ObjMap(); @@ -276,8 +248,8 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { /// Overlay on all the VarData's recursively. if (const auto* VD = dyn_cast(innermostDRE->getDecl())) { - getCurBranch()[VD]->overlay(IdxAndMemberSequence, - IdxAndMemberSequence.size()); + overlay(getCurBranch()[VD], IdxAndMemberSequence, + IdxAndMemberSequence.size()); } } @@ -292,7 +264,7 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { while (branch) { auto it = branch->find(VD); if (it != branch->end()) { - curBranch[VD] = it->second->copy(); + curBranch[VD] = copy(it->second); return; } branch = branch->prev; @@ -332,7 +304,7 @@ void TBRAnalyzer::markLocation(const clang::Expr* E) { /// required to be stored (when passing *= operator) but then marked as not /// required to be stored (when passing = operator). Current method of /// marking locations does not allow to differentiate between these two. - ToBeRec = ToBeRec || data->findReq(); + ToBeRec = ToBeRec || findReq(data); } else /// If the current branch is going to be deleted then there is not point in /// storing anything in it. @@ -344,7 +316,7 @@ void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { (modeStack.back() == (Mode::markingMode | Mode::nonLinearMode))) { VarData* data = getExprVarData(E, /*addNonConstIdx=*/isReq); if (isReq || !nonConstIndexFound) - data->setIsRequired(isReq); + setIsRequired(data, isReq); /// If an array element with a non-const element is set to required /// all the elements of that array should be set to required. if (isReq && nonConstIndexFound) @@ -410,7 +382,7 @@ void TBRAnalyzer::VisitCFGBlock(CFGBlock* block) { varsData = new VarsData(); varsData->prev = blockData[block->getBlockID()]; } else if (varsData->prev != blockData[block->getBlockID()]) { - varsData->merge(blockData[block->getBlockID()]); + merge(varsData, blockData[block->getBlockID()]); } if (notLastPass) { CFGQueue.insert(succ->getBlockID()); @@ -426,21 +398,22 @@ CFGBlock* TBRAnalyzer::getCFGBlockByID(unsigned ID) { } TBRAnalyzer::VarsData* -TBRAnalyzer::VarsData::findLowestCommonAncestor(TBRAnalyzer::VarsData* other) { - VarsData* pred1 = this; - VarsData* pred2 = other; +TBRAnalyzer::findLowestCommonAncestor(VarsData* varsData1, + VarsData* varsData2) { + VarsData* pred1 = varsData1; + VarsData* pred2 = varsData2; while (true) { if (pred1 == pred2) return pred1; - auto branch = this; + auto branch = varsData1; while (branch != pred1) { if (branch == pred2) return branch; branch = branch->prev; } - branch = other; + branch = varsData2; while (branch != pred2) { if (branch == pred1) return branch; @@ -451,7 +424,7 @@ TBRAnalyzer::VarsData::findLowestCommonAncestor(TBRAnalyzer::VarsData* other) { pred1 = pred1->prev; /// This ensures we don't get an infinite loop because of VarsData being /// connected in a loop themselves. - if (pred1 == this) + if (pred1 == varsData1) return nullptr; } else { /// pred1 not having a predecessor means it is corresponds to the entry @@ -463,7 +436,7 @@ TBRAnalyzer::VarsData::findLowestCommonAncestor(TBRAnalyzer::VarsData* other) { pred2 = pred2->prev; /// This ensures we don't get an infinite loop because of VarsData being /// connected in a loop themselves. - if (pred2 == other) + if (pred2 == varsData2) return nullptr; } else { /// pred2 not having a predecessor means it is corresponds to the entry @@ -476,11 +449,11 @@ TBRAnalyzer::VarsData::findLowestCommonAncestor(TBRAnalyzer::VarsData* other) { } std::unique_ptr -TBRAnalyzer::VarsData::collectDataFromPredecessors( - TBRAnalyzer::VarsData* limit) { - auto result = std::unique_ptr(new VarsData(*this)); - if (this != limit) { - auto pred = this->prev; +TBRAnalyzer::collectDataFromPredecessors(VarsData* varsData, + TBRAnalyzer::VarsData* limit) { + auto result = std::unique_ptr(new VarsData(*varsData)); + if (varsData != limit) { + auto pred = varsData->prev; while (pred != limit) { for (auto pair : *pred) if (result->find(pair.first) == result->end()) @@ -492,21 +465,21 @@ TBRAnalyzer::VarsData::collectDataFromPredecessors( return result; } -void TBRAnalyzer::VarsData::merge(TBRAnalyzer::VarsData* mergeData) { - auto* LCA = this->findLowestCommonAncestor(mergeData); +void TBRAnalyzer::merge(VarsData* targetData, VarsData* mergeData) { + auto* LCA = findLowestCommonAncestor(targetData, mergeData); auto collectedMergeData = - mergeData->collectDataFromPredecessors(/*limit=*/LCA); + collectDataFromPredecessors(mergeData, /*limit=*/LCA); for (auto& pair : *collectedMergeData) { VarData* found = nullptr; - auto elemSearch = this->find(pair.first); - if (elemSearch == this->end()) { - auto* branch = this->prev; + auto elemSearch = targetData->find(pair.first); + if (elemSearch == targetData->end()) { + auto* branch = targetData->prev; while (branch) { auto it = branch->find(pair.first); if (it != branch->end()) { - found = it->second->copy(); - this->emplace(pair.first, found); + found = copy(it->second); + targetData->emplace(pair.first, found); break; } branch = branch->prev; @@ -516,20 +489,20 @@ void TBRAnalyzer::VarsData::merge(TBRAnalyzer::VarsData* mergeData) { } if (found) - found->merge(pair.second); + merge(found, pair.second); else - this->emplace(pair.first, pair.second->copy()); + targetData->emplace(pair.first, copy(pair.second)); } - auto collectedThis = this->collectDataFromPredecessors(/*limit=*/LCA); + auto collectedThis = collectDataFromPredecessors(targetData, /*limit=*/LCA); for (auto& pair : *collectedThis) { auto elemSearch = mergeData->find(pair.first); - if (elemSearch == this->end()) { + if (elemSearch == targetData->end()) { auto* branch = LCA; while (branch) { auto it = branch->find(pair.first); if (it != branch->end()) { - pair.second->merge(it->second); + merge(pair.second, it->second); break; } branch = branch->prev; @@ -590,12 +563,8 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { /// if the declared variable is ref type attach its VarData* to the /// VarData* of the RHS variable. auto returnExprs = utils::GetInnermostReturnExpr(init); - if (VDExpr->type == VarData::VarDataType::REF_TYPE && - !returnExprs.empty()) { - auto* RHSExpr = getExprVarData(returnExprs[0]); - VDExpr->val.refData = RHSExpr; - RHSExpr->isReferenced = true; - } + if (VDExpr->type == VarData::REF_TYPE && !returnExprs.empty()) + VDExpr->val.refData = returnExprs[0]; } } } @@ -617,7 +586,7 @@ void TBRAnalyzer::VisitConditionalOperator( blockData[curBlockID] = elseBranch; Visit(CO->getFalseExpr()); - elseBranch->merge(thenBranch); + merge(elseBranch, thenBranch); delete thenBranch; }