From e494c10e7c1e2dcb716e34c5f28ea0017f22e2c5 Mon Sep 17 00:00:00 2001 From: PetroZarytskyi <119341518+PetroZarytskyi@users.noreply.github.com> Date: Wed, 13 Nov 2024 18:04:15 +0100 Subject: [PATCH] Remove excessive FD and request parameters from DeriveVectorMode (#1136) --- .../clad/Differentiator/VectorForwardModeVisitor.h | 5 +---- lib/Differentiator/DerivativeBuilder.cpp | 2 +- lib/Differentiator/VectorForwardModeVisitor.cpp | 12 ++++-------- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index c5bccdeda..f906c6c14 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -30,13 +30,10 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { ///\brief Produces the first derivative of a given function with /// respect to multiple parameters. /// - ///\param[in] FD - the function that will be differentiated. - /// ///\returns The differentiated and potentially created enclosing /// context. /// - DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload DeriveVectorMode(); /// Builds an overload for the vector mode function that has derived params /// for all the arguments of the requested function and it calls the original diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index ada7153c6..946a68719 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -408,7 +408,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { result = V.DerivePushforward(); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this, request); - result = V.DeriveVectorMode(FD, request); + result = V.DeriveVectorMode(); } else if (request.Mode == DiffMode::experimental_vector_pushforward) { VectorPushForwardModeVisitor V(*this, request); result = V.DerivePushforward(); diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index c3756619a..c471c517d 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -53,20 +53,16 @@ void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) { m_IndVarCountExpr = IndVarCountExpr; } -DerivativeAndOverload -VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, - const DiffRequest& request) { - assert(m_DiffReq == request); +DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() { + const FunctionDecl* FD = m_DiffReq.Function; assert(m_DiffReq.Mode == DiffMode::vector_forward_mode); DiffParams args{}; - DiffInputVarsInfo DVI; - DVI = request.DVI; - for (auto dParam : DVI) + for (const auto& dParam : m_DiffReq.DVI) args.push_back(dParam.param); // Generate name for the derivative function. - std::string derivedFnName = request.BaseFunctionName + "_dvec"; + std::string derivedFnName = m_DiffReq.BaseFunctionName + "_dvec"; if (args.size() != FD->getNumParams()) { for (auto arg : args) { auto it = std::find(FD->param_begin(), FD->param_end(), arg);