Skip to content

Commit

Permalink
Add support of kernel pullback functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 18, 2024
1 parent 2d08ce1 commit e5fa81d
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 45 deletions.
12 changes: 12 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ ValueAndPushforward<int, int> cudaDeviceSynchronize_pushforward()
__attribute__((host)) {
return {cudaDeviceSynchronize(), 0};
}

void cudaMemcpy_pullback(void* destPtr, void* srcPtr, size_t count,
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr,
size_t* d_count, cudaMemcpyKind* d_kind)
__attribute__((host)) {
if (kind == cudaMemcpyDeviceToHost)
*d_kind = cudaMemcpyHostToDevice;
else if (kind == cudaMemcpyHostToDevice)
*d_kind = cudaMemcpyDeviceToHost;
cudaMemcpy(d_srcPtr, d_destPtr, count, *d_kind);
}

#endif

CUDA_HOST_DEVICE inline ValueAndPushforward<float, float>
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ namespace clad {
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true);
bool forCustomDerv = true, bool namespaceShouldExist = true,
clang::Expr* CUDAExecConfig = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
/// Shorthand to issues a warning or error.
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ namespace clad {
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
llvm::SmallVectorImpl<clang::Expr*>& outputArgs,
clang::Expr* CUDAExecConfig = nullptr);

public:
ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ namespace clad {
/// \returns The derivative function call.
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args,
clang::Expr* CUDAExecConfig = nullptr);

/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
Expand Down
5 changes: 4 additions & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ namespace clad {
}

bool IsMemoryFunction(const clang::FunctionDecl* FD) {

if (FD->getNameAsString() == "cudaMalloc")
return true;
#if CLANG_VERSION_MAJOR > 12
if (FD->getBuiltinID() == Builtin::BImalloc)
return true;
Expand All @@ -703,6 +704,8 @@ namespace clad {
}

bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) {
if (FD->getNameAsString() == "cudaFree")
return true;
#if CLANG_VERSION_MAJOR > 12
return FD->getBuiltinID() == Builtin::ID::BIfree;
#else
Expand Down
9 changes: 6 additions & 3 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/,
Expr* CUDAExecConfig /*=nullptr*/) {
CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);
Expand All @@ -265,8 +266,10 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
OverloadedFn = m_Sema
.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc,
CUDAExecConfig)
.get();

// Add the custom derivative to the set of derivatives.
// This is required in case the definition of the custom derivative
Expand Down
57 changes: 36 additions & 21 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
// if the function is a global kernel, all the adjoint parameters reside in
// the global memory of the GPU. To facilitate the process, all the params
// of the kernel are added to the set.
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);
Expand Down Expand Up @@ -610,7 +611,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!m_DiffReq.CUDAGlobalArgsIndexes.empty())
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index));

// If the function is a global kernel, all its parameters reside in the
// global memory of the GPU
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_CUDAGlobalArgs.emplace(param);
m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -1646,6 +1651,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

Expr* CUDAExecConfig = nullptr;
if (auto KCE = dyn_cast<CUDAKernelCallExpr>(CE))
CUDAExecConfig = Clone(KCE->getConfig());

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
Expand All @@ -1654,10 +1663,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ClonedArgs.push_back(Clone(CE->getArg(i)));

SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema);
Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc)
.get();
Expr* Call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc, CUDAExecConfig)
.get();
// Creating a zero derivative
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
Expand Down Expand Up @@ -1804,7 +1814,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
llvm::MutableArrayRef<Expr*>(CallArgs), Loc,
CUDAExecConfig)
.get();
return call;
}
Expand Down Expand Up @@ -1916,7 +1927,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
const_cast<DeclContext*>(FD->getDeclContext()), true, true,
CUDAExecConfig);
if (OverloadedDerivedFn)
asGrad = false;
}
Expand Down Expand Up @@ -2018,7 +2030,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
const_cast<DeclContext*>(FD->getDeclContext()), true, true,
CUDAExecConfig);
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}
Expand All @@ -2034,10 +2047,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();

OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef,
Loc, pullbackCallArgs, Loc)
.get();
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef, Loc,
pullbackCallArgs, Loc, CUDAExecConfig)
.get();
} else {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
Expand Down Expand Up @@ -2089,14 +2103,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn = GetSingleArgCentralDiffCall(
Clone(CE->getCallee()), DerivedCallArgs[0],
/*targetPos=*/0,
/*numArgs=*/1, DerivedCallArgs);
/*numArgs=*/1, DerivedCallArgs, CUDAExecConfig);
asGrad = !OverloadedDerivedFn;
} else {
auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema);
OverloadedDerivedFn = GetMultiArgCentralDiffCall(
Clone(CE->getCallee()), CEType.getCanonicalType(),
CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts,
DerivedCallArgs, CallArgDx);
DerivedCallArgs, CallArgDx, CUDAExecConfig);
}
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!OverloadedDerivedFn) {
Expand All @@ -2114,7 +2128,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
Loc, pullbackCallArgs, Loc)
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
}
}
Expand Down Expand Up @@ -2227,7 +2241,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
call = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(calleeFnForwPassFD), Loc,
CallArgs, Loc)
CallArgs, Loc, CUDAExecConfig)
.get();
}
auto* callRes = StoreAndRef(call);
Expand Down Expand Up @@ -2261,7 +2275,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
CallArgs, Loc)
CallArgs, Loc, CUDAExecConfig)
.get();
return StmtDiff(call);

Expand All @@ -2273,7 +2287,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVectorImpl<Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<Expr*>& args,
llvm::SmallVectorImpl<Expr*>& outputArgs) {
llvm::SmallVectorImpl<Expr*>& outputArgs,
Expr* CUDAExecConfig /*=nullptr*/) {
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
llvm::SmallVector<Expr*, 16U> NumDiffArgs = {};
NumDiffArgs.push_back(targetFuncCall);
Expand Down Expand Up @@ -2314,7 +2329,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
/*namespaceShouldExist=*/false, CUDAExecConfig);
}

StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,8 @@ namespace clad {

Expr* VisitorBase::GetSingleArgCentralDiffCall(
Expr* targetFuncCall, Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<Expr*>& args) {
unsigned numArgs, llvm::SmallVectorImpl<Expr*>& args,
Expr* CUDAExecConfig /*=nullptr*/) {
QualType argType = targetArg->getType();
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
bool isSupported = argType->isArithmeticType();
Expand All @@ -786,7 +787,7 @@ namespace clad {
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
/*namespaceShouldExist=*/false, CUDAExecConfig);
}

void VisitorBase::CallExprDiffDiagnostics(const clang::FunctionDecl* FD,
Expand Down
Loading

0 comments on commit e5fa81d

Please sign in to comment.