Skip to content

Commit

Permalink
Pull from cuda-pullback
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 15, 2024
2 parents a4aff07 + 8accef3 commit 22399c4
Show file tree
Hide file tree
Showing 12 changed files with 697 additions and 52 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct DiffRequest {
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Indexes of global GPU args of function as a subset of Args.
std::vector<size_t> GlobalArgsIndexes;
std::vector<size_t> CUDAGlobalArgsIndexes;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
Expand Down
332 changes: 332 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace clad {
/// block.
Stmts m_Globals;
/// Global GPU args of the function.
std::unordered_set<const clang::ParmVarDecl*> m_GlobalArgs;
std::unordered_set<const clang::ParmVarDecl*> m_CUDAGlobalArgs;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
Expand Down
13 changes: 11 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,17 @@ namespace clad {
/// \param[in] D The declaration to build a DeclRefExpr for.
/// \param[in] SS The scope specifier for the declaration.
/// \returns the DeclRefExpr for the given declaration.
clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D,
const clang::CXXScopeSpec* SS = nullptr);
clang::DeclRefExpr*
BuildDeclRef(clang::DeclaratorDecl* D,
const clang::CXXScopeSpec* SS = nullptr,
clang::ExprValueKind VK = clang::VK_LValue);
/// Builds a DeclRefExpr to a given Decl, adding proper nested name
/// qualifiers.
/// \param[in] D The declaration to build a DeclRefExpr for.
/// \param[in] NNS The nested name specifier to use.
clang::DeclRefExpr*
BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS,
clang::ExprValueKind VK = clang::VK_LValue);

/// Stores the result of an expression in a temporary variable (of the same
/// type as is the result of the expression) and returns a reference to it.
Expand Down
7 changes: 4 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,9 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
// Sema::BuildDeclRefExpr is responsible for adding captured fields
// to the underlying struct of a lambda.
if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) {
auto referencedDecl = cast<VarDecl>(clonedDRE->getDecl());
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(referencedDecl));
NestedNameSpecifier* NNS = DRE->getQualifier();
auto* referencedDecl = cast<VarDecl>(clonedDRE->getDecl());
clonedDRE = BuildDeclRef(referencedDecl, NNS);
}
} else
clonedDRE = cast<DeclRefExpr>(Clone(DRE));
Expand All @@ -1052,7 +1053,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
if (auto dVarDRE = dyn_cast<DeclRefExpr>(dExpr)) {
auto dVar = cast<VarDecl>(dVarDRE->getDecl());
if (dVar->getDeclContext() != m_Sema.CurContext)
dExpr = BuildDeclRef(dVar);
dExpr = BuildDeclRef(dVar, DRE->getQualifier());
}
return StmtDiff(clonedDRE, dExpr);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/ConstantFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ namespace clad {
SourceLocation noLoc;
Expr* cast = CXXStaticCastExpr::Create(
C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
clang::CastKind::CK_IntegralCast, Result, nullptr,
clang::CastKind::CK_IntegralCast, Result, /*CXXCastPath=*/nullptr,
C.getTrivialTypeSourceInfo(QT, noLoc)
CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO,
noLoc, noLoc, SourceRange());
Expand Down
74 changes: 48 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

bool ReverseModeVisitor::shouldUseCudaAtomicOps(const Expr* E) {
// Same as checking whether this is a function executed by the GPU
if (!m_GlobalArgs.empty())
if (!m_CUDAGlobalArgs.empty())
if (const auto* DRE = dyn_cast<DeclRefExpr>(E))
if (const auto* PVD = dyn_cast<ParmVarDecl>(DRE->getDecl()))
// we need to check whether this param is in the global memory of the
// GPU
return m_GlobalArgs.find(PVD) != m_GlobalArgs.end();
// Check whether this param is in the global memory of the GPU
return m_CUDAGlobalArgs.find(PVD) != m_CUDAGlobalArgs.end();

return false;
}
Expand Down Expand Up @@ -454,8 +453,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);

llvm::ArrayRef<ParmVarDecl*> paramsRef =
clad_compat::makeArrayRef(params.data(), params.size());
Expand Down Expand Up @@ -563,7 +562,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
for (auto index : m_DiffReq.GlobalArgsIndexes)
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
derivativeName += "_" + std::to_string(index);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

Expand Down Expand Up @@ -608,14 +607,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Derivative->setParams(params);
// Match the global arguments of the call to the device function to the
// pullback function's parameters.
if (!m_DiffReq.GlobalArgsIndexes.empty())
for (auto index : m_DiffReq.GlobalArgsIndexes)
m_GlobalArgs.emplace(m_Derivative->getParamDecl(index));
if (!m_DiffReq.CUDAGlobalArgsIndexes.empty())
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index));
// If the function is a global kernel, all its parameters reside in the
// global memory of the GPU
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);
m_CUDAGlobalArgs.emplace(param);
m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -1573,8 +1572,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// with Sema::BuildDeclRefExpr. This is required in some cases, e.g.
// Sema::BuildDeclRefExpr is responsible for adding captured fields
// to the underlying struct of a lambda.
if (VD->getDeclContext() != m_Sema.CurContext)
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(VD));
if (VD->getDeclContext() != m_Sema.CurContext) {
auto* ccDRE = dyn_cast<DeclRefExpr>(clonedDRE);
NestedNameSpecifier* NNS = DRE->getQualifier();
auto* referencedDecl = cast<VarDecl>(ccDRE->getDecl());
clonedDRE = BuildDeclRef(referencedDecl, NNS, DRE->getValueKind());
}
// This case happens when ref-type variables have to become function
// global. Ref-type declarations cannot be moved to the function global
// scope because they can't be separated from their inits.
Expand Down Expand Up @@ -1900,9 +1903,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

Expr* OverloadedDerivedFn = nullptr;
// If the function has a single arg and does not returns a reference or take
// If the function has a single arg and does not return a reference or take
// arg by reference, we look for a derivative w.r.t. to this arg using the
// forward mode(it is unlikely that we need gradient of a one-dimensional'
// forward mode(it is unlikely that we need gradient of a one-dimensional
// function).
bool asGrad = true;

Expand Down Expand Up @@ -2000,11 +2003,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
// Add the indexes of the global args to the custom pullback name
if (!m_GlobalArgs.empty())
if (!m_CUDAGlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
if (auto* param = dyn_cast<ParmVarDecl>(DRE->getDecl()))
if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) {
if (m_CUDAGlobalArgs.find(param) != m_CUDAGlobalArgs.end()) {
customPullback += "_" + std::to_string(i);
globalCallArgs.emplace_back(i);
}
Expand Down Expand Up @@ -2049,7 +2052,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
pullbackRequest.GlobalArgsIndexes = globalCallArgs;
pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs;

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
Expand Down Expand Up @@ -2214,8 +2217,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff argDiff = Visit(arg);
CallArgs.push_back(argDiff.getExpr_dx());
}
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
if (Expr* baseE = baseDiff.getExpr()) {
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, Loc);
} else {
Expand All @@ -2232,6 +2234,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, resAdjoint, resAdjoint);
} // Recreate the original call expression.

if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
auto* FD = const_cast<CXXMethodDecl*>(
dyn_cast<CXXMethodDecl>(OCE->getCalleeDecl()));

NestedNameSpecifierLoc NNS(FD->getQualifier(),
/*Data=*/nullptr);
auto DAP = DeclAccessPair::make(FD, FD->getAccess());
auto* memberExpr = MemberExpr::Create(
m_Context, Clone(OCE->getArg(0)), /*isArrow=*/false, Loc, NNS, noLoc,
FD, DAP, FD->getNameInfo(),
/*TemplateArgs=*/nullptr, m_Context.BoundMemberTy,
CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams(
NOUR_None));
call = m_Sema
.BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc,
CallArgs, Loc)
.get();
return StmtDiff(call);
}

call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
CallArgs, Loc)
Expand Down Expand Up @@ -2668,11 +2692,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i);
Expr* gradElem = BuildArraySubscript(gradRef, {idx});
Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem);
if (shouldUseCudaAtomicOps(outputArgs[i]))
PostCallStmts.push_back(
BuildCallToCudaAtomicAdd(outputArgs[i], gradExpr));
else
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
// Inputs were not pointers, so the output args are not in global GPU
// memory. Hence, no need to use atomic ops.
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
Expand Down Expand Up @@ -2779,7 +2801,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else {
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (dfdx() && derivedE) {
if (shouldUseCudaAtomicOps(diff_dx)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx());
// Add it to the body statements.
Expand Down
31 changes: 29 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,38 @@ namespace clad {
}

DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D,
const CXXScopeSpec* SS /*=nullptr*/) {
const CXXScopeSpec* SS /*=nullptr*/,
ExprValueKind VK /*=VK_LValue*/) {
QualType T = D->getType();
T = T.getNonReferenceType();
return cast<DeclRefExpr>(clad_compat::GetResult<Expr*>(
m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS)));
m_Sema.BuildDeclRefExpr(D, T, VK, D->getBeginLoc(), SS)));
}

DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D,
NestedNameSpecifier* NNS,
ExprValueKind VK /*=VK_LValue*/) {
std::vector<NestedNameSpecifier*> NNChain;
CXXScopeSpec CSS;
while (NNS) {
NNChain.push_back(NNS);
NNS = NNS->getPrefix();
}

std::reverse(NNChain.begin(), NNChain.end());

for (size_t i = 0; i < NNChain.size(); ++i) {
NNS = NNChain[i];
// FIXME: this needs to be extended to support more NNS kinds. An
// inspiration can be take from getFullyQualifiedNestedNameSpecifier in
// llvm-project/clang/lib/AST/QualTypeNames.cpp
if (NNS->getKind() == NestedNameSpecifier::Namespace) {
NamespaceDecl* NS = NNS->getAsNamespace();
CSS.Extend(m_Context, NS, noLoc, noLoc);
}
}

return BuildDeclRef(D, &CSS, VK);
}

IdentifierInfo*
Expand Down
6 changes: 3 additions & 3 deletions test/Gradient/Lambdas.C
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ double f1(double i, double j) {
}

// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const;
// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK: void f1_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: auto _f = []{{ ?}}(double t) {
// CHECK-NEXT: return t * t + 1.;
// CHECK-NEXT: }{{;?}}
Expand All @@ -34,12 +34,12 @@ double f2(double i, double j) {
}

// CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const;
// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK: void f2_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) {
// CHECK-NEXT: return t + k;
// CHECK-NEXT: }{{;?}}
// CHECK: double _d_x = 0.;
// CHECK-NEXT: double x = operator()(i + j, i);
// CHECK-NEXT: double x = _f.operator()(i + j, i);
// CHECK-NEXT: _d_x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
Expand Down
Loading

0 comments on commit 22399c4

Please sign in to comment.