From e5fa81d45d7ae436f690ec326b01b6d9b5943681 Mon Sep 17 00:00:00 2001 From: Christina Koutsou Date: Tue, 15 Oct 2024 19:27:27 +0000 Subject: [PATCH] Add support of kernel pullback functions --- .../clad/Differentiator/BuiltinDerivatives.h | 12 ++ .../clad/Differentiator/DerivativeBuilder.h | 3 +- .../clad/Differentiator/ReverseModeVisitor.h | 3 +- include/clad/Differentiator/VisitorBase.h | 3 +- lib/Differentiator/CladUtils.cpp | 5 +- lib/Differentiator/DerivativeBuilder.cpp | 9 +- lib/Differentiator/ReverseModeVisitor.cpp | 57 +++++---- lib/Differentiator/VisitorBase.cpp | 5 +- test/CUDA/GradientKernels.cu | 113 +++++++++++++++--- 9 files changed, 165 insertions(+), 45 deletions(-) diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index 557274a56..1d3f96aba 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -82,6 +82,18 @@ ValueAndPushforward cudaDeviceSynchronize_pushforward() __attribute__((host)) { return {cudaDeviceSynchronize(), 0}; } + +void cudaMemcpy_pullback(void* destPtr, void* srcPtr, size_t count, + cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr, + size_t* d_count, cudaMemcpyKind* d_kind) + __attribute__((host)) { + if (kind == cudaMemcpyDeviceToHost) + *d_kind = cudaMemcpyHostToDevice; + else if (kind == cudaMemcpyHostToDevice) + *d_kind = cudaMemcpyDeviceToHost; + cudaMemcpy(d_srcPtr, d_destPtr, count, *d_kind); +} + #endif CUDA_HOST_DEVICE inline ValueAndPushforward diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 9ac7165bf..5e9d54ac2 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -118,7 +118,8 @@ namespace clad { clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, clang::Scope* S, clang::DeclContext* originalFnDC, - bool forCustomDerv = true, bool namespaceShouldExist = true); + bool forCustomDerv = true, bool namespaceShouldExist = true, + clang::Expr* CUDAExecConfig = nullptr); bool noOverloadExists(clang::Expr* UnresolvedLookup, llvm::MutableArrayRef ARargs); /// Shorthand to issues a warning or error. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index dabcfd256..7781d7eb2 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -354,7 +354,8 @@ namespace clad { clang::Expr* dfdx, llvm::SmallVectorImpl& PreCallStmts, llvm::SmallVectorImpl& PostCallStmts, llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs); + llvm::SmallVectorImpl& outputArgs, + clang::Expr* CUDAExecConfig = nullptr); public: ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index dba1540a2..210f82112 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -603,7 +603,8 @@ namespace clad { /// \returns The derivative function call. clang::Expr* GetSingleArgCentralDiffCall( clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos, - unsigned numArgs, llvm::SmallVectorImpl& args); + unsigned numArgs, llvm::SmallVectorImpl& args, + clang::Expr* CUDAExecConfig = nullptr); /// Emits diagnostic messages on differentiation (or lack thereof) for /// call expressions. diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 39e98f12f..350eeea07 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -679,7 +679,8 @@ namespace clad { } bool IsMemoryFunction(const clang::FunctionDecl* FD) { - + if (FD->getNameAsString() == "cudaMalloc") + return true; #if CLANG_VERSION_MAJOR > 12 if (FD->getBuiltinID() == Builtin::BImalloc) return true; @@ -703,6 +704,8 @@ namespace clad { } bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) { + if (FD->getNameAsString() == "cudaFree") + return true; #if CLANG_VERSION_MAJOR > 12 return FD->getBuiltinID() == Builtin::ID::BIfree; #else diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index e4ba3f99a..ada7153c6 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -246,7 +246,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, clang::Scope* S, clang::DeclContext* originalFnDC, - bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { + bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/, + Expr* CUDAExecConfig /*=nullptr*/) { CXXScopeSpec SS; LookupResult R = LookupCustomDerivativeOrNumericalDiff( Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist); @@ -265,8 +266,10 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { if (noOverloadExists(UnresolvedLookup, MARargs)) return nullptr; - OverloadedFn = - m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); + OverloadedFn = m_Sema + .ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc, + CUDAExecConfig) + .get(); // Add the custom derivative to the set of derivatives. // This is required in case the definition of the custom derivative diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 26746f5b0..ffdfea1b8 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -450,8 +450,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParams(params); - // if the function is a global kernel, all its parameters reside in the - // global memory of the GPU + // if the function is a global kernel, all the adjoint parameters reside in + // the global memory of the GPU. To facilitate the process, all the params + // of the kernel are added to the set. if (m_DiffReq->hasAttr()) for (auto* param : params) m_CUDAGlobalArgs.emplace(param); @@ -610,7 +611,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!m_DiffReq.CUDAGlobalArgsIndexes.empty()) for (auto index : m_DiffReq.CUDAGlobalArgsIndexes) m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index)); - + // If the function is a global kernel, all its parameters reside in the + // global memory of the GPU + else if (m_DiffReq->hasAttr()) + for (auto param : params) + m_CUDAGlobalArgs.emplace(param); m_Derivative->setBody(nullptr); if (!m_DiffReq.DeclarationOnly) { @@ -1646,6 +1651,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + Expr* CUDAExecConfig = nullptr; + if (auto KCE = dyn_cast(CE)) + CUDAExecConfig = Clone(KCE->getConfig()); + // If the function is non_differentiable, return zero derivative. if (clad::utils::hasNonDifferentiableAttribute(CE)) { // Calling the function without computing derivatives @@ -1654,10 +1663,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ClonedArgs.push_back(Clone(CE->getArg(i))); SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema); - Expr* Call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - validLoc, ClonedArgs, validLoc) - .get(); + Expr* Call = + m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, ClonedArgs, validLoc, CUDAExecConfig) + .get(); // Creating a zero derivative auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, /*val=*/0); @@ -1804,7 +1814,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, - llvm::MutableArrayRef(CallArgs), Loc) + llvm::MutableArrayRef(CallArgs), Loc, + CUDAExecConfig) .get(); return call; } @@ -1916,7 +1927,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, pushforwardCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + const_cast(FD->getDeclContext()), true, true, + CUDAExecConfig); if (OverloadedDerivedFn) asGrad = false; } @@ -2018,7 +2030,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + const_cast(FD->getDeclContext()), true, true, + CUDAExecConfig); if (baseDiff.getExpr()) pullbackCallArgs.erase(pullbackCallArgs.begin()); } @@ -2034,10 +2047,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) .get(); - OverloadedDerivedFn = m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, - Loc, pullbackCallArgs, Loc) - .get(); + OverloadedDerivedFn = + m_Sema + .ActOnCallExpr(getCurrentScope(), selfRef, Loc, + pullbackCallArgs, Loc, CUDAExecConfig) + .get(); } else { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( @@ -2089,14 +2103,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = GetSingleArgCentralDiffCall( Clone(CE->getCallee()), DerivedCallArgs[0], /*targetPos=*/0, - /*numArgs=*/1, DerivedCallArgs); + /*numArgs=*/1, DerivedCallArgs, CUDAExecConfig); asGrad = !OverloadedDerivedFn; } else { auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema); OverloadedDerivedFn = GetMultiArgCentralDiffCall( Clone(CE->getCallee()), CEType.getCanonicalType(), CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts, - DerivedCallArgs, CallArgDx); + DerivedCallArgs, CallArgDx, CUDAExecConfig); } CallExprDiffDiagnostics(FD, CE->getBeginLoc()); if (!OverloadedDerivedFn) { @@ -2114,7 +2128,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Sema .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), - Loc, pullbackCallArgs, Loc) + Loc, pullbackCallArgs, Loc, CUDAExecConfig) .get(); } } @@ -2227,7 +2241,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, call = m_Sema .ActOnCallExpr(getCurrentScope(), BuildDeclRef(calleeFnForwPassFD), Loc, - CallArgs, Loc) + CallArgs, Loc, CUDAExecConfig) .get(); } auto* callRes = StoreAndRef(call); @@ -2261,7 +2275,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, - CallArgs, Loc) + CallArgs, Loc, CUDAExecConfig) .get(); return StmtDiff(call); @@ -2273,7 +2287,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVectorImpl& PreCallStmts, llvm::SmallVectorImpl& PostCallStmts, llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs) { + llvm::SmallVectorImpl& outputArgs, + Expr* CUDAExecConfig /*=nullptr*/) { int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); llvm::SmallVector NumDiffArgs = {}; NumDiffArgs.push_back(targetFuncCall); @@ -2314,7 +2329,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, /*forCustomDerv=*/false, - /*namespaceShouldExist=*/false); + /*namespaceShouldExist=*/false, CUDAExecConfig); } StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 63ae3c369..4ff317bfd 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -763,7 +763,8 @@ namespace clad { Expr* VisitorBase::GetSingleArgCentralDiffCall( Expr* targetFuncCall, Expr* targetArg, unsigned targetPos, - unsigned numArgs, llvm::SmallVectorImpl& args) { + unsigned numArgs, llvm::SmallVectorImpl& args, + Expr* CUDAExecConfig /*=nullptr*/) { QualType argType = targetArg->getType(); int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); bool isSupported = argType->isArithmeticType(); @@ -786,7 +787,7 @@ namespace clad { Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, /*forCustomDerv=*/false, - /*namespaceShouldExist=*/false); + /*namespaceShouldExist=*/false, CUDAExecConfig); } void VisitorBase::CallExprDiffDiagnostics(const clang::FunctionDecl* FD, diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu index a60604fa6..87a8daa28 100644 --- a/test/CUDA/GradientKernels.cu +++ b/test/CUDA/GradientKernels.cu @@ -412,6 +412,81 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v //CHECK-NEXT: } //CHECK-NEXT:} +__global__ void kernel_call(double *a, double *b) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + a[index] = *b; +} + +void fn(double *out, double *in) { + kernel_call<<<1, 10>>>(out, in); +} + +// CHECK: void fn_grad(double *out, double *in, double *_d_out, double *_d_in) { +//CHECK-NEXT: kernel_call<<<1, 10>>>(out, in); +//CHECK-NEXT: kernel_call_pullback<<<1, 10>>>(out, in, _d_out, _d_in); +//CHECK-NEXT: } + +double fn_memory(double *out, double *in) { + kernel_call<<<1, 10>>>(out, in); + cudaDeviceSynchronize(); + double *out_host = (double*)malloc(10 * sizeof(double)); + cudaMemcpy(out_host, out, 10 * sizeof(double), cudaMemcpyDeviceToHost); + double res = 0; + for (int i=0; i < 10; ++i) { + res += out_host[i]; + } + free(out_host); + cudaFree(out); + cudaFree(in); + return res; +} + +// CHECK: void fn_memory_grad(double *out, double *in, double *_d_out, double *_d_in) { +//CHECK-NEXT: int _d_i = 0; +//CHECK-NEXT: int i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: kernel_call<<<1, 10>>>(out, in); +//CHECK-NEXT: cudaDeviceSynchronize(); +//CHECK-NEXT: double *_d_out_host = (double *)malloc(10 * sizeof(double)); +//CHECK-NEXT: double *out_host = (double *)malloc(10 * sizeof(double)); +//CHECK-NEXT: cudaMemcpy(out_host, out, 10 * sizeof(double), cudaMemcpyDeviceToHost); +//CHECK-NEXT: double _d_res = 0.; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: unsigned long _t0 = 0UL; +//CHECK-NEXT: for (i = 0; ; ++i) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i < 10)) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += out_host[i]; +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: --i; +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: _d_out_host[i] += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: unsigned long _r0 = 0UL; +//CHECK-NEXT: cudaMemcpyKind _r1 = static_cast(0U); +//CHECK-NEXT: clad::custom_derivatives::cudaMemcpy_pullback(out_host, out, 10 * sizeof(double), cudaMemcpyDeviceToHost, _d_out_host, _d_out, &_r0, &_r1); +//CHECK-NEXT: } +//CHECK-NEXT: kernel_call_pullback<<<1, 10>>>(out, in, _d_out, _d_in); +//CHECK-NEXT: free(out_host); +//CHECK-NEXT: free(_d_out_host); +//CHECK-NEXT: cudaFree(out); +//CHECK-NEXT: cudaFree(in); +//CHECK-NEXT:} + // CHECK: __attribute__((device)) void device_fn_pullback_1(double in, double val, double _d_y, double *_d_in, double *_d_val) { //CHECK-NEXT: { //CHECK-NEXT: *_d_in += _d_y; @@ -609,22 +684,12 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v #define INIT(x, y, val, dx, dy, d_val) \ { \ - double *fives = (double*)malloc(10 * sizeof(double)); \ - for(int i = 0; i < 10; i++) { \ - fives[i] = 5; \ - } \ - double *zeros = (double*)malloc(10 * sizeof(double)); \ - for(int i = 0; i < 10; i++) { \ - zeros[i] = 0; \ - } \ cudaMemcpy(x, fives, 10 * sizeof(double), cudaMemcpyHostToDevice); \ cudaMemcpy(y, zeros, 10 * sizeof(double), cudaMemcpyHostToDevice); \ cudaMemcpy(val, fives, sizeof(double), cudaMemcpyHostToDevice); \ cudaMemcpy(dx, zeros, 10 * sizeof(double), cudaMemcpyHostToDevice); \ cudaMemcpy(dy, fives, 10 * sizeof(double), cudaMemcpyHostToDevice); \ cudaMemcpy(d_val, zeros, sizeof(double), cudaMemcpyHostToDevice); \ - free(fives); \ - free(zeros); \ } int main(void) { @@ -644,7 +709,6 @@ int main(void) { cudaFree(a); cudaFree(d_a); - int *dummy_in, *dummy_out, *d_out, *d_in; cudaMalloc(&dummy_in, 10 * sizeof(int)); cudaMalloc(&dummy_out, 10 * sizeof(int)); @@ -671,10 +735,13 @@ int main(void) { TEST_2_D(add_kernel_7, dim3(1), dim3(5, 1, 1), 0, false, "a, b", dummy_out_double, dummy_in_double, d_out_double, d_in_double, 10); // CHECK-EXEC: 50.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 - double *val; + double *val, *d_val; cudaMalloc(&val, sizeof(double)); - double *d_val; cudaMalloc(&d_val, sizeof(double)); + + double *fives = (double*)malloc(10 * sizeof(double)); + double *zeros = (double*)malloc(10 * sizeof(double)); + for(int i = 0; i < 10; i++) { fives[i] = 5; zeros[i] = 0; } INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); @@ -723,15 +790,31 @@ int main(void) { INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + auto test_kernel_call = clad::gradient(fn); + test_kernel_call.execute(dummy_out_double, dummy_in_double, d_out_double, d_in_double); + cudaDeviceSynchronize(); + cudaMemcpy(res, d_in_double, sizeof(double), cudaMemcpyDeviceToHost); + printf("%0.2f\n", *res); // CHECK-EXEC: 50.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + auto nested_device = clad::gradient(kernel_with_nested_device_call, "out, in"); nested_device.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_in_double); cudaDeviceSynchronize(); cudaMemcpy(res, d_in_double, 10 * sizeof(double), cudaMemcpyDeviceToHost); printf("%0.2f, %0.2f, %0.2f\n", res[0], res[1], res[2]); // CHECK-EXEC: 5.00, 5.00, 5.00 + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto test_memory = clad::gradient(fn_memory); + test_memory.execute(dummy_out_double, dummy_in_double, d_out_double, d_in_double); + cudaDeviceSynchronize(); + cudaMemcpy(res, d_in_double, 10 * sizeof(double), cudaMemcpyDeviceToHost); + printf("%0.2f, %0.2f, %0.2f\n", res[0], res[1], res[2]); // CHECK-EXEC: 10.00, 0.00, 0.00 + free(res); - cudaFree(dummy_in_double); - cudaFree(dummy_out_double); + free(fives); + free(zeros); cudaFree(d_out_double); cudaFree(d_in_double); cudaFree(val);