Skip to content

Commit

Permalink
Modify pullback function name w.r.t. global args
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 12, 2024
1 parent 3d7c32f commit d579393
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct DiffRequest {
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Indexes of global GPU args of function as a subset of Args.
std::unordered_set<size_t> GlobalArgsIndexes;
std::vector<size_t> GlobalArgsIndexes;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
Expand Down
25 changes: 17 additions & 8 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
for (auto index : m_DiffReq.GlobalArgsIndexes)
derivativeName += "_" + std::to_string(index);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -1929,6 +1931,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
std::vector<size_t> globalCallArgs;
if (!OverloadedDerivedFn) {
size_t idx = 0;

Expand Down Expand Up @@ -1994,12 +1997,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullback);

// Try to find it in builtin derivatives
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
// Add the indexes of the global args to the custom pullback name
if (!m_GlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
if (auto* param = dyn_cast<ParmVarDecl>(DRE->getDecl()))
if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) {
customPullback += "_" + std::to_string(i);
globalCallArgs.emplace_back(i);
}

if (baseDiff.getExpr())
pullbackCallArgs.insert(
pullbackCallArgs.begin(),
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";

OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
Expand Down Expand Up @@ -2035,12 +2049,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
if (!m_GlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
if (auto* param = dyn_cast<ParmVarDecl>(DRE->getDecl()))
if (m_GlobalArgs.find(param) != m_GlobalArgs.end())
pullbackRequest.GlobalArgsIndexes.emplace(i);
pullbackRequest.GlobalArgsIndexes = globalCallArgs;

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
Expand Down
36 changes: 25 additions & 11 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ __global__ void kernel_with_device_call(double *out, double *in, double val) {
//CHECK-NEXT: _d_out[index0] = 0.;
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: device_fn_pullback(in[index0], val, _r_d0, &_r0, &_r1);
//CHECK-NEXT: device_fn_pullback_1(in[index0], val, _r_d0, &_r0, &_r1);
//CHECK-NEXT: atomicAdd(_d_val, _r1);
//CHECK-NEXT: }
//CHECK-NEXT:}
Expand All @@ -323,6 +323,11 @@ __global__ void kernel_with_device_call_2(double *out, double *in, double val) {
out[index] = device_fn_2(in, val);
}

__global__ void dup_kernel_with_device_call_2(double *out, double *in, double val) {
int index = threadIdx.x;
out[index] = device_fn_2(in, val);
}

// CHECK: void kernel_with_device_call_2_grad_0_2(double *out, double *in, double val, double *_d_out, double *_d_val) {
//CHECK-NEXT: int _d_index = 0;
//CHECK-NEXT: int index0 = threadIdx.x;
Expand All @@ -333,7 +338,7 @@ __global__ void kernel_with_device_call_2(double *out, double *in, double val) {
//CHECK-NEXT: double _r_d0 = _d_out[index0];
//CHECK-NEXT: _d_out[index0] = 0.;
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: device_fn_2_pullback(in, val, _r_d0, &_r0);
//CHECK-NEXT: device_fn_2_pullback_0_1(in, val, _r_d0, &_r0);
//CHECK-NEXT: atomicAdd(_d_val, _r0);
//CHECK-NEXT: }
//CHECK-NEXT:}
Expand All @@ -349,7 +354,7 @@ __global__ void kernel_with_device_call_2(double *out, double *in, double val) {
//CHECK-NEXT: double _r_d0 = _d_out[index0];
//CHECK-NEXT: _d_out[index0] = 0.;
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: device_fn_2_pullback(in, val, _r_d0, _d_in, &_r0);
//CHECK-NEXT: device_fn_2_pullback_0_1_3(in, val, _r_d0, _d_in, &_r0);
//CHECK-NEXT: _d_val += _r0;
//CHECK-NEXT: }
//CHECK-NEXT:}
Expand All @@ -373,7 +378,7 @@ __global__ void kernel_with_device_call_3(double *out, double *in, double *val)
//CHECK-NEXT: out[index0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_out[index0];
//CHECK-NEXT: _d_out[index0] = 0.;
//CHECK-NEXT: device_fn_3_pullback(in, val, _r_d0, _d_in, _d_val);
//CHECK-NEXT: device_fn_3_pullback_0_1_3_4(in, val, _r_d0, _d_in, _d_val);
//CHECK-NEXT: }
//CHECK-NEXT:}

Expand Down Expand Up @@ -402,27 +407,27 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v
//CHECK-NEXT: double _r_d0 = _d_out[index0];
//CHECK-NEXT: _d_out[index0] = 0.;
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: device_with_device_call_pullback(in, val, _r_d0, _d_in, &_r0);
//CHECK-NEXT: device_with_device_call_pullback_0_1_3(in, val, _r_d0, _d_in, &_r0);
//CHECK-NEXT: _d_val += _r0;
//CHECK-NEXT: }
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_fn_pullback(double in, double val, double _d_y, double *_d_in, double *_d_val) {
// CHECK: __attribute__((device)) void device_fn_pullback_1(double in, double val, double _d_y, double *_d_in, double *_d_val) {
//CHECK-NEXT: {
//CHECK-NEXT: *_d_in += _d_y;
//CHECK-NEXT: *_d_val += _d_y;
//CHECK-NEXT: }
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_fn_2_pullback(double *in, double val, double _d_y, double *_d_val) {
// CHECK: __attribute__((device)) void device_fn_2_pullback_0_1(double *in, double val, double _d_y, double *_d_val) {
//CHECK-NEXT: unsigned int _t1 = blockIdx.x;
//CHECK-NEXT: unsigned int _t0 = blockDim.x;
//CHECK-NEXT: int _d_index = 0;
//CHECK-NEXT: int index0 = threadIdx.x + _t1 * _t0;
//CHECK-NEXT: *_d_val += _d_y;
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_fn_2_pullback(double *in, double val, double _d_y, double *_d_in, double *_d_val) {
// CHECK: __attribute__((device)) void device_fn_2_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) {
//CHECK-NEXT: unsigned int _t1 = blockIdx.x;
//CHECK-NEXT: unsigned int _t0 = blockDim.x;
//CHECK-NEXT: int _d_index = 0;
Expand All @@ -433,7 +438,7 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v
//CHECK-NEXT: }
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_fn_3_pullback(double *in, double *val, double _d_y, double *_d_in, double *_d_val) {
// CHECK: __attribute__((device)) void device_fn_3_pullback_0_1_3_4(double *in, double *val, double _d_y, double *_d_in, double *_d_val) {
//CHECK-NEXT: unsigned int _t1 = blockIdx.x;
//CHECK-NEXT: unsigned int _t0 = blockDim.x;
//CHECK-NEXT: int _d_index = 0;
Expand All @@ -444,15 +449,15 @@ __global__ void kernel_with_nested_device_call(double *out, double *in, double v
//CHECK-NEXT: }
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_with_device_call_pullback(double *in, double val, double _d_y, double *_d_in, double *_d_val) {
// CHECK: __attribute__((device)) void device_with_device_call_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) {
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: device_fn_4_pullback(in, val, _d_y, _d_in, &_r0);
//CHECK-NEXT: *_d_val += _r0;
//CHECK-NEXT: }
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_fn_4_pullback(double *in, double val, double _d_y, double *_d_in, double *_d_val) {
// CHECK: __attribute__((device)) void device_fn_4_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) {
//CHECK-NEXT: unsigned int _t1 = blockIdx.x;
//CHECK-NEXT: unsigned int _t0 = blockDim.x;
//CHECK-NEXT: int _d_index = 0;
Expand Down Expand Up @@ -690,6 +695,15 @@ int main(void) {

INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val);

auto check_dup = clad::gradient(dup_kernel_with_device_call_2, "out, val"); // check that the pullback function is not regenerated
check_dup.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_val);
cudaMemcpy(res, d_val, sizeof(double), cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
printf("%s\n", cudaGetErrorString(cudaGetLastError())); // CHECK-EXEC: no error
printf("%0.2f\n", *res); // CHECK-EXEC: 50.00

INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val);

auto test_device_3 = clad::gradient(kernel_with_device_call_2, "out, in");
test_device_3.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_in_double);
cudaDeviceSynchronize();
Expand Down

0 comments on commit d579393

Please sign in to comment.