From 3d7c32fd0f9bfc369c14eb361e53e3f23cfb7cde Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 12 Oct 2024 10:55:05 +0300 Subject: [PATCH] Add comments --- include/clad/Differentiator/DiffPlanner.h | 2 +- include/clad/Differentiator/ReverseModeVisitor.h | 2 +- lib/Differentiator/ReverseModeVisitor.cpp | 14 ++++++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 728b18504..fe37046c4 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -46,7 +46,7 @@ struct DiffRequest { clang::CallExpr* CallContext = nullptr; /// Args provided to the call to clad::gradient/differentiate. const clang::Expr* Args = nullptr; - /// Indexes of global args of function as a subset of Args. + /// Indexes of global GPU args of function as a subset of Args. std::unordered_set GlobalArgsIndexes; /// Requested differentiation mode, forward or reverse. DiffMode Mode = DiffMode::unknown; diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 83b2b9ec1..e58d77398 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -56,7 +56,7 @@ namespace clad { /// that will be put immediately in the beginning of derivative function /// block. Stmts m_Globals; - /// Global args of the function. + /// Global GPU args of the function. std::unordered_set m_GlobalArgs; //// A reference to the output parameter of the gradient function. clang::Expr* m_Result; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a203cabf7..4af15553e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -109,7 +109,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!m_GlobalArgs.empty()) if (const auto* DRE = dyn_cast(E)) if (const auto* PVD = dyn_cast(DRE->getDecl())) - // we need to check whether this param is in global memory of the GPU + // we need to check whether this param is in the global memory of the + // GPU return m_GlobalArgs.find(PVD) != m_GlobalArgs.end(); return false; @@ -450,6 +451,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParams(params); + // if the function is a global kernel, all its parameters reside in the + // global memory of the GPU if (m_DiffReq->hasAttr()) for (auto param : params) m_GlobalArgs.emplace(param); @@ -601,9 +604,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterCreatingDerivedFnParams(params); m_Derivative->setParams(params); + // Match the global arguments of the call to the device function to the + // pullback function's parameters. if (!m_DiffReq.GlobalArgsIndexes.empty()) for (auto index : m_DiffReq.GlobalArgsIndexes) m_GlobalArgs.emplace(m_Derivative->getParamDecl(index)); + // If the function is a global kernel, all its parameters reside in the + // global memory of the GPU else if (m_DiffReq->hasAttr()) for (auto param : params) m_GlobalArgs.emplace(param); @@ -1753,6 +1760,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ArgDiff = Visit(Arg, dfdx()); CallArgs.push_back(ArgDiff.getExpr()); if (auto* DRE = dyn_cast(ArgDiff.getExpr())) { + // If the arg is used for differentiation of the function, then we + // cannot free it in the end as it's the result to be returned to the + // user. if (m_ParamVarsWithDiff.find(DRE->getDecl()) == m_ParamVarsWithDiff.end()) DerivedCallArgs.push_back(ArgDiff.getExpr_dx()); @@ -2024,7 +2034,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.Function = FD; // Mark the indexes of the global args. Necessary if the argument of the - // call has a different name than the function's signature argument. + // call has a different name than the function's signature parameter. if (!m_GlobalArgs.empty()) for (size_t i = 0; i < pullbackCallArgs.size(); i++) if (auto* DRE = dyn_cast(pullbackCallArgs[i]))