Skip to content

Commit

Permalink
Fix _r local vars being passed to cuda kernel pullbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 3, 2024
1 parent dd1a37e commit d236dcf
Showing 1 changed file with 59 additions and 2 deletions.
61 changes: 59 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1895,7 +1895,63 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));
PreCallStmts.push_back(BuildDeclStmt(dArgDecl));
CallArgDx.push_back(BuildDeclRef(dArgDecl));
DeclRefExpr* dArgRef = BuildDeclRef(dArgDecl);
if (isa<CUDAKernelCallExpr>(CE)) {
// Create variables to be allocated on the device and passed to
// kernel. These need to be pointers because cudaMlloc expects a
// double pointer as an arg.
Expr* sizeLiteral = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, m_Context.getTypeSize(dArgTy));
dArgTy = m_Context.getPointerType(dArgTy);
VarDecl* dArgDeclCUDA =
BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));

// Create the cudaMemcpyDeviceToHost argument
IdentifierInfo* idInfo = &m_Context.Idents.get("cudaMemcpyKind");
LookupResult result(m_Sema, idInfo, SourceLocation(),
Sema::LookupOrdinaryName);
m_Sema.LookupName(result, m_Sema.getCurScope());
EnumDecl* cudaMemcpyKindDecl = nullptr;
for (NamedDecl* decl : result)
if ((cudaMemcpyKindDecl = dyn_cast<EnumDecl>(decl)))
break;
assert(cudaMemcpyKindDecl && "cudaMemcpyKind not found");
QualType cudaMemcpyKindType =
m_Context.getTypeDeclType(cudaMemcpyKindDecl);
EnumConstantDecl* deviceToHostEnumDecl = nullptr;
for (EnumConstantDecl* enumConst :
cudaMemcpyKindDecl->enumerators()) {
if (enumConst->getName() == "cudaMemcpyDeviceToHost") {
deviceToHostEnumDecl = enumConst;
break;
}
}
assert(deviceToHostEnumDecl && "cudaMemcpyDeviceToHost not found");
DeclRefExpr* deviceToHostDeclRef =
m_Sema.BuildDeclRefExpr(deviceToHostEnumDecl, cudaMemcpyKindType,
CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
SourceLocation(), nullptr);

PreCallStmts.push_back(BuildDeclStmt(dArgDeclCUDA));
Expr* refOp = BuildOp(UO_AddrOf, BuildDeclRef(dArgDeclCUDA));
llvm::SmallVector<Expr*, 3> mallocArgs = {refOp, sizeLiteral};
PreCallStmts.push_back(GetFunctionCall("cudaMalloc", "", mallocArgs));
llvm::SmallVector<Expr*, 3> memsetArgs = {
BuildDeclRef(dArgDeclCUDA), getZeroInit(m_Context.IntTy),
sizeLiteral};
PreCallStmts.push_back(GetFunctionCall("cudaMemset", "", memsetArgs));
llvm::SmallVector<Expr*, 4> cudaMemcpyArgs = {
BuildOp(UO_AddrOf, dArgRef), BuildDeclRef(dArgDeclCUDA),
sizeLiteral, deviceToHostDeclRef};
PostCallStmts.push_back(
GetFunctionCall("cudaMemcpy", "", cudaMemcpyArgs));
llvm::SmallVector<Expr*, 3> freeArgs = {BuildDeclRef(dArgDeclCUDA)};
PostCallStmts.push_back(GetFunctionCall("cudaFree", "", freeArgs));

// Update arg to be passed to pullback call
dArgRef = BuildDeclRef(dArgDeclCUDA);
}
CallArgDx.push_back(dArgRef);
// Visit using uninitialized reference.
argDiff = Visit(arg, BuildDeclRef(dArgDecl));
}
Expand Down Expand Up @@ -2024,7 +2080,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* gradArgExpr = nullptr;
QualType paramTy = FD->getParamDecl(idx)->getType();
if (!argDerivative || utils::isArrayOrPointerType(paramTy) ||
isCladArrayType(argDerivative->getType()))
isCladArrayType(argDerivative->getType()) ||
isa<CUDAKernelCallExpr>(CE))
gradArgExpr = argDerivative;
else
gradArgExpr =
Expand Down

0 comments on commit d236dcf

Please sign in to comment.