From b340d05302d270d1b5afc410e7310fb97f7c5d08 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 18 Mar 2024 09:33:19 +0200 Subject: [PATCH] Move GetMultiArgCentralDiffCall to ReverseModeVisitor --- .../clad/Differentiator/ReverseModeVisitor.h | 22 ++++++++ include/clad/Differentiator/VisitorBase.h | 21 -------- lib/Differentiator/ReverseModeVisitor.cpp | 53 ++++++++++++++++++ lib/Differentiator/VisitorBase.cpp | 54 ------------------- 4 files changed, 75 insertions(+), 75 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 10564dbbf..6c7d1fe71 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -315,6 +315,28 @@ namespace clad { CladTapeResult MakeCladTapeFor(clang::Expr* E, llvm::StringRef prefix = "_t"); + /// A function to get the multi-argument "central_difference" + /// call expression for the given arguments. + /// + /// \param[in] targetFuncCall The function to get the derivative for. + /// \param[in] retType The return type of the target call expression. + /// \param[in] dfdx The dfdx corresponding to this call expression. + /// \param[in] numArgs The total number of 'args'. + /// \param[in] PreCallStmts The built statements to add to block + /// before the call to the derived function. + /// \param[in] PostCallStmts The built statements to add to block + /// after the call to the derived function. + /// \param[in] args All the arguments to the target function. + /// \param[in] outputArgs The output gradient arguments. + /// + /// \returns The derivative function call. + clang::Expr* GetMultiArgCentralDiffCall( + clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs, + clang::Expr* dfdx, llvm::SmallVectorImpl& PreCallStmts, + llvm::SmallVectorImpl& PostCallStmts, + llvm::SmallVectorImpl& args, + llvm::SmallVectorImpl& outputArgs); + public: ReverseModeVisitor(DerivativeBuilder& builder); virtual ~ReverseModeVisitor(); diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index a588b2b97..9feeb6624 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -585,27 +585,6 @@ namespace clad { clang::Expr* GetSingleArgCentralDiffCall( clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos, unsigned numArgs, llvm::SmallVectorImpl& args); - /// A function to get the multi-argument "central_difference" - /// call expression for the given arguments. - /// - /// \param[in] targetFuncCall The function to get the derivative for. - /// \param[in] retType The return type of the target call expression. - /// \param[in] dfdx The dfdx corresponding to this call expression. - /// \param[in] numArgs The total number of 'args'. - /// \param[in] PreCallStmts The built statements to add to block - /// before the call to the derived function. - /// \param[in] PostCallStmts The built statements to add to block - /// after the call to the derived function. - /// \param[in] args All the arguments to the target function. - /// \param[in] outputArgs The output gradient arguments. - /// - /// \returns The derivative function call. - clang::Expr* GetMultiArgCentralDiffCall( - clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs, - clang::Expr* dfdx, llvm::SmallVectorImpl& PreCallStmts, - llvm::SmallVectorImpl& PostCallStmts, - llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs); /// Emits diagnostic messages on differentiation (or lack thereof) for /// call expressions. /// diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0038a6bbd..ed8683483 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1957,6 +1957,59 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {}; } + Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall( + Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx, + llvm::SmallVectorImpl& PreCallStmts, + llvm::SmallVectorImpl& PostCallStmts, + llvm::SmallVectorImpl& args, + llvm::SmallVectorImpl& outputArgs) { + int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); + llvm::SmallVector NumDiffArgs = {}; + NumDiffArgs.push_back(targetFuncCall); + // build the clad::tape> = {}; + QualType RefType = GetCladArrayRefOfType(retType); + QualType TapeType = GetCladTapeOfType(RefType); + auto* VD = BuildVarDecl( + TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false, + /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); + PreCallStmts.push_back(BuildDeclStmt(VD)); + Expr* TapeRef = BuildDeclRef(VD); + NumDiffArgs.push_back(TapeRef); + NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, printErrorInf)); + + // Build the tape push expressions. + VD->setLocation(m_Function->getLocation()); + m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false); + CXXScopeSpec CSS; + CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); + LookupResult& Push = GetCladTapePush(); + Expr* PushDRE = + m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); + for (unsigned i = 0, e = numArgs; i < e; i++) { + QualType argTy = args[i]->getType(); + VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy)); + PreCallStmts.push_back(BuildDeclStmt(gradVar)); + Expr* PushExpr = BuildDeclRef(gradVar); + if (!isCladArrayType(argTy)) + PushExpr = BuildOp(UO_AddrOf, PushExpr); + std::array callArgs = {TapeRef, PushExpr}; + Stmt* PushStmt = + m_Sema + .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) + .get(); + PreCallStmts.push_back(PushStmt); + Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar)); + PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); + NumDiffArgs.push_back(args[i]); + } + std::string Name = "central_difference"; + return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, + /*forCustomDerv=*/false, + /*namespaceShouldExist=*/false); + } + StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { auto opCode = UnOp->getOpcode(); Expr* valueForRevPass = nullptr; diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 86ea4e9cc..c242bad2d 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -771,60 +771,6 @@ namespace clad { /*namespaceShouldExist=*/false); } - Expr* VisitorBase::GetMultiArgCentralDiffCall( - Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx, - llvm::SmallVectorImpl& PreCallStmts, - llvm::SmallVectorImpl& PostCallStmts, - llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs) { - int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); - llvm::SmallVector NumDiffArgs = {}; - NumDiffArgs.push_back(targetFuncCall); - // build the clad::tape> = {}; - QualType RefType = GetCladArrayRefOfType(retType); - QualType TapeType = GetCladTapeOfType(RefType); - auto VD = BuildVarDecl( - TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false, - /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); - PreCallStmts.push_back(BuildDeclStmt(VD)); - Expr* TapeRef = BuildDeclRef(VD); - NumDiffArgs.push_back(TapeRef); - NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral(m_Context.IntTy, - m_Context, - printErrorInf)); - - // Build the tape push expressions. - VD->setLocation(m_Function->getLocation()); - m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false); - CXXScopeSpec CSS; - CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); - LookupResult& Push = GetCladTapePush(); - Expr* PushDRE = - m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); - for (unsigned i = 0, e = numArgs; i < e; i++) { - QualType argTy = args[i]->getType(); - VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy)); - PreCallStmts.push_back(BuildDeclStmt(gradVar)); - Expr* PushExpr = BuildDeclRef(gradVar); - if (!isCladArrayType(argTy)) - PushExpr = BuildOp(UO_AddrOf, PushExpr); - std::array callArgs = {TapeRef, PushExpr}; - Stmt* PushStmt = - m_Sema - .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) - .get(); - PreCallStmts.push_back(PushStmt); - Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar)); - PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); - NumDiffArgs.push_back(args[i]); - } - std::string Name = "central_difference"; - return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, - /*forCustomDerv=*/false, - /*namespaceShouldExist=*/false); - } - void VisitorBase::CallExprDiffDiagnostics(llvm::StringRef funcName, SourceLocation srcLoc, bool isDerived){ if (!isDerived) {