Skip to content

Commit

Permalink
Add cudaMemset call after cudaMalloc for derivative pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 3, 2024
1 parent cddc21d commit 885ce5a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
16 changes: 16 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
if (FD->getNameAsString() == "cudaMalloc") {
if (auto addrOp = dyn_cast<UnaryOperator>(DerivedCallArgs[0])) {
if (addrOp->getOpcode() == UO_AddrOf) {
DerivedCallArgs[0] =
addrOp->getSubExpr(); // *x -> &x in cudaMalloc args
}
} else { // **x -> x in cudaMalloc args
DerivedCallArgs[0] = BuildOp(UO_Deref, DerivedCallArgs[0]);
}
llvm::SmallVector<Expr*, 3> args = {DerivedCallArgs[0],
getZeroInit(m_Context.IntTy),
DerivedCallArgs[1]};
addToCurrentBlock(call_dx, direction::forward);
addToCurrentBlock(GetFunctionCall("cudaMemset", "", args));
call_dx = nullptr;
}
return StmtDiff(call, call_dx);
}
// For calls to C-style memory deallocation functions, we do not need to
Expand Down
1 change: 1 addition & 0 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ double fn_memory(double *out, double *in) {
//CHECK-NEXT: double *_d_in_dev = nullptr;
//CHECK-NEXT: double *in_dev = nullptr;
//CHECK-NEXT: cudaMalloc(&_d_in_dev, 10 * sizeof(double));
//CHECK-NEXT: cudaMemset(_d_in_dev, 0, 10 * sizeof(double));
//CHECK-NEXT: cudaMalloc(&in_dev, 10 * sizeof(double));
//CHECK-NEXT: cudaMemcpy(in_dev, in, 10 * sizeof(double), cudaMemcpyHostToDevice);
//CHECK-NEXT: kernel_call<<<1, 10>>>(out, in_dev);
Expand Down

0 comments on commit 885ce5a

Please sign in to comment.