From d236dcfb367edc05057d9af601bb6e6dd23bae57 Mon Sep 17 00:00:00 2001 From: kchristin Date: Sun, 3 Nov 2024 16:37:34 +0200 Subject: [PATCH] Fix _r local vars being passed to cuda kernel pullbacks --- lib/Differentiator/ReverseModeVisitor.cpp | 61 ++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 08ea53015..7acccd469 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1895,7 +1895,63 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema); VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy)); PreCallStmts.push_back(BuildDeclStmt(dArgDecl)); - CallArgDx.push_back(BuildDeclRef(dArgDecl)); + DeclRefExpr* dArgRef = BuildDeclRef(dArgDecl); + if (isa(CE)) { + // Create variables to be allocated on the device and passed to + // kernel. These need to be pointers because cudaMlloc expects a + // double pointer as an arg. + Expr* sizeLiteral = ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, m_Context.getTypeSize(dArgTy)); + dArgTy = m_Context.getPointerType(dArgTy); + VarDecl* dArgDeclCUDA = + BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy)); + + // Create the cudaMemcpyDeviceToHost argument + IdentifierInfo* idInfo = &m_Context.Idents.get("cudaMemcpyKind"); + LookupResult result(m_Sema, idInfo, SourceLocation(), + Sema::LookupOrdinaryName); + m_Sema.LookupName(result, m_Sema.getCurScope()); + EnumDecl* cudaMemcpyKindDecl = nullptr; + for (NamedDecl* decl : result) + if ((cudaMemcpyKindDecl = dyn_cast(decl))) + break; + assert(cudaMemcpyKindDecl && "cudaMemcpyKind not found"); + QualType cudaMemcpyKindType = + m_Context.getTypeDeclType(cudaMemcpyKindDecl); + EnumConstantDecl* deviceToHostEnumDecl = nullptr; + for (EnumConstantDecl* enumConst : + cudaMemcpyKindDecl->enumerators()) { + if (enumConst->getName() == "cudaMemcpyDeviceToHost") { + deviceToHostEnumDecl = enumConst; + break; + } + } + assert(deviceToHostEnumDecl && "cudaMemcpyDeviceToHost not found"); + DeclRefExpr* deviceToHostDeclRef = + m_Sema.BuildDeclRefExpr(deviceToHostEnumDecl, cudaMemcpyKindType, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + SourceLocation(), nullptr); + + PreCallStmts.push_back(BuildDeclStmt(dArgDeclCUDA)); + Expr* refOp = BuildOp(UO_AddrOf, BuildDeclRef(dArgDeclCUDA)); + llvm::SmallVector mallocArgs = {refOp, sizeLiteral}; + PreCallStmts.push_back(GetFunctionCall("cudaMalloc", "", mallocArgs)); + llvm::SmallVector memsetArgs = { + BuildDeclRef(dArgDeclCUDA), getZeroInit(m_Context.IntTy), + sizeLiteral}; + PreCallStmts.push_back(GetFunctionCall("cudaMemset", "", memsetArgs)); + llvm::SmallVector cudaMemcpyArgs = { + BuildOp(UO_AddrOf, dArgRef), BuildDeclRef(dArgDeclCUDA), + sizeLiteral, deviceToHostDeclRef}; + PostCallStmts.push_back( + GetFunctionCall("cudaMemcpy", "", cudaMemcpyArgs)); + llvm::SmallVector freeArgs = {BuildDeclRef(dArgDeclCUDA)}; + PostCallStmts.push_back(GetFunctionCall("cudaFree", "", freeArgs)); + + // Update arg to be passed to pullback call + dArgRef = BuildDeclRef(dArgDeclCUDA); + } + CallArgDx.push_back(dArgRef); // Visit using uninitialized reference. argDiff = Visit(arg, BuildDeclRef(dArgDecl)); } @@ -2024,7 +2080,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* gradArgExpr = nullptr; QualType paramTy = FD->getParamDecl(idx)->getType(); if (!argDerivative || utils::isArrayOrPointerType(paramTy) || - isCladArrayType(argDerivative->getType())) + isCladArrayType(argDerivative->getType()) || + isa(CE)) gradArgExpr = argDerivative; else gradArgExpr =