From d8a87985ae041d146e7f2b7749ddb5061954c801 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Fri, 27 Dec 2024 11:18:58 +0000 Subject: [PATCH] Move the computation of parameters in SetupDerivativeParameters. --- .../Differentiator/BaseForwardModeVisitor.h | 13 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 132 +++++++++--------- 2 files changed, 74 insertions(+), 71 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index ad9ce8126..a5ccda51a 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -8,6 +8,8 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Sema/Sema.h" +#include "llvm/ADT/SmallVector.h" + #include #include #include @@ -37,9 +39,6 @@ class BaseForwardModeVisitor DerivativeAndOverload DerivePushforward(); - /// Computes the return type of the derivative in `m_DiffReq->Function`. - clang::QualType ComputeDerivativeFunctionType(); - virtual void ExecuteInsidePushforwardFunctionBlock(); static bool IsDifferentiableType(clang::QualType T); @@ -148,6 +147,14 @@ class BaseForwardModeVisitor const clang::CXXConstructExpr* CE, llvm::SmallVectorImpl& clonedArgs, llvm::SmallVectorImpl& derivedArgs); + +private: + /// Computes the return type of the derivative in `m_DiffReq->Function`. + clang::QualType ComputeDerivativeFunctionType(); + + /// Prepares the derivative function parameters. + void + SetupDerivativeParameters(llvm::SmallVectorImpl& params); }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index bc59f3529..6b1f5631b 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -154,32 +154,16 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { FunctionDecl* derivedFD = result.first; m_Derivative = derivedFD; - llvm::SmallVector params; - const ParmVarDecl* PVD = nullptr; - // Function declaration scope beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | Scope::DeclScope); m_Sema.PushFunctionScope(); m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) { - PVD = FD->getParamDecl(i); - auto* newPVD = CloneParmVarDecl(PVD, PVD->getIdentifier(), - /*pushOnScopeChains=*/true, - /*cloneDefaultArg=*/false); - - // Make m_IndependentVar to point to the argument of the newly created - // derivedFD. - if (PVD == m_IndependentVar) - m_IndependentVar = newPVD; - - params.push_back(newPVD); - } + llvm::SmallVector params; + SetupDerivativeParameters(params); + derivedFD->setParams(params); - llvm::ArrayRef paramsRef = - clad_compat::makeArrayRef(params.data(), params.size()); - derivedFD->setParams(paramsRef); derivedFD->setBody(nullptr); if (!m_DiffReq.DeclarationOnly) { @@ -386,6 +370,64 @@ QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() { return m_Context.getFunctionType(dRetTy, FnTypes, EPI); } +void BaseForwardModeVisitor::SetupDerivativeParameters( + llvm::SmallVectorImpl& params) { + const FunctionDecl* FD = m_DiffReq.Function; + for (ParmVarDecl* PVD : FD->parameters()) { + IdentifierInfo* PVDII = PVD->getIdentifier(); + // Implicitly created special member functions have no parameter names. + if (!PVD->getDeclName()) + PVDII = CreateUniqueIdentifier("param"); + + auto* newPVD = CloneParmVarDecl(PVD, PVDII, + /*pushOnScopeChains=*/true, + /*cloneDefaultArg=*/false); + + // Point m_IndependentVar to the argument of the newly created param. + if (PVD == m_IndependentVar) + m_IndependentVar = newPVD; + + if (!PVD->getDeclName()) // We can't use lookup-based replacements + m_DeclReplacements[PVD] = newPVD; + + params.push_back(newPVD); + } + + if (m_DiffReq.Mode == DiffMode::forward) + return; + + bool HasThis = false; + // If we are differentiating an instance member function then create a + // parameter for representing derivative of `this` pointer with respect to the + // independent parameter. + if (const auto* MD = dyn_cast(FD)) { + if (MD->isInstance()) { + IdentifierInfo* dThisII = CreateUniqueIdentifier("_d_this"); + auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, dThisII, + MD->getThisType()); + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false); + params.push_back(dPVD); + // FIXME: Replace m_ThisExprDerivative in favor of lookups of _d_this. + m_ThisExprDerivative = BuildDeclRef(dPVD); + HasThis = true; + } + } + + for (size_t i = 0, e = params.size() - HasThis; i < e; ++i) { + const ParmVarDecl* PVD = params[i]; + + if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + continue; + + IdentifierInfo* II = CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + auto* dPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, II, GetPushForwardDerivativeType(PVD->getType()), + PVD->getStorageClass()); + params.push_back(dPVD); + m_Variables[PVD] = BuildDeclRef(dPVD); + } +} + void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt(); auto* CS = cast(bodyDiff); @@ -417,61 +459,15 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType); m_Derivative = cloneFunctionResult.first; - llvm::SmallVector params; - llvm::SmallVector derivedParams; beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | Scope::DeclScope); m_Sema.PushFunctionScope(); m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - // If we are differentiating an instance member function then - // create a parameter for representing derivative of - // `this` pointer with respect to the independent parameter. - if (const auto* MFD = dyn_cast(FD)) { - if (MFD->isInstance()) { - auto thisType = MFD->getThisType(); - IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this"); - auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, - derivedPVDII, thisType); - m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(), - /*AddToContext=*/false); - derivedParams.push_back(derivedPVD); - m_ThisExprDerivative = BuildDeclRef(derivedPVD); - } - } - - std::size_t numParamsOriginalFn = m_DiffReq->getNumParams(); - for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { - const auto* PVD = m_DiffReq->getParamDecl(i); - // Some of the special member functions created implicitly by compilers - // have missing parameter identifier. - bool identifierMissing = false; - IdentifierInfo* PVDII = PVD->getIdentifier(); - if (!PVDII || PVDII->getLength() == 0) { - PVDII = CreateUniqueIdentifier("param"); - identifierMissing = true; - } - auto* newPVD = CloneParmVarDecl(PVD, PVDII, - /*pushOnScopeChains=*/true, - /*cloneDefaultArg=*/false); - params.push_back(newPVD); - - if (identifierMissing) - m_DeclReplacements[PVD] = newPVD; - - if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) - continue; - auto derivedPVDName = "_d_" + std::string(PVDII->getName()); - IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); - auto* derivedPVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, derivedPVDII, - GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass()); - derivedParams.push_back(derivedPVD); - m_Variables[newPVD] = BuildDeclRef(derivedPVD); - } - - params.insert(params.end(), derivedParams.begin(), derivedParams.end()); + llvm::SmallVector params; + SetupDerivativeParameters(params); m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); if (!m_DiffReq.DeclarationOnly) {