diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 11c0c1981..108dc431a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1766,6 +1766,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, llvm::MutableArrayRef(DerivedCallArgs), Loc) .get(); + if (FD->getNameAsString() == "cudaMalloc") { + if (auto addrOp = dyn_cast(DerivedCallArgs[0])) { + if (addrOp->getOpcode() == UO_AddrOf) { + DerivedCallArgs[0] = + addrOp->getSubExpr(); // *x -> &x in cudaMalloc args + } + } else { // **x -> x in cudaMalloc args + DerivedCallArgs[0] = BuildOp(UO_Deref, DerivedCallArgs[0]); + } + llvm::SmallVector args = {DerivedCallArgs[0], + getZeroInit(m_Context.IntTy), + DerivedCallArgs[1]}; + addToCurrentBlock(call_dx, direction::forward); + addToCurrentBlock(GetFunctionCall("cudaMemset", "", args)); + call_dx = nullptr; + } return StmtDiff(call, call_dx); } // For calls to C-style memory deallocation functions, we do not need to diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu index 92f69d92c..fcbae3b5b 100644 --- a/test/CUDA/GradientKernels.cu +++ b/test/CUDA/GradientKernels.cu @@ -451,6 +451,7 @@ double fn_memory(double *out, double *in) { //CHECK-NEXT: double *_d_in_dev = nullptr; //CHECK-NEXT: double *in_dev = nullptr; //CHECK-NEXT: cudaMalloc(&_d_in_dev, 10 * sizeof(double)); +//CHECK-NEXT: cudaMemset(_d_in_dev, 0, 10 * sizeof(double)); //CHECK-NEXT: cudaMalloc(&in_dev, 10 * sizeof(double)); //CHECK-NEXT: cudaMemcpy(in_dev, in, 10 * sizeof(double), cudaMemcpyHostToDevice); //CHECK-NEXT: kernel_call<<<1, 10>>>(out, in_dev);