From 8ed27077850651b699992327170b446323e2cc68 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 3 Oct 2024 23:23:24 +0300 Subject: [PATCH] Remove m_Functor from VisitorBase since it is already stored in the diff request --- include/clad/Differentiator/VisitorBase.h | 2 -- lib/Differentiator/BaseForwardModeVisitor.cpp | 20 +++++++++---------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index e78149d3d..917088c42 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -132,8 +132,6 @@ namespace clad { std::vector m_Blocks; /// Stores output variables for vector-valued functions VectorOutputs m_VectorOutput; - /// The functor type that is currently being differentiated, if any. - const clang::CXXRecordDecl* m_Functor = nullptr; /// Stores derivative expression of the implicit `this` pointer. /// /// In the forward mode, `this` pointer derivative expression is of pointer diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 0d0bbcb47..fe14227b7 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -60,7 +60,6 @@ bool IsRealNonReferenceType(QualType T) { } DerivativeAndOverload BaseForwardModeVisitor::Derive() { - m_Functor = m_DiffReq.Functor; const FunctionDecl* FD = m_DiffReq.Function; assert(m_DiffReq.Mode == DiffMode::forward); assert(!m_DerivativeInFlight && @@ -138,11 +137,11 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { // class defining the call operator. // Thus, we need to find index of the member variable instead. unsigned argIndex = ~0; - if (m_DiffReq->param_empty() && m_Functor) - argIndex = - std::distance(m_Functor->field_begin(), - std::find(m_Functor->field_begin(), - m_Functor->field_end(), m_IndependentVar)); + const CXXRecordDecl* functor = m_DiffReq.Functor; + if (m_DiffReq->param_empty() && functor) + argIndex = std::distance(functor->field_begin(), + std::find(functor->field_begin(), + functor->field_end(), m_IndependentVar)); else argIndex = std::distance( FD->param_begin(), @@ -296,8 +295,8 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { // Create derived variable for each member variable if we are // differentiating a call operator. - if (m_Functor) { - for (FieldDecl* fieldDecl : m_Functor->fields()) { + if (m_DiffReq.Functor) { + for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) { Expr* dInitializer = nullptr; QualType fieldType = fieldDecl->getType(); @@ -400,7 +399,6 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { const FunctionDecl* FD = m_DiffReq.Function; - m_Functor = m_DiffReq.Functor; assert(m_DiffReq.Mode == GetPushForwardMode()); assert(!m_DerivativeInFlight && "Doesn't support recursive diff. Use DiffPlan."); @@ -884,7 +882,7 @@ StmtDiff BaseForwardModeVisitor::VisitMemberExpr(const MemberExpr* ME) { auto clonedME = dyn_cast(Clone(ME)); // Currently, we only differentiate member variables if we are // differentiating a call operator. - if (m_Functor) { + if (m_DiffReq.Functor) { if (isa(ME->getBase()->IgnoreParenImpCasts())) { // Try to find the derivative of the member variable wrt independent // variable @@ -956,7 +954,7 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { ValueDecl* VD = nullptr; // Derived variables for member variables are also created when we are // differentiating a call operator. - if (m_Functor) { + if (m_DiffReq.Functor) { if (auto ME = dyn_cast(clonedBase->IgnoreParenImpCasts())) { ValueDecl* decl = ME->getMemberDecl(); auto it = m_Variables.find(decl);