diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 7781d7eb2..840420349 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -451,6 +451,16 @@ namespace clad { /// \returns The atomicAdd call expression. clang::Expr* BuildCallToCudaAtomicAdd(clang::Expr* LHS, clang::Expr* RHS); + /// Check whether this is an assignment to a malloc or realloc call for a + /// derivative variable and build a call to calloc instead, to properly + /// intialize the memory to zero. Currently these configurations of size are + /// supported in malloc or realloc: + /// 1. x * sizeof(T) + /// 2. sizeof(T) * x + /// \param[in] RHS The right-hand side expression of the assignment. + /// @returns The call to calloc if the condition is met, otherwise nullptr. + clang::Expr* CheckAndBuildCallToCalloc(clang::Expr* RHS); + static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 9a7be174c..210f82112 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -546,15 +546,6 @@ namespace clad { bool useRefQualifiedThisObj = false, const clang::CXXScopeSpec* SS = nullptr); - /// Build a call to a free function. Search for it using its name and args. - /// - /// \param[in] funcName function name - /// \param[in] argExprs function arguments expressions - /// \returns Built call expression - clang::Expr* - BuildCallExprToFunction(std::string funcName, - llvm::MutableArrayRef args); - /// Build a call to templated free function inside the clad namespace. /// /// \param[in] name name of the function diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e77158afd..84b2079d3 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1736,11 +1736,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, llvm::MutableArrayRef(CallArgs), Loc) .get(); - Expr* call_dx = - m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, - llvm::MutableArrayRef(DerivedCallArgs), Loc) - .get(); + Expr* call_dx = nullptr; + if (FD->getNameAsString() == "malloc") + call_dx = CheckAndBuildCallToCalloc(Clone(CE)); + if (!call_dx) + call_dx = m_Sema + .ActOnCallExpr( + getCurrentScope(), Clone(CE->getCallee()), Loc, + llvm::MutableArrayRef(DerivedCallArgs), Loc) + .get(); return StmtDiff(call, call_dx); } // For calls to C-style memory deallocation functions, we do not need to @@ -2856,6 +2860,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR); addToCurrentBlock(BuildOp(opCode, derivedL, derivedR), direction::forward); + if (opCode == BO_Assign && derivedR) + if (Expr* callocCall = + CheckAndBuildCallToCalloc(derivedR->IgnoreParenCasts())) { + Expr* cast = + m_Sema + .BuildCStyleCastExpr( + SourceLocation(), + m_Context.getTrivialTypeSourceInfo(derivedL->getType()), + SourceLocation(), callocCall) + .get(); + addToCurrentBlock(BuildOp(BO_Assign, derivedL, cast), + direction::forward); + } } } return StmtDiff(op, ResultRef, nullptr, valueForRevPass); @@ -3209,8 +3226,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool dxInForward = false; if (auto* callExpr = dyn_cast_or_null(stmtDx)) if (auto* FD = dyn_cast(callExpr->getCalleeDecl())) - if (utils::IsMemoryFunction(FD)) + if (utils::IsMemoryFunction(FD)) { + printf("%s\n", FD->getNameAsString().c_str()); dxInForward = true; + } if (stmtDx) { if (dxInForward) addToCurrentBlock(stmtDx, direction::forward); @@ -3224,6 +3243,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(SDiff.getStmt(), ReverseResult); } + Expr* ReverseModeVisitor::CheckAndBuildCallToCalloc(Expr* RHS) { + Expr* size = nullptr; + if (auto* callExpr = dyn_cast(RHS)) + if (auto* implCast = dyn_cast(callExpr->getCallee())) + if (auto* declRef = dyn_cast(implCast->getSubExpr())) + if (auto* FD = dyn_cast(declRef->getDecl())) { + if (FD->getNameAsString() == "malloc") + size = callExpr->getArg(0); + else if (FD->getNameAsString() == "realloc") + size = callExpr->getArg(1); + } + + if (size) { + llvm::SmallVector args; + if (auto BinOp = dyn_cast(size)) { + if (BinOp->getOpcode() == BO_Mul) { + Expr* lhs = BinOp->getLHS(); + Expr* rhs = BinOp->getRHS(); + if (auto* sizeofCall = dyn_cast(rhs)) + args = {lhs, sizeofCall}; + else if (auto* sizeofCall = dyn_cast(lhs)) + args = {rhs, sizeofCall}; + if (!args.empty()) + return GetFunctionCall("calloc", "", args); + } + } + } + + return {}; + } + std::pair ReverseModeVisitor::DifferentiateSingleExpr(const Expr* E, Expr* dfdE) { beginBlock(direction::forward); @@ -3335,21 +3385,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, else { VarDecl* VDDerived = VDDiff.getDecl_dx(); declsDiff.push_back(VDDerived); - if (auto* cast = dyn_cast(VDDerived->getInit())) - if (auto* callExpr = - dyn_cast(cast->getSubExpr()->IgnoreCasts())) - if (auto* implCast = - dyn_cast(callExpr->getCallee())) - if (auto* declRef = - dyn_cast(implCast->getSubExpr())) - if (auto* FD = dyn_cast(declRef->getDecl())) - if (FD->getNameAsString() == "malloc") { - llvm::SmallVector memsetArgs{ - BuildDeclRef(VDDerived), - getZeroInit(m_Context.IntTy), callExpr->getArg(0)}; - callToMemsets.push_back( - BuildCallExprToFunction("memset", memsetArgs)); - } } } } else if (auto* SAD = dyn_cast(D)) { diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index af2133ae1..d955ed6df 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -531,11 +531,14 @@ namespace clad { Expr* VisitorBase::GetFunctionCall(const std::string& funcName, const std::string& nmspace, llvm::SmallVectorImpl& callArgs) { - NamespaceDecl* NSD = - utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true); - DeclContext* DC = NSD; + NamespaceDecl* NSD = nullptr; CXXScopeSpec SS; - SS.Extend(m_Context, NSD, noLoc, noLoc); + + if (!nmspace.empty()) { + NSD = utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true); + SS.Extend(m_Context, NSD, noLoc, noLoc); + } + DeclContext* DC = NSD; IdentifierInfo* II = &m_Context.Idents.get(funcName); DeclarationName name(II); @@ -544,6 +547,8 @@ namespace clad { if (DC) m_Sema.LookupQualifiedName(R, DC); + else + m_Sema.LookupQualifiedName(R, m_Context.getTranslationUnitDecl()); Expr* UnresolvedLookup = nullptr; if (!R.empty()) UnresolvedLookup = @@ -683,35 +688,6 @@ namespace clad { return call; } - clang::Expr* - VisitorBase::BuildCallExprToFunction(std::string funcName, - llvm::MutableArrayRef args) { - DeclarationName Id = &m_Context.Idents.get(funcName); - LookupResult lookupResult(m_Sema, Id, SourceLocation(), - Sema::LookupOrdinaryName); - m_Sema.LookupQualifiedName(lookupResult, - m_Context.getTranslationUnitDecl()); - - CXXScopeSpec SS; - Expr* UnresolvedLookup = - m_Sema.BuildDeclarationNameExpr(SS, lookupResult, /*ADL=*/true).get(); - for (auto arg : args) - arg->dump(); - - assert(!m_Builder.noOverloadExists(UnresolvedLookup, args) && - "memset function not found"); - - Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), - /*Fn=*/UnresolvedLookup, - /*LParenLoc=*/noLoc, - /*ArgExprs=*/args, - /*RParenLoc=*/m_DiffReq->getLocation()) - .get(); - - return call; - } - Expr* VisitorBase::BuildCallExprToCladFunction( llvm::StringRef name, llvm::MutableArrayRef argExprs, llvm::ArrayRef templateArgs, diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu index ae32330f0..f111d0d9b 100644 --- a/test/CUDA/GradientKernels.cu +++ b/test/CUDA/GradientKernels.cu @@ -447,7 +447,7 @@ double fn_memory(double *out, double *in) { //CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: kernel_call<<<1, 10>>>(out, in); //CHECK-NEXT: cudaDeviceSynchronize(); -//CHECK-NEXT: double *_d_out_host = (double *)malloc(10 * sizeof(double)); +//CHECK-NEXT: double *_d_out_host = (double *)calloc(10, sizeof(double)); //CHECK-NEXT: double *out_host = (double *)malloc(10 * sizeof(double)); //CHECK-NEXT: cudaMemcpy(out_host, out, 10 * sizeof(double), cudaMemcpyDeviceToHost); //CHECK-NEXT: double _d_res = 0.; diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 21d98ecaf..a7f17e67f 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -407,7 +407,7 @@ double cStyleMemoryAlloc(double x, size_t n) { // CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, double *_d_x) { // CHECK-NEXT: size_t _d_n = 0UL; -// CHECK-NEXT: T *_d_t = (T *)malloc(n * sizeof(T)); +// CHECK-NEXT: T *_d_t = (T *)calloc(n, sizeof(T)); // CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T)); // CHECK-NEXT: memset(_d_t, 0, n * sizeof(T)); // CHECK-NEXT: memset(t, 0, n * sizeof(T)); @@ -422,6 +422,7 @@ double cStyleMemoryAlloc(double x, size_t n) { // CHECK-NEXT: double *_t2 = p; // CHECK-NEXT: double *_t3 = _d_p; // CHECK-NEXT: _d_p = (double *)realloc(_d_p, 2 * sizeof(double)); +// CHECK-NEXT: _d_p = (double *)calloc(2, sizeof(double)); // CHECK-NEXT: p = (double *)realloc(p, 2 * sizeof(double)); // CHECK-NEXT: double _t4 = p[1]; // CHECK-NEXT: p[1] = 2 * x;