From c893c317cb7889b2f68412fb63c983a6af47e405 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Fri, 27 Dec 2024 17:06:51 +0000 Subject: [PATCH] Move the seed initialization into a separate function. --- .../Differentiator/BaseForwardModeVisitor.h | 7 + lib/Differentiator/BaseForwardModeVisitor.cpp | 257 +++++++++--------- 2 files changed, 140 insertions(+), 124 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index a5ccda51a..107955895 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -155,6 +155,13 @@ class BaseForwardModeVisitor /// Prepares the derivative function parameters. void SetupDerivativeParameters(llvm::SmallVectorImpl& params); + + /// Generate a seed initializing each independent argument with 1 and 0 + /// otherwise: + /// double f_darg0(double x, double y) { + /// double _d_x = 1; + /// double _d_y = 0; + void GenerateSeeds(const clang::FunctionDecl* dFD); }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 6b1f5631b..29cf3d309 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -9,13 +9,17 @@ #include "ConstantFolder.h" #include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" #include "clang/AST/ASTContext.h" #include "clang/AST/ASTLambda.h" +#include "clang/AST/Decl.h" #include "clang/AST/Expr.h" +#include "clang/AST/OperationKinds.h" #include "clang/AST/TemplateBase.h" +#include "clang/AST/Type.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" @@ -29,6 +33,7 @@ #include #include +#include #include "clad/Differentiator/Compatibility.h" @@ -171,131 +176,9 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { beginScope(Scope::FnScope | Scope::DeclScope); m_DerivativeFnScope = getCurrentScope(); beginBlock(); - // For each function parameter variable, store its derivative value. - for (auto* param : params) { - // We cannot create derivatives of reference type since seed value is - // always a constant (r-value). We assume that all the arguments have no - // relation among them, thus it is safe (correct) to use the corresponding - // non-reference type for creating the derivatives. - QualType dParamType = param->getType().getNonReferenceType(); - // We do not create derived variable for array/pointer parameters. - if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) || - utils::isArrayOrPointerType(dParamType)) - continue; - Expr* dParam = nullptr; - if (dParamType->isRealType()) { - // If param is independent variable, its derivative is 1, otherwise 0. - int dValue = (param == m_IndependentVar); - dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, - dValue); - } - // For each function arg, create a variable _d_arg to store derivatives - // of potential reassignments, e.g.: - // double f_darg0(double x, double y) { - // double _d_x = 1; - // double _d_y = 0; - // ... - auto* dParamDecl = - BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam); - addToCurrentBlock(BuildDeclStmt(dParamDecl)); - dParam = BuildDeclRef(dParamDecl); - if (dParamType->isRecordType() && param == m_IndependentVar) { - llvm::SmallVector ref(diffVarInfo.fields.begin(), - diffVarInfo.fields.end()); - Expr* memRef = - utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref); - assert(memRef->getType()->isRealType() && - "Forward mode can only differentiate w.r.t builtin scalar " - "numerical types."); - addToCurrentBlock(BuildOp( - BinaryOperatorKind::BO_Assign, memRef, - ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1))); - } - // Memorize the derivative of param, i.e. whenever the param is visited - // in the future, it's derivative dParam is found (unless reassigned with - // something new). - m_Variables[param] = dParam; - } - - if (const auto* MD = dyn_cast(FD)) { - // We cannot create derivative of lambda yet because lambdas default - // constructor is deleted. - if (MD->isInstance() && !MD->getParent()->isLambda()) { - QualType thisObjectType = - clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD); - QualType thisType = MD->getThisType(); - // Here we are effectively doing: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // We are not creating `this` expression derivative using `new` because - // then we would be responsible for freeing the memory as well and its - // more convenient to let compiler handle the object lifecycle. - VarDecl* derivativeVD = BuildVarDecl(thisObjectType, "_d_this_obj"); - DeclRefExpr* derivativeE = BuildDeclRef(derivativeVD); - VarDecl* thisExprDerivativeVD = - BuildVarDecl(thisType, "_d_this", - BuildOp(UnaryOperatorKind::UO_AddrOf, derivativeE)); - addToCurrentBlock(BuildDeclStmt(derivativeVD)); - addToCurrentBlock(BuildDeclStmt(thisExprDerivativeVD)); - m_ThisExprDerivative = BuildDeclRef(thisExprDerivativeVD); - } - } - // Create derived variable for each member variable if we are - // differentiating a call operator. - if (m_DiffReq.Functor) { - for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) { - Expr* dInitializer = nullptr; - QualType fieldType = fieldDecl->getType(); - - if (const auto* arrType = - dyn_cast(fieldType.getTypePtr())) { - if (!arrType->getElementType()->isRealType()) - continue; - - auto arrSize = arrType->getSize().getZExtValue(); - std::vector dArrVal; - - // Create an initializer list to initialize derived variable created - // for array member variable. - // For example, if we are differentiating wrt arr[3], then - // ``` - // double arr[7]; - // ``` - // will get differentiated to, - // - // ``` - // double _d_arr[7] = {0, 0, 0, 1, 0, 0, 0}; - // ``` - for (size_t i = 0; i < arrSize; ++i) { - int dValue = - (fieldDecl == m_IndependentVar && i == m_IndependentVarIndex); - auto* dValueLiteral = ConstantFolder::synthesizeLiteral( - m_Context.IntTy, m_Context, dValue); - dArrVal.push_back(dValueLiteral); - } - dInitializer = - m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get(); - } else if (const auto* ptrType = - dyn_cast(fieldType.getTypePtr())) { - if (!ptrType->getPointeeType()->isRealType()) - continue; - // Pointer member variables should be initialised by `nullptr`. - dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get(); - } else { - int dValue = (fieldDecl == m_IndependentVar); - dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy, - m_Context, dValue); - } - VarDecl* derivedFieldDecl = - BuildVarDecl(fieldType.getNonReferenceType(), - "_d_" + fieldDecl->getNameAsString(), dInitializer); - addToCurrentBlock(BuildDeclStmt(derivedFieldDecl)); - m_Variables.emplace(fieldDecl, BuildDeclRef(derivedFieldDecl)); - } - } + if (m_DiffReq.Mode == DiffMode::forward) + GenerateSeeds(derivedFD); Stmt* BodyDiff = Visit(FD->getBody()).getStmt(); if (auto* CS = dyn_cast(BodyDiff)) @@ -428,6 +311,132 @@ void BaseForwardModeVisitor::SetupDerivativeParameters( } } +void BaseForwardModeVisitor::GenerateSeeds(const clang::FunctionDecl* dFD) { + // For each function parameter variable, store its derivative value. + for (const ParmVarDecl* param : dFD->parameters()) { + // We cannot create derivatives of reference type since seed value is + // always a constant (r-value). We assume that all the arguments have no + // relation among them, thus it is safe (correct) to use the corresponding + // non-reference type for creating the derivatives. + QualType dParamType = param->getType().getNonReferenceType(); + // We do not create derived variable for array/pointer parameters. + if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) || + utils::isArrayOrPointerType(dParamType)) + continue; + Expr* dParam = nullptr; + if (dParamType->isRealType()) { + // If param is independent variable, its derivative is 1, otherwise 0. + int dValue = (param == m_IndependentVar); + dParam = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, dValue); + } + // For each function arg, create a variable _d_arg to store derivatives + // of potential reassignments, e.g.: + // double f_darg0(double x, double y) { + // double _d_x = 1; + // double _d_y = 0; + // ... + auto* dParamDecl = + BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam); + addToCurrentBlock(BuildDeclStmt(dParamDecl)); + dParam = BuildDeclRef(dParamDecl); + if (dParamType->isRecordType() && param == m_IndependentVar) { + DiffInputVarInfo diffVarInfo = m_DiffReq.DVI.back(); + llvm::SmallVector ref(diffVarInfo.fields.begin(), + diffVarInfo.fields.end()); + Expr* memRef = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref); + assert(memRef->getType()->isRealType() && + "Forward mode can only differentiate w.r.t builtin scalar " + "numerical types."); + addToCurrentBlock(BuildOp(BinaryOperatorKind::BO_Assign, memRef, + ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, /*val=*/1))); + } + // Memorize the derivative of param, i.e. whenever the param is visited + // in the future, it's derivative dParam is found (unless reassigned with + // something new). + m_Variables[param] = dParam; + } + if (const auto* MD = dyn_cast(dFD)) { + // We cannot create derivative of lambda yet because lambdas default + // constructor is deleted. + if (MD->isInstance() && !MD->getParent()->isLambda()) { + QualType thisObjectType = + clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD); + QualType thisType = MD->getThisType(); + // Here we are effectively doing: + // ``` + // Class _d_this_obj; + // Class* _d_this = &_d_this_obj; + // ``` + // We are not creating `this` expression derivative using `new` because + // then we would be responsible for freeing the memory as well and its + // more convenient to let compiler handle the object lifecycle. + VarDecl* derivativeVD = BuildVarDecl(thisObjectType, "_d_this_obj"); + DeclRefExpr* derivativeE = BuildDeclRef(derivativeVD); + VarDecl* thisExprDerivativeVD = + BuildVarDecl(thisType, "_d_this", + BuildOp(UnaryOperatorKind::UO_AddrOf, derivativeE)); + addToCurrentBlock(BuildDeclStmt(derivativeVD)); + addToCurrentBlock(BuildDeclStmt(thisExprDerivativeVD)); + m_ThisExprDerivative = BuildDeclRef(thisExprDerivativeVD); + } + } + SourceLocation validLoc{m_DiffReq->getLocation()}; + // Create derived variable for each member variable if we are + // differentiating a call operator. + if (m_DiffReq.Functor) { + for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) { + Expr* dInitializer = nullptr; + QualType fieldType = fieldDecl->getType(); + + if (const auto* arrType = dyn_cast(fieldType)) { + if (!arrType->getElementType()->isRealType()) + continue; + + auto arrSize = arrType->getSize().getZExtValue(); + std::vector dArrVal; + + // Create an initializer list to initialize derived variable created + // for array member variable. + // For example, if we are differentiating wrt arr[3], then + // ``` + // double arr[7]; + // ``` + // will get differentiated to, + // + // ``` + // double _d_arr[7] = {0, 0, 0, 1, 0, 0, 0}; + // ``` + for (size_t i = 0; i < arrSize; ++i) { + int dValue = + (fieldDecl == m_IndependentVar && i == m_IndependentVarIndex); + auto* dValueLiteral = ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, dValue); + dArrVal.push_back(dValueLiteral); + } + dInitializer = m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get(); + } else if (const auto* ptrType = + dyn_cast(fieldType.getTypePtr())) { + if (!ptrType->getPointeeType()->isRealType()) + continue; + // Pointer member variables should be initialised by `nullptr`. + dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get(); + } else { + int dValue = (fieldDecl == m_IndependentVar); + dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy, + m_Context, dValue); + } + VarDecl* derivedFieldDecl = + BuildVarDecl(fieldType.getNonReferenceType(), + "_d_" + fieldDecl->getNameAsString(), dInitializer); + addToCurrentBlock(BuildDeclStmt(derivedFieldDecl)); + m_Variables.emplace(fieldDecl, BuildDeclRef(derivedFieldDecl)); + } + } +} + void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt(); auto* CS = cast(bodyDiff);