From 96ad6a1d2da2ae6a7c769716200844ae501528a6 Mon Sep 17 00:00:00 2001 From: kchristin Date: Fri, 18 Oct 2024 19:52:07 +0300 Subject: [PATCH] Fix comment and format --- lib/Differentiator/ReverseModeVisitor.cpp | 38 +++++++++++++---------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f2857183a..ffdfea1b8 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -450,8 +450,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParams(params); - // if the function is a global kernel, all its parameters reside in the - // global memory of the GPU + // if the function is a global kernel, all the adjoint parameters reside in + // the global memory of the GPU. To facilitate the process, all the params + // of the kernel are added to the set. if (m_DiffReq->hasAttr()) for (auto* param : params) m_CUDAGlobalArgs.emplace(param); @@ -1651,9 +1652,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } Expr* CUDAExecConfig = nullptr; - if (auto KCE = dyn_cast(CE)) { + if (auto KCE = dyn_cast(CE)) CUDAExecConfig = Clone(KCE->getConfig()); - } // If the function is non_differentiable, return zero derivative. if (clad::utils::hasNonDifferentiableAttribute(CE)) { @@ -1663,10 +1663,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ClonedArgs.push_back(Clone(CE->getArg(i))); SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema); - Expr* Call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - validLoc, ClonedArgs, validLoc, CUDAExecConfig) - .get(); + Expr* Call = + m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, ClonedArgs, validLoc, CUDAExecConfig) + .get(); // Creating a zero derivative auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, /*val=*/0); @@ -1813,7 +1814,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, - llvm::MutableArrayRef(CallArgs), Loc, CUDAExecConfig) + llvm::MutableArrayRef(CallArgs), Loc, + CUDAExecConfig) .get(); return call; } @@ -1925,7 +1927,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, pushforwardCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext()), true, true, CUDAExecConfig); + const_cast(FD->getDeclContext()), true, true, + CUDAExecConfig); if (OverloadedDerivedFn) asGrad = false; } @@ -2027,7 +2030,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext()), true, true, CUDAExecConfig); + const_cast(FD->getDeclContext()), true, true, + CUDAExecConfig); if (baseDiff.getExpr()) pullbackCallArgs.erase(pullbackCallArgs.begin()); } @@ -2043,10 +2047,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) .get(); - OverloadedDerivedFn = m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, - Loc, pullbackCallArgs, Loc, CUDAExecConfig) - .get(); + OverloadedDerivedFn = + m_Sema + .ActOnCallExpr(getCurrentScope(), selfRef, Loc, + pullbackCallArgs, Loc, CUDAExecConfig) + .get(); } else { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( @@ -2282,7 +2287,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVectorImpl& PreCallStmts, llvm::SmallVectorImpl& PostCallStmts, llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs, Expr* CUDAExecConfig /*=nullptr*/) { + llvm::SmallVectorImpl& outputArgs, + Expr* CUDAExecConfig /*=nullptr*/) { int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); llvm::SmallVector NumDiffArgs = {}; NumDiffArgs.push_back(targetFuncCall);