diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index fe37046c4..663b24b47 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -47,7 +47,7 @@ struct DiffRequest { /// Args provided to the call to clad::gradient/differentiate. const clang::Expr* Args = nullptr; /// Indexes of global GPU args of function as a subset of Args. - std::unordered_set GlobalArgsIndexes; + std::vector GlobalArgsIndexes; /// Requested differentiation mode, forward or reverse. DiffMode Mode = DiffMode::unknown; /// If function appears in the call to clad::gradient/differentiate, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4af15553e..69123df5c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -563,6 +563,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback"; + for (auto index : m_DiffReq.GlobalArgsIndexes) + derivativeName += "_" + std::to_string(index); auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); auto paramTypes = ComputeParamTypes(args); @@ -1929,6 +1931,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If it has more args or f_darg0 was not found, we look for its pullback // function. const auto* MD = dyn_cast(FD); + std::vector globalCallArgs; if (!OverloadedDerivedFn) { size_t idx = 0; @@ -1994,12 +1997,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullback); // Try to find it in builtin derivatives + std::string customPullback = + clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; + // Add the indexes of the global args to the custom pullback name + if (!m_GlobalArgs.empty()) + for (size_t i = 0; i < pullbackCallArgs.size(); i++) + if (auto* DRE = dyn_cast(pullbackCallArgs[i])) + if (auto* param = dyn_cast(DRE->getDecl())) + if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) { + customPullback += "_" + std::to_string(i); + globalCallArgs.emplace_back(i); + } + if (baseDiff.getExpr()) pullbackCallArgs.insert( pullbackCallArgs.begin(), BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr())); - std::string customPullback = - clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; + OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), @@ -2035,12 +2049,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Mark the indexes of the global args. Necessary if the argument of the // call has a different name than the function's signature parameter. - if (!m_GlobalArgs.empty()) - for (size_t i = 0; i < pullbackCallArgs.size(); i++) - if (auto* DRE = dyn_cast(pullbackCallArgs[i])) - if (auto* param = dyn_cast(DRE->getDecl())) - if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) - pullbackRequest.GlobalArgsIndexes.emplace(i); + pullbackRequest.GlobalArgsIndexes = globalCallArgs; pullbackRequest.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu index 5d8d1696a..70ecd96c0 100644 --- a/test/CUDA/GradientKernels.cu +++ b/test/CUDA/GradientKernels.cu @@ -308,7 +308,7 @@ __global__ void kernel_with_device_call(double *out, double *in, double val) { //CHECK-NEXT: _d_out[index0] = 0.; //CHECK-NEXT: double _r0 = 0.; //CHECK-NEXT: double _r1 = 0.; -//CHECK-NEXT: device_fn_pullback(in[index0], val, _r_d0, &_r0, &_r1); +//CHECK-NEXT: device_fn_pullback_1(in[index0], val, _r_d0, &_r0, &_r1); //CHECK-NEXT: atomicAdd(_d_val, _r1); //CHECK-NEXT: } //CHECK-NEXT:} @@ -323,6 +323,11 @@ __global__ void kernel_with_device_call_2(double *out, double *in, double val) { out[index] = device_fn_2(in, val); } +__global__ void dup_kernel_with_device_call_2(double *out, double *in, double val) { + int index = threadIdx.x; + out[index] = device_fn_2(in, val); +} + // CHECK: void kernel_with_device_call_2_grad_0_2(double *out, double *in, double val, double *_d_out, double *_d_val) { //CHECK-NEXT: int _d_index = 0; //CHECK-NEXT: int index0 = threadIdx.x; @@ -333,7 +338,7 @@ __global__ void kernel_with_device_call_2(double *out, double *in, double val) { //CHECK-NEXT: double _r_d0 = _d_out[index0]; //CHECK-NEXT: _d_out[index0] = 0.; //CHECK-NEXT: double _r0 = 0.; -//CHECK-NEXT: device_fn_2_pullback(in, val, _r_d0, &_r0); +//CHECK-NEXT: device_fn_2_pullback_0_1(in, val, _r_d0, &_r0); //CHECK-NEXT: atomicAdd(_d_val, _r0); //CHECK-NEXT: } //CHECK-NEXT:} @@ -349,7 +354,7 @@ __global__ void kernel_with_device_call_2(double *out, double *in, double val) { //CHECK-NEXT: double _r_d0 = _d_out[index0]; //CHECK-NEXT: _d_out[index0] = 0.; //CHECK-NEXT: double _r0 = 0.; -//CHECK-NEXT: device_fn_2_pullback(in, val, _r_d0, _d_in, &_r0); +//CHECK-NEXT: device_fn_2_pullback_0_1_3(in, val, _r_d0, _d_in, &_r0); //CHECK-NEXT: _d_val += _r0; //CHECK-NEXT: } //CHECK-NEXT:} @@ -373,7 +378,7 @@ __global__ void kernel_with_device_call_3(double *out, double *in, double *val) //CHECK-NEXT: out[index0] = _t0; //CHECK-NEXT: double _r_d0 = _d_out[index0]; //CHECK-NEXT: _d_out[index0] = 0.; -//CHECK-NEXT: device_fn_3_pullback(in, val, _r_d0, _d_in, _d_val); +//CHECK-NEXT: device_fn_3_pullback_0_1_3_4(in, val, _r_d0, _d_in, _d_val); //CHECK-NEXT: } //CHECK-NEXT:} @@ -402,19 +407,19 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v //CHECK-NEXT: double _r_d0 = _d_out[index0]; //CHECK-NEXT: _d_out[index0] = 0.; //CHECK-NEXT: double _r0 = 0.; -//CHECK-NEXT: device_with_device_call_pullback(in, val, _r_d0, _d_in, &_r0); +//CHECK-NEXT: device_with_device_call_pullback_0_1_3(in, val, _r_d0, _d_in, &_r0); //CHECK-NEXT: _d_val += _r0; //CHECK-NEXT: } //CHECK-NEXT:} -// CHECK: __attribute__((device)) void device_fn_pullback(double in, double val, double _d_y, double *_d_in, double *_d_val) { +// CHECK: __attribute__((device)) void device_fn_pullback_1(double in, double val, double _d_y, double *_d_in, double *_d_val) { //CHECK-NEXT: { //CHECK-NEXT: *_d_in += _d_y; //CHECK-NEXT: *_d_val += _d_y; //CHECK-NEXT: } //CHECK-NEXT:} -// CHECK: __attribute__((device)) void device_fn_2_pullback(double *in, double val, double _d_y, double *_d_val) { +// CHECK: __attribute__((device)) void device_fn_2_pullback_0_1(double *in, double val, double _d_y, double *_d_val) { //CHECK-NEXT: unsigned int _t1 = blockIdx.x; //CHECK-NEXT: unsigned int _t0 = blockDim.x; //CHECK-NEXT: int _d_index = 0; @@ -422,7 +427,7 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v //CHECK-NEXT: *_d_val += _d_y; //CHECK-NEXT:} -// CHECK: __attribute__((device)) void device_fn_2_pullback(double *in, double val, double _d_y, double *_d_in, double *_d_val) { +// CHECK: __attribute__((device)) void device_fn_2_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) { //CHECK-NEXT: unsigned int _t1 = blockIdx.x; //CHECK-NEXT: unsigned int _t0 = blockDim.x; //CHECK-NEXT: int _d_index = 0; @@ -433,7 +438,7 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v //CHECK-NEXT: } //CHECK-NEXT:} -// CHECK: __attribute__((device)) void device_fn_3_pullback(double *in, double *val, double _d_y, double *_d_in, double *_d_val) { +// CHECK: __attribute__((device)) void device_fn_3_pullback_0_1_3_4(double *in, double *val, double _d_y, double *_d_in, double *_d_val) { //CHECK-NEXT: unsigned int _t1 = blockIdx.x; //CHECK-NEXT: unsigned int _t0 = blockDim.x; //CHECK-NEXT: int _d_index = 0; @@ -444,7 +449,7 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v //CHECK-NEXT: } //CHECK-NEXT:} -// CHECK: __attribute__((device)) void device_with_device_call_pullback(double *in, double val, double _d_y, double *_d_in, double *_d_val) { +// CHECK: __attribute__((device)) void device_with_device_call_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) { //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0.; //CHECK-NEXT: device_fn_4_pullback(in, val, _d_y, _d_in, &_r0); @@ -452,7 +457,7 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v //CHECK-NEXT: } //CHECK-NEXT:} -// CHECK: __attribute__((device)) void device_fn_4_pullback(double *in, double val, double _d_y, double *_d_in, double *_d_val) { +// CHECK: __attribute__((device)) void device_fn_4_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) { //CHECK-NEXT: unsigned int _t1 = blockIdx.x; //CHECK-NEXT: unsigned int _t0 = blockDim.x; //CHECK-NEXT: int _d_index = 0; @@ -690,6 +695,15 @@ int main(void) { INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + auto check_dup = clad::gradient(dup_kernel_with_device_call_2, "out, val"); // check that the pullback function is not regenerated + check_dup.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_val); + cudaMemcpy(res, d_val, sizeof(double), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("%s\n", cudaGetErrorString(cudaGetLastError())); // CHECK-EXEC: no error + printf("%0.2f\n", *res); // CHECK-EXEC: 50.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + auto test_device_3 = clad::gradient(kernel_with_device_call_2, "out, in"); test_device_3.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_in_double); cudaDeviceSynchronize();