Skip to content

Commit

Permalink
Improve CallExpr analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Dec 12, 2024
1 parent f3eeeaf commit fac1aee
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 193 deletions.
10 changes: 2 additions & 8 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ struct DiffRequest {
} m_TbrRunInfo;

mutable struct ActivityRunInfo {
std::set<const clang::VarDecl*> ToBeRecorded;
bool HasAnalysisRun = false;
} m_ActivityRunInfo;

public:
/// All varied declarations.
static std::set<const clang::VarDecl*> AllVariedDecls;
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
Expand Down Expand Up @@ -144,13 +145,6 @@ struct DiffRequest {

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;

void setToBeRecorded(std::set<const clang::VarDecl*> init) {
this->m_ActivityRunInfo.ToBeRecorded = init;
}
std::set<const clang::VarDecl*> getToBeRecorded() const {
return this->m_ActivityRunInfo.ToBeRecorded;
}
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
39 changes: 19 additions & 20 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
FunctionDecl* FD = CE->getDirectCallee();
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
bool restoreMarking = m_Marking;
bool restoreVaried = m_Varied;
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
Expand All @@ -130,25 +132,24 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
while (innermostType->isPointerType())
innermostType = innermostType->getPointeeType();

if ((parType->isReferenceType() ||
utils::isArrayOrPointerType(parType)) &&
!innermostType.isConstQualified()) {
m_Marking = true;
m_Varied = true;
}

m_Varied = false;
m_Marking = false;
TraverseStmt(par);
if ((parType->isReferenceType() ||
utils::isArrayOrPointerType(parType)) &&
!innermostType.isConstQualified()) {
m_Marking = false; //?
m_Varied = false;
}

if ((m_Varied || !innermostType.isConstQualified()))
if (m_Varied)
m_VariedDecls.insert(FDparam[i]);
else if ((parType->isReferenceType() ||
(utils::isArrayOrPointerType(parType) &&
!innermostType.isConstQualified()))) {
m_Varied = true;
m_Marking = true;
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
}
m_Varied = restoreVaried;
m_Marking = restoreMarking;
}

return true;
}

Expand All @@ -159,12 +160,10 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
QualType innermost = VDTy;
while (innermost->isPointerType())
innermost = innermost->getPointeeType();
if (VDTy->isPointerType() && !innermost.isConstQualified()) {
copyVarToCurBlock(cast<VarDecl>(D));
continue;
} else if (VDTy->isArrayType()) {
if (VDTy->isArrayType() ||
(VDTy->isPointerType() && !innermost.isConstQualified())) {
copyVarToCurBlock(cast<VarDecl>(D));
continue;
m_Varied = true;
}

if (Expr* init = cast<VarDecl>(D)->getInit()) {
Expand Down
Loading

0 comments on commit fac1aee

Please sign in to comment.