Skip to content

Commit

Permalink
Fix comment and format
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 18, 2024
1 parent ae71115 commit 96ad6a1
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
// if the function is a global kernel, all the adjoint parameters reside in
// the global memory of the GPU. To facilitate the process, all the params
// of the kernel are added to the set.
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);
Expand Down Expand Up @@ -1651,9 +1652,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

Expr* CUDAExecConfig = nullptr;
if (auto KCE = dyn_cast<CUDAKernelCallExpr>(CE)) {
if (auto KCE = dyn_cast<CUDAKernelCallExpr>(CE))
CUDAExecConfig = Clone(KCE->getConfig());
}

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
Expand All @@ -1663,10 +1663,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ClonedArgs.push_back(Clone(CE->getArg(i)));

SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema);
Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc, CUDAExecConfig)
.get();
Expr* Call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc, CUDAExecConfig)
.get();
// Creating a zero derivative
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
Expand Down Expand Up @@ -1813,7 +1814,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc, CUDAExecConfig)
llvm::MutableArrayRef<Expr*>(CallArgs), Loc,
CUDAExecConfig)
.get();
return call;
}
Expand Down Expand Up @@ -1925,7 +1927,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()), true, true, CUDAExecConfig);
const_cast<DeclContext*>(FD->getDeclContext()), true, true,
CUDAExecConfig);
if (OverloadedDerivedFn)
asGrad = false;
}
Expand Down Expand Up @@ -2027,7 +2030,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()), true, true, CUDAExecConfig);
const_cast<DeclContext*>(FD->getDeclContext()), true, true,
CUDAExecConfig);
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}
Expand All @@ -2043,10 +2047,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();

OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef,
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef, Loc,
pullbackCallArgs, Loc, CUDAExecConfig)
.get();
} else {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
Expand Down Expand Up @@ -2282,7 +2287,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVectorImpl<Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<Expr*>& args,
llvm::SmallVectorImpl<Expr*>& outputArgs, Expr* CUDAExecConfig /*=nullptr*/) {
llvm::SmallVectorImpl<Expr*>& outputArgs,
Expr* CUDAExecConfig /*=nullptr*/) {
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
llvm::SmallVector<Expr*, 16U> NumDiffArgs = {};
NumDiffArgs.push_back(targetFuncCall);
Expand Down

0 comments on commit 96ad6a1

Please sign in to comment.