Skip to content

Commit

Permalink
Make m_ParamVarsWithDiff a set instead of a map
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 10, 2024
1 parent 7070376 commit d19aa31
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ struct DiffRequest {
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Indexes of global args of function as a subset of Args.
std::unordered_set<size_t> GlobalArgsIndexes;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
Expand Down
8 changes: 5 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ namespace clad {
// several private/protected members of the visitor classes.
friend class ErrorEstimationHandler;
llvm::SmallVector<const clang::ValueDecl*, 16> m_IndependentVars;
/// Map used to keep track of parameter variables w.r.t which the
/// Set used to keep track of parameter variables w.r.t which the
/// the derivative (gradient) is being computed. This is separate from the
/// m_Variables map because all other intermediate variables will
/// not be stored here.
std::unordered_map<const clang::ValueDecl*, clang::Expr*> m_ParamVariables;
std::unordered_set<const clang::ValueDecl*> m_ParamVarsWithDiff;
/// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
Expand All @@ -56,6 +56,8 @@ namespace clad {
/// that will be put immediately in the beginning of derivative function
/// block.
Stmts m_Globals;
/// Global args of the function.
std::unordered_set<const clang::ParmVarDecl*> m_GlobalArgs;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
Expand Down Expand Up @@ -439,7 +441,7 @@ namespace clad {

/// Helper function that checks whether the function to be derived
/// is meant to be executed only by the GPU
bool shouldUseCudaAtomicOps();
bool shouldUseCudaAtomicOps(const clang::Expr* E);

/// Add call to cuda::atomicAdd for the given LHS and RHS expressions.
///
Expand Down
63 changes: 52 additions & 11 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTapeResult{*this, PushExpr, PopExpr, TapeRef};
}

bool ReverseModeVisitor::shouldUseCudaAtomicOps() {
return m_DiffReq->hasAttr<clang::CUDAGlobalAttr>();
bool ReverseModeVisitor::shouldUseCudaAtomicOps(const Expr* E) {
// Same as checking whether this is a function executed by the GPU
if (!m_GlobalArgs.empty())
if (auto* DRE = dyn_cast<DeclRefExpr>(E))
if (auto* PVD = dyn_cast<ParmVarDecl>(DRE->getDecl()))
// we need to check whether this param is in global memory of the GPU
return m_GlobalArgs.find(PVD) != m_GlobalArgs.end();

return false;
}

clang::Expr* ReverseModeVisitor::BuildCallToCudaAtomicAdd(clang::Expr* LHS,
Expand All @@ -121,8 +128,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema.BuildDeclarationNameExpr(SS, lookupResult, /*ADL=*/true).get();

Expr* finalLHS = LHS;
if (isa<ArraySubscriptExpr>(LHS))
if (auto* UO = dyn_cast<UnaryOperator>(LHS)) {
if (UO->getOpcode() == UnaryOperatorKind::UO_Deref)
finalLHS = UO->getSubExpr()->IgnoreImplicit();
} else if (!LHS->getType()->isPointerType() &&
!LHS->getType()->isReferenceType())
finalLHS = BuildOp(UnaryOperatorKind::UO_AddrOf, LHS);

llvm::SmallVector<Expr*, 2> atomicArgs = {finalLHS, RHS};

assert(!m_Builder.noOverloadExists(UnresolvedLookup, atomicArgs) &&
Expand Down Expand Up @@ -438,6 +450,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);

llvm::ArrayRef<ParmVarDecl*> paramsRef =
clad_compat::makeArrayRef(params.data(), params.size());
gradientFD->setParams(paramsRef);
Expand Down Expand Up @@ -585,6 +601,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

m_Derivative->setParams(params);
if (!m_DiffReq.GlobalArgsIndexes.empty())
for (auto index : m_DiffReq.GlobalArgsIndexes)
m_GlobalArgs.emplace(m_Derivative->getParamDecl(index));
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);
m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -1517,7 +1539,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildArraySubscript(target, forwSweepDerivativeIndices);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps()) {
if (shouldUseCudaAtomicOps(target)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(result, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -1577,8 +1599,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: not sure if this is generic.
// Don't update derivatives of record types.
if (!VD->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
addToCurrentBlock(add_assign, direction::reverse);
Expr* base = it->second;
if (auto* UO = dyn_cast<UnaryOperator>(it->second))
base = UO->getSubExpr()->IgnoreImpCasts();
if (shouldUseCudaAtomicOps(base)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
} else {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
addToCurrentBlock(add_assign, direction::reverse);
}
}
}
return StmtDiff(clonedDRE, it->second, it->second);
Expand Down Expand Up @@ -1721,9 +1752,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
if (m_ParamVariables.find(
if (m_ParamVarsWithDiff.find(
dyn_cast<clang::DeclRefExpr>(ArgDiff.getExpr())->getDecl()) ==
m_ParamVariables.end())
m_ParamVarsWithDiff.end())
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
}
Expr* call =
Expand Down Expand Up @@ -1990,6 +2021,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature argument.
if (!m_GlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
if (auto* param = dyn_cast<ParmVarDecl>(DRE->getDecl()))
if (m_GlobalArgs.find(param) != m_GlobalArgs.end())
pullbackRequest.GlobalArgsIndexes.emplace(i);

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
Expand Down Expand Up @@ -2603,7 +2644,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i);
Expr* gradElem = BuildArraySubscript(gradRef, {idx});
Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem);
if (shouldUseCudaAtomicOps())
if (shouldUseCudaAtomicOps(outputArgs[i]))
PostCallStmts.push_back(
BuildCallToCudaAtomicAdd(outputArgs[i], gradExpr));
else
Expand Down Expand Up @@ -2720,7 +2761,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps()) {
if (shouldUseCudaAtomicOps(diff_dx)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -4932,7 +4973,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Variables[*it] =
utils::BuildParenExpr(m_Sema, m_Variables[*it]);
}
m_ParamVariables[*it] = m_Variables[*it];
m_ParamVarsWithDiff.emplace(*it);
}
}
}
Expand Down

0 comments on commit d19aa31

Please sign in to comment.