From a3415b850c36d28eb06e1f00fdfb9dfb1a3a68ff Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 29 Nov 2023 01:17:04 +0200 Subject: [PATCH] Move name collision handling to ReferencesUpdater. --- include/clad/Differentiator/StmtClone.h | 5 ++++- lib/Differentiator/ReverseModeVisitor.cpp | 6 +----- lib/Differentiator/StmtClone.cpp | 15 +++++++++++++-- lib/Differentiator/VisitorBase.cpp | 4 ++-- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/include/clad/Differentiator/StmtClone.h b/include/clad/Differentiator/StmtClone.h index cf32b813f..6e5ae7f7d 100644 --- a/include/clad/Differentiator/StmtClone.h +++ b/include/clad/Differentiator/StmtClone.h @@ -153,9 +153,12 @@ namespace utils { clang::Sema& m_Sema; // We don't own. clang::Scope* m_CurScope; // We don't own. const clang::FunctionDecl* m_Function; // We don't own. + const std::unordered_map& m_DeclReplacements; // We don't own. public: ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S, - const clang::FunctionDecl* FD); + const clang::FunctionDecl* FD, + const std::unordered_map& DeclReplacements); bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); bool VisitStmt(clang::Stmt* S); /// Used to update the size expression of QT diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4a9ca6c11..d7d3438d2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1328,11 +1328,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Check if referenced Decl was "replaced" with another identifier inside // the derivative if (const auto* VD = dyn_cast(DRE->getDecl())) { - auto it = m_DeclReplacements.find(VD); - if (it != std::end(m_DeclReplacements)) - clonedDRE = BuildDeclRef(it->second); - else - clonedDRE = cast(Clone(DRE)); + clonedDRE = cast(Clone(DRE)); // If current context is different than the context of the original // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index e14c7fa78..ed26cbfd3 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -502,8 +502,10 @@ Stmt* StmtClone::VisitStmt(Stmt*) { } ReferencesUpdater::ReferencesUpdater(Sema& SemaRef, Scope* S, - const FunctionDecl* FD) - : m_Sema(SemaRef), m_CurScope(S), m_Function(FD) {} + const FunctionDecl* FD, + const std::unordered_map& DeclReplacements) + : m_Sema(SemaRef), m_CurScope(S), m_Function(FD), m_DeclReplacements(DeclReplacements) {} bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) { // We should only update references of the declarations that were inside @@ -511,6 +513,15 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) { // Original function = function that we are currently differentiating. if (!DRE->getDecl()->getDeclContext()->Encloses(m_Function)) return true; + + // Replace the declaration if it is present in `m_DeclReplacements`. + if (VarDecl* VD = dyn_cast(DRE->getDecl())) { + auto it = m_DeclReplacements.find(VD); + if (it != std::end(m_DeclReplacements)) { + DRE->setDecl(it->second); + } + } + DeclarationNameInfo DNI = DRE->getNameInfo(); LookupResult R(m_Sema, DNI, Sema::LookupOrdinaryName); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 3ed0b24fb..391022301 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -84,7 +84,7 @@ namespace clad { } void VisitorBase::updateReferencesOf(Stmt* InSubtree) { - utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function); + utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function, m_DeclReplacements); up.TraverseStmt(InSubtree); } @@ -304,7 +304,7 @@ namespace clad { QualType VisitorBase::CloneType(const QualType QT) { auto clonedType = m_Builder.m_NodeCloner->CloneType(QT); - utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function); + utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function, m_DeclReplacements); up.updateType(clonedType); return clonedType; }