From 7205668f076dd242666b5a083c8d4a900fc7dec4 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 26 Dec 2024 20:15:40 +0000 Subject: [PATCH] Move computing of the effective derivative name in DiffRequest. NFC. --- include/clad/Differentiator/DiffPlanner.h | 1 + include/clad/Differentiator/VisitorBase.h | 3 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 37 +++--------------- lib/Differentiator/DiffPlanner.cpp | 38 +++++++++++++++++++ 4 files changed, 46 insertions(+), 33 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 54c4ee3eb..33359aaec 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -147,6 +147,7 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + std::string ComputeDerivativeName() const; }; using DiffInterval = std::vector; diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 0abd868ca..e271d3bd0 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -9,6 +9,7 @@ #include "Compatibility.h" #include "DerivativeBuilder.h" +#include "clad/Differentiator/CladUtils.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtVisitor.h" @@ -199,7 +200,7 @@ namespace clad { } /// For a qualtype QT returns if it's type is Array or Pointer Type static bool isArrayOrPointerType(const clang::QualType QT) { - return QT->isArrayType() || QT->isPointerType(); + return utils::isArrayOrPointerType(QT); } clang::CompoundStmt* MakeCompoundStmt(const Stmts& Stmts); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index e242d0ccd..bc59f3529 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/SaveAndRestore.h" #include +#include #include "clad/Differentiator/Compatibility.h" @@ -98,7 +99,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { // performed. Mathematically, independent variables are all the function // parameters, thus, does not convey the intendend meaning. m_IndependentVar = DVI.back().param; - std::string derivativeSuffix(""); // If param is not real (i.e. floating point or integral), a pointer to a // real type, or an array of a real type we cannot differentiate it. // FIXME: we should support custom numeric types in the future. @@ -113,7 +113,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { return {}; } m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start; - derivativeSuffix = "_" + std::to_string(m_IndependentVarIndex); } else { QualType T = m_IndependentVar->getType(); bool isField = false; @@ -132,32 +131,9 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { } } - // If we are differentiating a call operator, that has no parameters, - // then the specified independent argument is a member variable of the - // class defining the call operator. - // Thus, we need to find index of the member variable instead. - unsigned argIndex = ~0; - const CXXRecordDecl* functor = m_DiffReq.Functor; - if (m_DiffReq->param_empty() && functor) - argIndex = std::distance(functor->field_begin(), - std::find(functor->field_begin(), - functor->field_end(), m_IndependentVar)); - else - argIndex = std::distance( - FD->param_begin(), - std::find(FD->param_begin(), FD->param_end(), m_IndependentVar)); - - std::string argInfo = std::to_string(argIndex); - for (auto field : diffVarInfo.fields) - argInfo += "_" + field; - - std::string s; - if (m_DiffReq.CurrentDerivativeOrder > 1) - s = std::to_string(m_DiffReq.CurrentDerivativeOrder); - // Check if the function is already declared as a custom derivative. - std::string gradientName = m_DiffReq.BaseFunctionName + "_d" + s + "arg" + - argInfo + derivativeSuffix; + std::string gradientName = m_DiffReq.ComputeDerivativeName(); + // FIXME: We should not use const_cast to get the decl context here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); @@ -428,12 +404,9 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); - auto originalFnEffectiveName = utils::ComputeEffectiveFnName(FD); - - IdentifierInfo* derivedFnII = &m_Context.Idents.get( - originalFnEffectiveName + GetPushForwardFunctionSuffix()); + IdentifierInfo* II = &m_Context.Idents.get(m_DiffReq.ComputeDerivativeName()); SourceLocation loc{FD->getLocation()}; - DeclarationNameInfo derivedFnName(derivedFnII, loc); + DeclarationNameInfo derivedFnName(II, loc); // FIXME: We should not use const_cast to get the decl context here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index cd90e0c01..e8285bda3 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -669,6 +669,44 @@ namespace clad { return found != m_ActivityRunInfo.ToBeRecorded.end(); } + std::string DiffRequest::ComputeDerivativeName() const { + if (Mode != DiffMode::forward && Mode != DiffMode::reverse) + return BaseFunctionName + "_" + DiffModeToString(Mode); + + if (DVI.empty()) + return ""; + + DiffInputVarInfo VarInfo = DVI.back(); + const ValueDecl* IndependentVar = VarInfo.param; + unsigned argIndex = ~0; + // If we are differentiating a call operator, that has no parameters, + // then the specified independent argument is a member variable of the + // class defining the call operator. + // Thus, we need to find index of the member variable instead. + if (Function->param_empty() && Functor) + argIndex = std::distance(Functor->field_begin(), + std::find(Functor->field_begin(), + Functor->field_end(), IndependentVar)); + else + argIndex = + std::distance(Function->param_begin(), + std::find(Function->param_begin(), + Function->param_end(), IndependentVar)); + + std::string argInfo = std::to_string(argIndex); + for (const std::string& field : VarInfo.fields) + argInfo += "_" + field; + + std::string s; + if (CurrentDerivativeOrder > 1) + s = std::to_string(CurrentDerivativeOrder); + + if (utils::isArrayOrPointerType(IndependentVar->getType())) + argInfo += "_" + std::to_string(VarInfo.paramIndexInterval.Start); + + return BaseFunctionName + "_d" + s + "arg" + argInfo; + } + ///\returns true on error. static bool ProcessInvocationArgs(Sema& S, SourceLocation endLoc, const RequestOptions& ReqOpts,