Skip to content

Commit

Permalink
Add test on correctly passing non-ref params to cuda kernel pullbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 3, 2024
1 parent 34815b4 commit b8dd9a0
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,66 @@ double fn_memory(double *out, double *in) {
//CHECK-NEXT: cudaFree(_d_in_dev);
//CHECK-NEXT:}

void launch_add_kernel_4(int *out, int *in, const int N) {
int *in_dev = nullptr;
cudaMalloc(&in_dev, N * sizeof(int));
cudaMemcpy(in_dev, in, N * sizeof(int), cudaMemcpyHostToDevice);
int *out_dev = nullptr;
cudaMalloc(&out_dev, N * sizeof(int));
cudaMemcpy(out_dev, out, N * sizeof(int), cudaMemcpyHostToDevice);

add_kernel_4<<<1, 5>>>(out_dev, in_dev, N);

cudaMemcpy(out, out_dev, N * sizeof(int), cudaMemcpyDeviceToHost);
cudaFree(in_dev);
cudaFree(out_dev);
}

// CHECK: void launch_add_kernel_4_grad_0_1(int *out, int *in, const int N, int *_d_out, int *_d_in) {
//CHECK-NEXT: int *_d_in_dev = nullptr;
//CHECK-NEXT: int *in_dev = nullptr;
//CHECK-NEXT: cudaMalloc(&_d_in_dev, N * sizeof(int));
//CHECK-NEXT: cudaMemset(_d_in_dev, 0, N * sizeof(int));
//CHECK-NEXT: cudaMalloc(&in_dev, N * sizeof(int));
//CHECK-NEXT: cudaMemcpy(in_dev, in, N * sizeof(int), cudaMemcpyHostToDevice);
//CHECK-NEXT: int *_d_out_dev = nullptr;
//CHECK-NEXT: int *out_dev = nullptr;
//CHECK-NEXT: cudaMalloc(&_d_out_dev, N * sizeof(int));
//CHECK-NEXT: cudaMemset(_d_out_dev, 0, N * sizeof(int));
//CHECK-NEXT: cudaMalloc(&out_dev, N * sizeof(int));
//CHECK-NEXT: cudaMemcpy(out_dev, out, N * sizeof(int), cudaMemcpyHostToDevice);
//CHECK-NEXT: add_kernel_4<<<1, 5>>>(out_dev, in_dev, N);
//CHECK-NEXT: cudaMemcpy(out, out_dev, N * sizeof(int), cudaMemcpyDeviceToHost);
//CHECK-NEXT: {
//CHECK-NEXT: unsigned long _r6 = 0UL;
//CHECK-NEXT: cudaMemcpyKind _r7 = static_cast<cudaMemcpyKind>(0U);
//CHECK-NEXT: clad::custom_derivatives::cudaMemcpy_pullback(out, out_dev, N * sizeof(int), cudaMemcpyDeviceToHost, _d_out, _d_out_dev, &_r6, &_r7);
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: int _r4 = 0;
//CHECK-NEXT: int *_r5 = nullptr;
//CHECK-NEXT: cudaMalloc(&_r5, 32);
//CHECK-NEXT: cudaMemset(_r5, 0, 32);
//CHECK-NEXT: add_kernel_4_pullback<<<1, 5>>>(out_dev, in_dev, N, _d_out_dev, _d_in_dev, _r5);
//CHECK-NEXT: cudaMemcpy(&_r4, _r5, 32, cudaMemcpyDeviceToHost);
//CHECK-NEXT: cudaFree(_r5);
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: unsigned long _r2 = 0UL;
//CHECK-NEXT: cudaMemcpyKind _r3 = static_cast<cudaMemcpyKind>(0U);
//CHECK-NEXT: clad::custom_derivatives::cudaMemcpy_pullback(out_dev, out, N * sizeof(int), cudaMemcpyHostToDevice, _d_out_dev, _d_out, &_r2, &_r3);
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: unsigned long _r0 = 0UL;
//CHECK-NEXT: cudaMemcpyKind _r1 = static_cast<cudaMemcpyKind>(0U);
//CHECK-NEXT: clad::custom_derivatives::cudaMemcpy_pullback(in_dev, in, N * sizeof(int), cudaMemcpyHostToDevice, _d_in_dev, _d_in, &_r0, &_r1);
//CHECK-NEXT: }
//CHECK-NEXT: cudaFree(in_dev);
//CHECK-NEXT: cudaFree(_d_in_dev);
//CHECK-NEXT: cudaFree(out_dev);
//CHECK-NEXT: cudaFree(_d_out_dev);
//CHECK-NEXT:}

// CHECK: __attribute__((device)) void device_fn_pullback_1(double in, double val, double _d_y, double *_d_in, double *_d_val) {
//CHECK-NEXT: {
//CHECK-NEXT: *_d_in += _d_y;
Expand Down Expand Up @@ -559,6 +619,66 @@ double fn_memory(double *out, double *in) {
//CHECK-NEXT: }
//CHECK-NEXT:}

// CHECK: __attribute__((global)) void add_kernel_4_pullback(int *out, int *in, int N, int *_d_out, int *_d_in, int *_d_N) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: int _d_sum = 0;
//CHECK-NEXT: int sum = 0;
//CHECK-NEXT: unsigned long _t2;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<int> _t3 = {};
//CHECK-NEXT: clad::tape<int> _t4 = {};
//CHECK-NEXT: int _t5;
//CHECK-NEXT: unsigned int _t1 = blockIdx.x;
//CHECK-NEXT: unsigned int _t0 = blockDim.x;
//CHECK-NEXT: int _d_index = 0;
//CHECK-NEXT: int index0 = threadIdx.x + _t1 * _t0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = index0 < N;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: sum = 0;
//CHECK-NEXT: _t2 = 0UL;
//CHECK-NEXT: for (i = index0; ; clad::push(_t3, i) , (i += warpSize)) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i < N))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t4, sum);
//CHECK-NEXT: sum += in[i];
//CHECK-NEXT: }
//CHECK-NEXT: _t5 = out[index0];
//CHECK-NEXT: out[index0] = sum;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: {
//CHECK-NEXT: out[index0] = _t5;
//CHECK-NEXT: int _r_d2 = _d_out[index0];
//CHECK-NEXT: _d_out[index0] = 0;
//CHECK-NEXT: _d_sum += _r_d2;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: for (;; _t2--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t2)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: i = clad::pop(_t3);
//CHECK-NEXT: int _r_d0 = _d_i;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t4);
//CHECK-NEXT: int _r_d1 = _d_sum;
//CHECK-NEXT: atomicAdd(&_d_in[i], _r_d1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _d_index += _d_i;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT:}

#define TEST(F, grid, block, shared_mem, use_stream, x, dx, N) \
{ \
int *fives = (int*)malloc(N * sizeof(int)); \
Expand Down Expand Up @@ -816,9 +936,23 @@ int main(void) {
test_memory.execute(dummy_out_double, fives, d_out_double, zeros);
printf("%0.2f, %0.2f, %0.2f\n", zeros[0], zeros[1], zeros[2]); // CHECK-EXEC: 60.00, 0.00, 0.00

auto launch_kernel_4_test = clad::gradient(launch_add_kernel_4, "out, in");
int *out_res = (int*)malloc(10 * sizeof(int));
int *in_res = (int*)calloc(10, sizeof(int));
int *zeros_int = (int*)calloc(10, sizeof(int));
int *fives_int = (int*)malloc(10 * sizeof(int));
for(int i = 0; i < 10; i++) { fives_int[i] = 5; out_res[i] = 5; }

launch_kernel_4_test.execute(zeros_int, fives_int, 10, out_res, in_res);
printf("%d, %d, %d\n", in_res[0], in_res[1], in_res[2]); // CHECK-EXEC: 5, 5 5

free(res);
free(fives);
free(zeros);
free(fives_int);
free(zeros_int);
free(out_res);
free(in_res);
cudaFree(d_out_double);
cudaFree(d_in_double);
cudaFree(val);
Expand Down

0 comments on commit b8dd9a0

Please sign in to comment.