diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index efd3d629c..b9901e20f 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -178,7 +178,7 @@ static inline IfStmt* IfStmt_Create(const ASTContext &Ctx, #endif } -// Compatibility helper function for creation CallExpr. +// Compatibility helper function for creation CallExpr and CUDAKernelCallExpr. // Clang 12 and above use one extra param. #if CLANG_VERSION_MAJOR < 12 @@ -188,6 +188,15 @@ static inline CallExpr* CallExpr_Create(const ASTContext &Ctx, Expr *Fn, ArrayRe { return CallExpr::Create(Ctx, Fn, Args, Ty, VK, RParenLoc, MinNumArgs, UsesADL); } + +static inline CUDAKernelCallExpr* +CUDAKernelCallExpr_Create(const ASTContext& Ctx, Expr* Fn, CallExpr* Config, + ArrayRef Args, QualType Ty, ExprValueKind VK, + SourceLocation RParenLoc, unsigned MinNumArgs = 0, + CallExpr::ADLCallKind UsesADL = CallExpr::NotADL) { + return CUDAKernelCallExpr::Create(Ctx, Fn, Config, Args, Ty, VK, RParenLoc, + MinNumArgs); +} #elif CLANG_VERSION_MAJOR >= 12 static inline CallExpr* CallExpr_Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr *> Args, QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures, @@ -195,6 +204,16 @@ static inline CallExpr* CallExpr_Create(const ASTContext &Ctx, Expr *Fn, ArrayRe { return CallExpr::Create(Ctx, Fn, Args, Ty, VK, RParenLoc, FPFeatures, MinNumArgs, UsesADL); } + +static inline CUDAKernelCallExpr* +CUDAKernelCallExpr_Create(const ASTContext& Ctx, Expr* Fn, CallExpr* Config, + ArrayRef Args, QualType Ty, ExprValueKind VK, + SourceLocation RParenLoc, + FPOptionsOverride FPFeatures, unsigned MinNumArgs = 0, + CallExpr::ADLCallKind UsesADL = CallExpr::NotADL) { + return CUDAKernelCallExpr::Create(Ctx, Fn, Config, Args, Ty, VK, RParenLoc, + FPFeatures, MinNumArgs); +} #endif // Clang 12 and above use one extra param. diff --git a/include/clad/Differentiator/StmtClone.h b/include/clad/Differentiator/StmtClone.h index 00c901cfa..83d91599c 100644 --- a/include/clad/Differentiator/StmtClone.h +++ b/include/clad/Differentiator/StmtClone.h @@ -104,6 +104,7 @@ namespace utils { DECLARE_CLONE_FN(ExtVectorElementExpr) DECLARE_CLONE_FN(UnaryExprOrTypeTraitExpr) DECLARE_CLONE_FN(CallExpr) + DECLARE_CLONE_FN(CUDAKernelCallExpr) DECLARE_CLONE_FN(ShuffleVectorExpr) DECLARE_CLONE_FN(ExprWithCleanups) DECLARE_CLONE_FN(CXXOperatorCallExpr) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 381547745..91c77fc4d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1930,15 +1930,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diag(DiagnosticsEngine::Error, CE->getEndLoc(), "Failed to create cudaMemcpy call; cudaMemcpyDeviceToHost not " "found. Creating kernel pullback aborted."); - for (std::size_t a = 0; a < CE->getNumArgs(); ++a) - CallArgs.push_back( - Clone(CE->getArg(a))); // create a non-const copy - Expr* call = - m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - Loc, CallArgs, Loc, CUDAExecConfig) - .get(); - return StmtDiff(call); + return Clone(CE); } CXXScopeSpec SS; Expr* deviceToHostExpr = @@ -1947,20 +1939,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /*ADL=*/false) .get(); if (!deviceToHostExpr) { - diag( - DiagnosticsEngine::Error, CE->getEndLoc(), - "Failed to create cudaMemcpy call; Failed to create expression " - "for cudaMemcpyDeviceToHost. Creating kernel pullback " - "aborted."); - for (std::size_t a = 0; a < CE->getNumArgs(); ++a) - CallArgs.push_back( - Clone(CE->getArg(a))); // create a non-const copy - Expr* call = - m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - Loc, CallArgs, Loc, CUDAExecConfig) - .get(); - return StmtDiff(call); + diag(DiagnosticsEngine::Error, CE->getEndLoc(), + "Failed to create cudaMemcpy call; Failed to create " + "expression " + "for cudaMemcpyDeviceToHost. Creating kernel pullback " + "aborted."); + return Clone(CE); } // Add calls to cudaMalloc, cudaMemset, cudaMemcpy, and cudaFree diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 15b32aebe..08bef1ad0 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -327,6 +327,21 @@ Stmt* StmtClone::VisitCallExpr(CallExpr* Node) { return result; } +Stmt* StmtClone::VisitCUDAKernelCallExpr(CUDAKernelCallExpr* Node) { + CUDAKernelCallExpr* result = clad_compat::CUDAKernelCallExpr_Create( + Ctx, Clone(Node->getCallee()), Clone(Node->getConfig()), + llvm::ArrayRef(), CloneType(Node->getType()), Node->getValueKind(), + Node->getRParenLoc() CLAD_COMPAT_CLANG8_CallExpr_ExtraParams); + result->setNumArgsUnsafe(Node->getNumArgs()); + for (unsigned i = 0, e = Node->getNumArgs(); i < e; ++i) + result->setArg(i, Clone(Node->getArg(i))); + + // Copy Value and Type dependent + clad_compat::ExprSetDeps(result, Node); + + return result; +} + Stmt* StmtClone::VisitUnresolvedLookupExpr(UnresolvedLookupExpr* Node) { TemplateArgumentListInfo TemplateArgs; if (Node->hasExplicitTemplateArgs())