Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 12, 2024
1 parent ea63937 commit 3d7c32f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct DiffRequest {
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Indexes of global args of function as a subset of Args.
/// Indexes of global GPU args of function as a subset of Args.
std::unordered_set<size_t> GlobalArgsIndexes;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace clad {
/// that will be put immediately in the beginning of derivative function
/// block.
Stmts m_Globals;
/// Global args of the function.
/// Global GPU args of the function.
std::unordered_set<const clang::ParmVarDecl*> m_GlobalArgs;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
Expand Down
14 changes: 12 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!m_GlobalArgs.empty())
if (const auto* DRE = dyn_cast<DeclRefExpr>(E))
if (const auto* PVD = dyn_cast<ParmVarDecl>(DRE->getDecl()))
// we need to check whether this param is in global memory of the GPU
// we need to check whether this param is in the global memory of the
// GPU
return m_GlobalArgs.find(PVD) != m_GlobalArgs.end();

return false;
Expand Down Expand Up @@ -450,6 +451,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);
Expand Down Expand Up @@ -601,9 +604,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

m_Derivative->setParams(params);
// Match the global arguments of the call to the device function to the
// pullback function's parameters.
if (!m_DiffReq.GlobalArgsIndexes.empty())
for (auto index : m_DiffReq.GlobalArgsIndexes)
m_GlobalArgs.emplace(m_Derivative->getParamDecl(index));
// If the function is a global kernel, all its parameters reside in the
// global memory of the GPU
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);
Expand Down Expand Up @@ -1753,6 +1760,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
if (auto* DRE = dyn_cast<DeclRefExpr>(ArgDiff.getExpr())) {
// If the arg is used for differentiation of the function, then we
// cannot free it in the end as it's the result to be returned to the
// user.
if (m_ParamVarsWithDiff.find(DRE->getDecl()) ==
m_ParamVarsWithDiff.end())
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
Expand Down Expand Up @@ -2024,7 +2034,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature argument.
// call has a different name than the function's signature parameter.
if (!m_GlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
Expand Down

0 comments on commit 3d7c32f

Please sign in to comment.