From 4f2a378c68aa4a159be51bc885292540b9add7f4 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 8 Oct 2024 10:58:41 +0200 Subject: [PATCH 1/8] Fix the generation of invalid code in some common cases (#1088) This commit fixes the way Clad generates code. Specifically, it addresses the way operators appear in the generated code in the reverse mode and the way nested name qualifiers are built in both modes. This partially addresses #1050. Fixes: #1087 --- include/clad/Differentiator/VisitorBase.h | 13 +++- lib/Differentiator/BaseForwardModeVisitor.cpp | 7 +- lib/Differentiator/ReverseModeVisitor.cpp | 37 +++++++-- lib/Differentiator/VisitorBase.cpp | 31 +++++++- test/Gradient/Lambdas.C | 6 +- test/ValidCodeGen/ValidCodeGen.C | 76 +++++++++++++++++++ 6 files changed, 154 insertions(+), 16 deletions(-) create mode 100644 test/ValidCodeGen/ValidCodeGen.C diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 917088c42..dba1540a2 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -362,8 +362,17 @@ namespace clad { /// \param[in] D The declaration to build a DeclRefExpr for. /// \param[in] SS The scope specifier for the declaration. /// \returns the DeclRefExpr for the given declaration. - clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, - const clang::CXXScopeSpec* SS = nullptr); + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, + const clang::CXXScopeSpec* SS = nullptr, + clang::ExprValueKind VK = clang::VK_LValue); + /// Builds a DeclRefExpr to a given Decl, adding proper nested name + /// qualifiers. + /// \param[in] D The declaration to build a DeclRefExpr for. + /// \param[in] NNS The nested name specifier to use. + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS, + clang::ExprValueKind VK = clang::VK_LValue); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index fe14227b7..8015b8fdb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1036,8 +1036,9 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { - auto referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = cast(BuildDeclRef(referencedDecl)); + NestedNameSpecifier* NNS = DRE->getQualifier(); + auto* referencedDecl = cast(clonedDRE->getDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS); } } else clonedDRE = cast(Clone(DRE)); @@ -1052,7 +1053,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (auto dVarDRE = dyn_cast(dExpr)) { auto dVar = cast(dVarDRE->getDecl()); if (dVar->getDeclContext() != m_Sema.CurContext) - dExpr = BuildDeclRef(dVar); + dExpr = BuildDeclRef(dVar, DRE->getQualifier()); } return StmtDiff(clonedDRE, dExpr); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e90036593..2a7a16aa4 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1544,8 +1544,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. - if (VD->getDeclContext() != m_Sema.CurContext) - clonedDRE = cast(BuildDeclRef(VD)); + if (VD->getDeclContext() != m_Sema.CurContext) { + auto* ccDRE = dyn_cast(clonedDRE); + NestedNameSpecifier* NNS = DRE->getQualifier(); + auto* referencedDecl = cast(ccDRE->getDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS, DRE->getValueKind()); + } // This case happens when ref-type variables have to become function // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. @@ -1852,9 +1856,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } Expr* OverloadedDerivedFn = nullptr; - // If the function has a single arg and does not returns a reference or take + // If the function has a single arg and does not return a reference or take // arg by reference, we look for a derivative w.r.t. to this arg using the - // forward mode(it is unlikely that we need gradient of a one-dimensional' + // forward mode(it is unlikely that we need gradient of a one-dimensional // function). bool asGrad = true; @@ -2149,8 +2153,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff argDiff = Visit(arg); CallArgs.push_back(argDiff.getExpr_dx()); } - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); + if (Expr* baseE = baseDiff.getExpr()) { call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), CallArgs, Loc); } else { @@ -2167,6 +2170,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. + + if (const auto* OCE = dyn_cast(CE)) { + auto* FD = const_cast( + dyn_cast(OCE->getCalleeDecl())); + + NestedNameSpecifierLoc NNS(FD->getQualifier(), + /*Data=*/nullptr); + auto DAP = DeclAccessPair::make(FD, FD->getAccess()); + auto* memberExpr = MemberExpr::Create( + m_Context, Clone(OCE->getArg(0)), /*isArrow=*/false, Loc, NNS, noLoc, + FD, DAP, FD->getNameInfo(), + /*TemplateArgs=*/nullptr, m_Context.BoundMemberTy, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + NOUR_None)); + call = m_Sema + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, + CallArgs, Loc) + .get(); + return StmtDiff(call); + } + call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, CallArgs, Loc) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index f739abd2f..b6156397c 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -236,11 +236,38 @@ namespace clad { } DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, - const CXXScopeSpec* SS /*=nullptr*/) { + const CXXScopeSpec* SS /*=nullptr*/, + ExprValueKind VK /*=VK_LValue*/) { QualType T = D->getType(); T = T.getNonReferenceType(); return cast(clad_compat::GetResult( - m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS))); + m_Sema.BuildDeclRefExpr(D, T, VK, D->getBeginLoc(), SS))); + } + + DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, + NestedNameSpecifier* NNS, + ExprValueKind VK /*=VK_LValue*/) { + std::vector NNChain; + CXXScopeSpec CSS; + while (NNS) { + NNChain.push_back(NNS); + NNS = NNS->getPrefix(); + } + + std::reverse(NNChain.begin(), NNChain.end()); + + for (size_t i = 0; i < NNChain.size(); ++i) { + NNS = NNChain[i]; + // FIXME: this needs to be extended to support more NNS kinds. An + // inspiration can be take from getFullyQualifiedNestedNameSpecifier in + // llvm-project/clang/lib/AST/QualTypeNames.cpp + if (NNS->getKind() == NestedNameSpecifier::Namespace) { + NamespaceDecl* NS = NNS->getAsNamespace(); + CSS.Extend(m_Context, NS, noLoc, noLoc); + } + } + + return BuildDeclRef(D, &CSS, VK); } IdentifierInfo* diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index f9b06aeeb..35776e2d6 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -13,7 +13,7 @@ double f1(double i, double j) { } // CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const; -// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK: void f1_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: auto _f = []{{ ?}}(double t) { // CHECK-NEXT: return t * t + 1.; // CHECK-NEXT: }{{;?}} @@ -34,12 +34,12 @@ double f2(double i, double j) { } // CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const; -// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK: void f2_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} // CHECK: double _d_x = 0.; -// CHECK-NEXT: double x = operator()(i + j, i); +// CHECK-NEXT: double x = _f.operator()(i + j, i); // CHECK-NEXT: _d_x += 1; // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C new file mode 100644 index 000000000..11f1f6fc7 --- /dev/null +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -0,0 +1,76 @@ +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oValidCodeGen.out 2>&1 | %filecheck %s +// RUN: ./ValidCodeGen.out | %filecheck_exec %s +// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oValidCodeGenWithTBR.out +// RUN: ./ValidCodeGenWithTBR.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include "../TestUtils.h" +#include "../PrintOverloads.h" + +namespace TN { + int coefficient = 3; + + template + struct Test2 { + T operator[](T x) { + return 4*x; + } + }; +} + +namespace clad { +namespace custom_derivatives { +namespace class_functions { + template + void operator_subscript_pullback(::TN::Test2* obj, T x, T d_u, ::TN::Test2* d_obj, T* d_x) { + (*d_x) += 4*d_u; + } +}}} + +double fn(double x) { + // fwd and rvs mode test + return x*TN::coefficient; // in this test, it's important that this nested name is copied into the generated code properly in both modes +} + +double fn2(double x, double y) { + // rvs mode test + TN::Test2 t; // this type needs to be copied into the derived code properly + auto q = t[x]; // in this test, it's important that this operator call is copied into the generated code properly and that the pullback function is called with all the needed namespace prefixes + return q; +} + +int main() { + double dx, dy; + INIT_DIFFERENTIATE(fn, "x"); + INIT_GRADIENT(fn); + INIT_GRADIENT(fn2); + + TEST_GRADIENT(fn, /*numOfDerivativeArgs=*/1, 3, &dx); // CHECK-EXEC: {3.00} + TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, 3, 4, &dx, &dy); // CHECK-EXEC: {4.00, 0.00} + TEST_DIFFERENTIATE(fn, 3) // CHECK-EXEC: {3.00} +} + +//CHECK: double fn_darg0(double x) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: return _d_x * TN::coefficient + x * 0; +//CHECK-NEXT: } + +//CHECK: void fn_grad(double x, double *_d_x) { +//CHECK-NEXT: *_d_x += 1 * TN::coefficient; +//CHECK-NEXT: } + +//CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: TN::Test2 _d_t({}); +//CHECK-NEXT: TN::Test2 t; +//CHECK-NEXT: TN::Test2 _t0 = t; +//CHECK-NEXT: double _d_q = 0.; +//CHECK-NEXT: double q = t.operator[](x); +//CHECK-NEXT: _d_q += 1; +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t0, x, _d_q, &_d_t, &_r0); +//CHECK-NEXT: *_d_x += _r0; +//CHECK-NEXT: } +//CHECK-NEXT: } From a14a3f61b28892233ae4f821de22bdaf3c4aacad Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 9 Oct 2024 08:45:10 +0200 Subject: [PATCH 2/8] Add support for Kokkos::View element access in the rvs mode (#1068) --- include/clad/Differentiator/KokkosBuiltins.h | 216 +++++++++++++++++++ unittests/Kokkos/ViewAccess.cpp | 9 +- 2 files changed, 220 insertions(+), 5 deletions(-) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index 1a6253027..d17d33a69 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -30,6 +30,37 @@ constructor_pushforward( Kokkos::View( "_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)}; } +template +clad::ValueAndAdjoint<::Kokkos::View, + ::Kokkos::View> +constructor_reverse_forw( + clad::ConstructorReverseForwTag<::Kokkos::View>, + const ::std::string& name, const size_t& idx0, const size_t& idx1, + const size_t& idx2, const size_t& idx3, const size_t& idx4, + const size_t& idx5, const size_t& idx6, const size_t& idx7, + const ::std::string& /*d_name*/, const size_t& /*d_idx0*/, + const size_t& /*d_idx1*/, const size_t& /*d_idx2*/, + const size_t& /*d_idx3*/, const size_t& /*d_idx4*/, + const size_t& /*d_idx5*/, const size_t& /*d_idx6*/, + const size_t& /*d_idx7*/) { + return {::Kokkos::View(name, idx0, idx1, idx2, idx3, + idx4, idx5, idx6, idx7), + ::Kokkos::View( + "_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)}; +} +template +void constructor_pullback(::Kokkos::View* v, + const ::std::string& name, const size_t& idx0, + const size_t& idx1, const size_t& idx2, + const size_t& idx3, const size_t& idx4, + const size_t& idx5, const size_t& idx6, + const size_t& idx7, + ::Kokkos::View* d_v, + const ::std::string* /*d_name*/, + const size_t& /*d_idx0*/, const size_t* /*d_idx1*/, + const size_t* /*d_idx2*/, const size_t* /*d_idx3*/, + const size_t* /*d_idx4*/, const size_t* /*d_idx5*/, + const size_t* /*d_idx6*/, const size_t* /*d_idx7*/) {} /// View indexing template @@ -107,6 +138,191 @@ operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7), (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)}; } +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx i0, + const ::Kokkos::View* d_v, + Idx /*d_i0*/) { + return {(*v)(i0), (*d_v)(i0)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx i0, Diff d_y, + ::Kokkos::View* d_v, + dIdx* /*d_i0*/) { + (*d_v)(i0) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/) { + return {(*v)(i0, i1), (*d_v)(i0, i1)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/) { + (*d_v)(i0, i1) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/) { + return {(*v)(i0, i1, i2), (*d_v)(i0, i1, i2)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/) { + (*d_v)(i0, i1, i2) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/) { + return {(*v)(i0, i1, i2, i3), (*d_v)(i0, i1, i2, i3)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/) { + (*d_v)(i0, i1, i2, i3) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/) { + return {(*v)(i0, i1, i2, i3, i4), (*d_v)(i0, i1, i2, i3, i4)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i4*/) { + (*d_v)(i0, i1, i2, i3, i4) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, Idx5 i5, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/) { + return {(*v)(i0, i1, i2, i3, i4, i5), (*d_v)(i0, i1, i2, i3, i4, i5)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Idx5 i5, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i4*/, dIdx5* /*d_i5*/) { + (*d_v)(i0, i1, i2, i3, i4, i5) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, Idx5 i5, + Idx6 i6, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/, + Idx6 /*d_i6*/) { + return {(*v)(i0, i1, i2, i3, i4, i5, i6), (*d_v)(i0, i1, i2, i3, i4, i5, i6)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Idx5 i5, Idx6 i6, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i3*/, dIdx5* /*d_i3*/, + dIdx6* /*d_i3*/) { + (*d_v)(i0, i1, i2, i3, i4, i5, i6) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, Idx5 i5, + Idx6 i6, Idx7 i7, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/, + Idx6 /*d_i6*/, Idx7 /*d_i7*/) { + return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7), + (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Idx5 i5, Idx6 i6, Idx7 i7, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i3*/, dIdx5* /*d_i3*/, + dIdx6* /*d_i3*/, dIdx7* /*d_i3*/) { + (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7) += d_y; +} } // namespace class_functions /// Kokkos functions (view utils) diff --git a/unittests/Kokkos/ViewAccess.cpp b/unittests/Kokkos/ViewAccess.cpp index e77b278f0..12cc355d1 100644 --- a/unittests/Kokkos/ViewAccess.cpp +++ b/unittests/Kokkos/ViewAccess.cpp @@ -60,11 +60,10 @@ TEST(ViewAccess, Test2) { double dx_f_2_FD = finite_difference_tangent(f_2_tmp, 3., epsilon); EXPECT_NEAR(f_2_x.execute(3, 4), dx_f_2_FD, tolerance * dx_f_2_FD); - // TODO: uncomment this once it has been implemented - // auto f_grad_exe = clad::gradient(f); - // double dx, dy; - // f_grad_exe.execute(3., 4., &dx, &dy); - // EXPECT_NEAR(f_x.execute(3, 4),dx,tolerance*dx); + auto f_grad_exe = clad::gradient(f); + double dx, dy; + f_grad_exe.execute(3., 4., &dx, &dy); + EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx); // double dx_2, dy_2; // auto f_2_grad_exe = clad::gradient(f_2); From 04b353bea8be5dfac66d186fcd0a7a0020952958 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Sat, 12 Oct 2024 19:30:09 +0200 Subject: [PATCH 3/8] Add support for `Kokkos::deep_copy` in the rvs mode --- include/clad/Differentiator/KokkosBuiltins.h | 116 +++++++++++++++++++ unittests/Kokkos/ViewAccess.cpp | 42 +++++-- 2 files changed, 150 insertions(+), 8 deletions(-) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index d17d33a69..51824a004 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -334,6 +334,122 @@ inline void deep_copy_pushforward(const View1& dst, const View2& src, T param, deep_copy(dst, src); deep_copy(d_dst, d_src); } +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) {} +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for("iterate_over_all_view_elements", v.extent(0), func); + } +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for("iterate_over_all_view_elements", + ::Kokkos::MDRangePolicy<::Kokkos::Rank<2>>( + {0, 0}, {v.extent(0), v.extent(1)}), + func); + } +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for( + "iterate_over_all_view_elements", + ::Kokkos::MDRangePolicy<::Kokkos::Rank<3>>( + {0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2)}), + func); + } +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for( + "iterate_over_all_view_elements", + ::Kokkos::MDRangePolicy<::Kokkos::Rank<4>>( + {0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3)}), + func); + } +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for( + "iterate_over_all_view_elements", + ::Kokkos::MDRangePolicy<::Kokkos::Rank<5>>( + {0, 0, 0, 0, 0}, + {v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4)}), + func); + } +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for( + "iterate_over_all_view_elements", + ::Kokkos::MDRangePolicy<::Kokkos::Rank<6>>( + {0, 0, 0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), + v.extent(3), v.extent(4), v.extent(5)}), + func); + } +}; +template struct iterate_over_all_view_elements { + template static void run(const View& v, F func) { + ::Kokkos::parallel_for( + "iterate_over_all_view_elements", + ::Kokkos::MDRangePolicy<::Kokkos::Rank<7>>( + {0, 0, 0, 0, 0, 0, 0}, + {v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4), + v.extent(5), v.extent(6)}), + func); + } +}; +template +void deep_copy_pullback( + const ::Kokkos::View& dst, + typename ::Kokkos::ViewTraits::const_value_type& /*value*/, + ::std::enable_if_t<::std::is_same< + typename ::Kokkos::ViewTraits::specialize, void>::value>*, + ::Kokkos::View* d_dst, + typename ::Kokkos::ViewTraits::value_type* d_value, + ::std::enable_if_t< + ::std::is_same::specialize, + void>::value>*) { + typename ::Kokkos::ViewTraits::value_type res = 0; + + iterate_over_all_view_elements< + ::Kokkos::View, + ::Kokkos::ViewTraits::rank>::run(dst, + [&res, + &d_dst](auto&&... args) { + res += (*d_dst)(args...); + (*d_dst)(args...) = 0; + }); + + (*d_value) += res; +} +template +inline void deep_copy_pullback( + const ::Kokkos::View& dst, + const ::Kokkos::View& /*src*/, + ::std::enable_if_t< + (::std::is_void< + typename ::Kokkos::ViewTraits::specialize>::value && + ::std::is_void< + typename ::Kokkos::ViewTraits::specialize>::value && + ((unsigned int)(::Kokkos::ViewTraits::rank) != 0 || + (unsigned int)(::Kokkos::ViewTraits::rank) != 0))>*, + ::Kokkos::View* d_dst, ::Kokkos::View* d_src, + ::std::enable_if_t< + (::std::is_void< + typename ::Kokkos::ViewTraits::specialize>::value && + ::std::is_void< + typename ::Kokkos::ViewTraits::specialize>::value && + ((unsigned int)(::Kokkos::ViewTraits::rank) != 0 || + (unsigned int)(::Kokkos::ViewTraits::rank) != 0))>*) { + iterate_over_all_view_elements<::Kokkos::View, + ::Kokkos::ViewTraits::rank>:: + run(dst, [&d_src, &d_dst](auto&&... args) { + (*d_src)(args...) += (*d_dst)(args...); + (*d_dst)(args...) = 0; + }); +} + template diff --git a/unittests/Kokkos/ViewAccess.cpp b/unittests/Kokkos/ViewAccess.cpp index 12cc355d1..e42475ccd 100644 --- a/unittests/Kokkos/ViewAccess.cpp +++ b/unittests/Kokkos/ViewAccess.cpp @@ -14,11 +14,11 @@ double f(double x, double y) { Kokkos::View b("b", N1); a(0, 0) = x; - b(0, 0) = y; + b(1, 1) = y; - b(0, 0) += a(0, 0) * b(0, 0); + b(1, 1) += a(0, 0) * b(1, 1); - return a(0, 0) * a(0, 0) * b(0, 0) + b(0, 0); + return a(0, 0) * a(0, 0) * b(1, 1) + b(1, 1); } double f_2(double x, double y) { @@ -37,6 +37,22 @@ double f_2(double x, double y) { return a(0, 0); } +double f_3(double x, double y) { + + const int N1 = 4; + + Kokkos::View a("a", N1); + Kokkos::View b("b", N1); + + Kokkos::deep_copy(a, 3 * x + y); + b(0, 0) = y; + Kokkos::deep_copy(b, a); + + b(0, 0) += a(0, 0) * b(0, 0); + + return a(0, 0) + b(0, 0); +} + TEST(ViewAccess, Test1) { EXPECT_NEAR(f(0, 1), 1, 1e-8); EXPECT_NEAR(f(0, 2), 2, 1e-8); @@ -51,7 +67,6 @@ TEST(ViewAccess, Test2) { std::function f_tmp = [](double x) { return f(x, 4.); }; double dx_f_FD = finite_difference_tangent(f_tmp, 3., epsilon); - EXPECT_NEAR(f_x.execute(3, 4), dx_f_FD, tolerance * dx_f_FD); auto f_2_x = clad::differentiate(f_2, "x"); @@ -60,13 +75,24 @@ TEST(ViewAccess, Test2) { double dx_f_2_FD = finite_difference_tangent(f_2_tmp, 3., epsilon); EXPECT_NEAR(f_2_x.execute(3, 4), dx_f_2_FD, tolerance * dx_f_2_FD); + auto f_3_y = clad::differentiate(f_3, "y"); + + std::function f_3_tmp = [](double y) { return f_3(3., y); }; + double dy_f_3_FD = finite_difference_tangent(f_3_tmp, 4., epsilon); + EXPECT_NEAR(f_3_y.execute(3, 4), dy_f_3_FD, tolerance * dy_f_3_FD); + auto f_grad_exe = clad::gradient(f); double dx, dy; f_grad_exe.execute(3., 4., &dx, &dy); EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx); - // double dx_2, dy_2; - // auto f_2_grad_exe = clad::gradient(f_2); - // f_2_grad_exe.execute(3., 4., &dx_2, &dy_2); - // EXPECT_NEAR(f_2_x.execute(3, 4),dx_2,tolerance*dx_2); + double dx_2, dy_2; + auto f_2_grad_exe = clad::gradient(f_2); + f_2_grad_exe.execute(3., 4., &dx_2, &dy_2); + EXPECT_NEAR(f_2_x.execute(3, 4), dx_2, tolerance * dx_2); + + double dx_3, dy_3; + auto f_3_grad_exe = clad::gradient(f_3); + f_3_grad_exe.execute(3., 4., &dx_3, &dy_3); + EXPECT_NEAR(f_3_y.execute(3, 4), dy_3, tolerance * dy_3); } \ No newline at end of file From f86eedee99509b2eda4d16d5660ca8ff051c6d55 Mon Sep 17 00:00:00 2001 From: Christina Koutsou <74819775+kchristin22@users.noreply.github.com> Date: Sat, 12 Oct 2024 22:00:07 +0300 Subject: [PATCH 4/8] Fix synthesizing literals function for enums (#1113) --- lib/Differentiator/ConstantFolder.cpp | 16 ++- test/Gradient/Switch.C | 153 ++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/lib/Differentiator/ConstantFolder.cpp b/lib/Differentiator/ConstantFolder.cpp index 6439c6eaa..900e87a90 100644 --- a/lib/Differentiator/ConstantFolder.cpp +++ b/lib/Differentiator/ConstantFolder.cpp @@ -7,6 +7,7 @@ //----------------------------------------------------------------------------// #include "ConstantFolder.h" +#include "clad/Differentiator/Compatibility.h" #include "clang/AST/ASTContext.h" @@ -141,7 +142,20 @@ namespace clad { // SourceLocation noLoc; Expr* Result = 0; QT = QT.getCanonicalType(); - if (QT->isPointerType()) { + if (QT->isEnumeralType()) { + llvm::APInt APVal(C.getIntWidth(QT), val, + QT->isSignedIntegerOrEnumerationType()); + Result = clad::synthesizeLiteral( + dyn_cast(QT)->getDecl()->getIntegerType(), C, APVal); + SourceLocation noLoc; + Expr* cast = CXXStaticCastExpr::Create( + C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + clang::CastKind::CK_IntegralCast, Result, /*CXXCastPath=*/nullptr, + C.getTrivialTypeSourceInfo(QT, noLoc) + CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO, + noLoc, noLoc, SourceRange()); + Result = cast; + } else if (QT->isPointerType()) { Result = clad::synthesizeLiteral(QT, C); } else if (QT->isBooleanType()) { Result = clad::synthesizeLiteral(QT, C, (bool)val); diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index 6e18bc04d..98a176807 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -682,6 +682,146 @@ double fn7(double u, double v) { // CHECK-NEXT: } // CHECK-NEXT: } +enum Op { + Add, + Sub, + Mul, + Div +}; + +double fn24(double x, double y, Op op) { + double res = 0; + switch (op) { + case Add: + res = x + y; + break; + case Sub: + res = x - y; + break; + case Mul: + res = x * y; + break; + case Div: + res = x / y; + break; + } + return res; +} + +// CHECK: void fn24_grad_0_1(double x, double y, Op op, double *_d_x, double *_d_y) { +// CHECK-NEXT: Op _d_op = static_cast(0U); +// CHECK-NEXT: Op _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: _cond0 = op; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case Add: +// CHECK-NEXT: res = x + y; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{1U|1UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Sub: +// CHECK-NEXT: res = x - y; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{2U|2UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Mul: +// CHECK-NEXT: res = x * y; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{3U|3UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Div: +// CHECK-NEXT: res = x / y; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{4U|4UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, {{5U|5UL}}); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case {{5U|5UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: case {{4U|4UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d3 / y; +// CHECK-NEXT: double _r0 = _r_d3 * -(x / (y * y)); +// CHECK-NEXT: _d_y += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Div == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case {{3U|3UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d2 * y; +// CHECK-NEXT: _d_y += x * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (Mul == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case {{2U|2UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d1; +// CHECK-NEXT: _d_y += -_r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (Sub == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case {{1U|1UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: _d_y += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Add == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + #define TEST_2(F, x, y) \ { \ @@ -691,6 +831,14 @@ double fn7(double u, double v) { printf("{%.2f, %.2f}\n", result[0], result[1]); \ } +#define TEST_2_Op(F, x, y, op) \ +{ \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F, "x, y"); \ + d_##F.execute(x, y, op, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ +} + int main() { double result[2] = {}; @@ -705,4 +853,9 @@ int main() { TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00} TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00} + + TEST_2_Op(fn24, 3, 5, Add); // CHECK-EXEC: {1.00, 1.00} + TEST_2_Op(fn24, 3, 5, Sub); // CHECK-EXEC: {1.00, -1.00} + TEST_2_Op(fn24, 3, 5, Mul); // CHECK-EXEC: {5.00, 3.00} + TEST_2_Op(fn24, 3, 5, Div); // CHECK-EXEC: {0.20, -0.12} } From 871d538476c0c5cae17875c6d48242485e7c0ef5 Mon Sep 17 00:00:00 2001 From: kchristin Date: Mon, 23 Sep 2024 17:59:44 +0300 Subject: [PATCH 5/8] Add support for device functions as pullback functions For this purpose, a deeper look into atomic ops had to be taken. Atomic ops can only be applied on global or shared GPU memory. Hence, we needed to identify which call args of the device function pullback are actually kernel args and, thus, global. The indexes of those args are stored in a vector in the differentiation request for the internal device function and appended to the name of the pullback function. Later on, when deriving the encountered device function, the global call args are matched with the function's params based on their stored indexes. This way, the atomic ops are minimized to the absolute necessary number and no error arises. --- include/clad/Differentiator/DiffPlanner.h | 2 + .../clad/Differentiator/ExternalRMVSource.h | 2 +- .../clad/Differentiator/ReverseModeVisitor.h | 9 +- lib/Differentiator/ReverseModeVisitor.cpp | 113 ++++++-- lib/Differentiator/VisitorBase.cpp | 3 +- test/CUDA/GradientKernels.cu | 267 +++++++++++++++++- 6 files changed, 368 insertions(+), 28 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index ef29b7246..663b24b47 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -46,6 +46,8 @@ struct DiffRequest { clang::CallExpr* CallContext = nullptr; /// Args provided to the call to clad::gradient/differentiate. const clang::Expr* Args = nullptr; + /// Indexes of global GPU args of function as a subset of Args. + std::vector GlobalArgsIndexes; /// Requested differentiation mode, forward or reverse. DiffMode Mode = DiffMode::unknown; /// If function appears in the call to clad::gradient/differentiate, diff --git a/include/clad/Differentiator/ExternalRMVSource.h b/include/clad/Differentiator/ExternalRMVSource.h index 4da9d09fc..72fc596b0 100644 --- a/include/clad/Differentiator/ExternalRMVSource.h +++ b/include/clad/Differentiator/ExternalRMVSource.h @@ -124,7 +124,7 @@ class ExternalRMVSource { /// This is called just before finalising `VisitReturnStmt`. virtual void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) {} - /// This ic called just before finalising `VisitCallExpr`. + /// This is called just before finalising `VisitCallExpr`. /// /// \param CE call expression that is being visited. /// \param CallArgs diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index b044ee0ec..e58d77398 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -38,6 +38,11 @@ namespace clad { // several private/protected members of the visitor classes. friend class ErrorEstimationHandler; llvm::SmallVector m_IndependentVars; + /// Set used to keep track of parameter variables w.r.t which the + /// the derivative (gradient) is being computed. This is separate from the + /// m_Variables map because all other intermediate variables will + /// not be stored here. + std::unordered_set m_ParamVarsWithDiff; /// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in /// the reverse mode we also accumulate Stmts for the reverse pass which /// will be executed on return. @@ -51,6 +56,8 @@ namespace clad { /// that will be put immediately in the beginning of derivative function /// block. Stmts m_Globals; + /// Global GPU args of the function. + std::unordered_set m_GlobalArgs; //// A reference to the output parameter of the gradient function. clang::Expr* m_Result; /// A flag indicating if the Stmt we are currently visiting is inside loop. @@ -432,7 +439,7 @@ namespace clad { /// Helper function that checks whether the function to be derived /// is meant to be executed only by the GPU - bool shouldUseCudaAtomicOps(); + bool shouldUseCudaAtomicOps(const clang::Expr* E); /// Add call to cuda::atomicAdd for the given LHS and RHS expressions. /// diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 2a7a16aa4..f867a321c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -104,10 +104,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return CladTapeResult{*this, PushExpr, PopExpr, TapeRef}; } - bool ReverseModeVisitor::shouldUseCudaAtomicOps() { - return m_DiffReq->hasAttr() || - (m_DiffReq->hasAttr() && - !m_DiffReq->hasAttr()); + bool ReverseModeVisitor::shouldUseCudaAtomicOps(const Expr* E) { + // Same as checking whether this is a function executed by the GPU + if (!m_GlobalArgs.empty()) + if (const auto* DRE = dyn_cast(E)) + if (const auto* PVD = dyn_cast(DRE->getDecl())) + // we need to check whether this param is in the global memory of the + // GPU + return m_GlobalArgs.find(PVD) != m_GlobalArgs.end(); + + return false; } clang::Expr* ReverseModeVisitor::BuildCallToCudaAtomicAdd(clang::Expr* LHS, @@ -123,8 +129,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.BuildDeclarationNameExpr(SS, lookupResult, /*ADL=*/true).get(); Expr* finalLHS = LHS; - if (isa(LHS)) + if (auto* UO = dyn_cast(LHS)) { + if (UO->getOpcode() == UnaryOperatorKind::UO_Deref) + finalLHS = UO->getSubExpr()->IgnoreImplicit(); + } else if (!LHS->getType()->isPointerType() && + !LHS->getType()->isReferenceType()) finalLHS = BuildOp(UnaryOperatorKind::UO_AddrOf, LHS); + llvm::SmallVector atomicArgs = {finalLHS, RHS}; assert(!m_Builder.noOverloadExists(UnresolvedLookup, atomicArgs) && @@ -440,6 +451,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParams(params); + // if the function is a global kernel, all its parameters reside in the + // global memory of the GPU + if (m_DiffReq->hasAttr()) + for (auto param : params) + m_GlobalArgs.emplace(param); + llvm::ArrayRef paramsRef = clad_compat::makeArrayRef(params.data(), params.size()); gradientFD->setParams(paramsRef); @@ -546,6 +563,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback"; + for (auto index : m_DiffReq.GlobalArgsIndexes) + derivativeName += "_" + std::to_string(index); auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); auto paramTypes = ComputeParamTypes(args); @@ -587,6 +606,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterCreatingDerivedFnParams(params); m_Derivative->setParams(params); + // Match the global arguments of the call to the device function to the + // pullback function's parameters. + if (!m_DiffReq.GlobalArgsIndexes.empty()) + for (auto index : m_DiffReq.GlobalArgsIndexes) + m_GlobalArgs.emplace(m_Derivative->getParamDecl(index)); + // If the function is a global kernel, all its parameters reside in the + // global memory of the GPU + else if (m_DiffReq->hasAttr()) + for (auto param : params) + m_GlobalArgs.emplace(param); m_Derivative->setBody(nullptr); if (!m_DiffReq.DeclarationOnly) { @@ -1519,7 +1548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildArraySubscript(target, forwSweepDerivativeIndices); // Create the (target += dfdx) statement. if (dfdx()) { - if (shouldUseCudaAtomicOps()) { + if (shouldUseCudaAtomicOps(target)) { Expr* atomicCall = BuildCallToCudaAtomicAdd(result, dfdx()); // Add it to the body statements. addToCurrentBlock(atomicCall, direction::reverse); @@ -1583,9 +1612,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: not sure if this is generic. // Don't update derivatives of record types. if (!VD->getType()->isRecordType()) { - auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); - // Add it to the body statements. - addToCurrentBlock(add_assign, direction::reverse); + Expr* base = it->second; + if (auto* UO = dyn_cast(it->second)) + base = UO->getSubExpr()->IgnoreImpCasts(); + if (shouldUseCudaAtomicOps(base)) { + Expr* atomicCall = BuildCallToCudaAtomicAdd(it->second, dfdx()); + // Add it to the body statements. + addToCurrentBlock(atomicCall, direction::reverse); + } else { + auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); + addToCurrentBlock(add_assign, direction::reverse); + } } } return StmtDiff(clonedDRE, it->second, it->second); @@ -1728,20 +1765,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (const Expr* Arg : CE->arguments()) { StmtDiff ArgDiff = Visit(Arg, dfdx()); CallArgs.push_back(ArgDiff.getExpr()); - DerivedCallArgs.push_back(ArgDiff.getExpr_dx()); + if (auto* DRE = dyn_cast(ArgDiff.getExpr())) { + // If the arg is used for differentiation of the function, then we + // cannot free it in the end as it's the result to be returned to the + // user. + if (m_ParamVarsWithDiff.find(DRE->getDecl()) == + m_ParamVarsWithDiff.end()) + DerivedCallArgs.push_back(ArgDiff.getExpr_dx()); + } } Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, llvm::MutableArrayRef(CallArgs), Loc) .get(); - Expr* call_dx = - m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, - llvm::MutableArrayRef(DerivedCallArgs), Loc) - .get(); m_DeallocExprs.push_back(call); - m_DeallocExprs.push_back(call_dx); + + if (!DerivedCallArgs.empty()) { + Expr* call_dx = + m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, + llvm::MutableArrayRef(DerivedCallArgs), + Loc) + .get(); + m_DeallocExprs.push_back(call_dx); + } return StmtDiff(); } @@ -1887,6 +1935,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If it has more args or f_darg0 was not found, we look for its pullback // function. const auto* MD = dyn_cast(FD); + std::vector globalCallArgs; if (!OverloadedDerivedFn) { size_t idx = 0; @@ -1952,12 +2001,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullback); // Try to find it in builtin derivatives + std::string customPullback = + clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; + // Add the indexes of the global args to the custom pullback name + if (!m_GlobalArgs.empty()) + for (size_t i = 0; i < pullbackCallArgs.size(); i++) + if (auto* DRE = dyn_cast(pullbackCallArgs[i])) + if (auto* param = dyn_cast(DRE->getDecl())) + if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) { + customPullback += "_" + std::to_string(i); + globalCallArgs.emplace_back(i); + } + if (baseDiff.getExpr()) pullbackCallArgs.insert( pullbackCallArgs.begin(), BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr())); - std::string customPullback = - clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; + OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), @@ -1990,6 +2050,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // derive the called function. DiffRequest pullbackRequest{}; pullbackRequest.Function = FD; + + // Mark the indexes of the global args. Necessary if the argument of the + // call has a different name than the function's signature parameter. + pullbackRequest.GlobalArgsIndexes = globalCallArgs; + pullbackRequest.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); pullbackRequest.Mode = DiffMode::experimental_pullback; @@ -2237,12 +2302,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i); Expr* gradElem = BuildArraySubscript(gradRef, {idx}); Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem); - PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); + if (shouldUseCudaAtomicOps(outputArgs[i])) + PostCallStmts.push_back( + BuildCallToCudaAtomicAdd(outputArgs[i], gradExpr)); + else + PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); NumDiffArgs.push_back(args[i]); } std::string Name = "central_difference"; return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, + Name, NumDiffArgs, getCurrentScope(), + /*OriginalFnDC=*/nullptr, /*forCustomDerv=*/false, /*namespaceShouldExist=*/false); } @@ -2344,7 +2414,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx); // Create the (target += dfdx) statement. if (dfdx()) { - if (shouldUseCudaAtomicOps()) { + if (shouldUseCudaAtomicOps(diff_dx)) { Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx()); // Add it to the body statements. addToCurrentBlock(atomicCall, direction::reverse); @@ -4556,6 +4626,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Variables[*it] = utils::BuildParenExpr(m_Sema, m_Variables[*it]); } + m_ParamVarsWithDiff.emplace(*it); } } } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index b6156397c..63ae3c369 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -783,7 +783,8 @@ namespace clad { // Return the found overload. std::string Name = "forward_central_difference"; return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, + Name, NumDiffArgs, getCurrentScope(), + /*OriginalFnDC=*/nullptr, /*forCustomDerv=*/false, /*namespaceShouldExist=*/false); } diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu index 171341a7e..a60604fa6 100644 --- a/test/CUDA/GradientKernels.cu +++ b/test/CUDA/GradientKernels.cu @@ -288,6 +288,186 @@ __global__ void add_kernel_7(double *a, double *b) { //CHECK-NEXT: } //CHECK-NEXT:} +__device__ double device_fn(double in, double val) { + return in + val; +} + +__global__ void kernel_with_device_call(double *out, double *in, double val) { + int index = threadIdx.x; + out[index] = device_fn(in[index], val); +} + +// CHECK: void kernel_with_device_call_grad_0_2(double *out, double *in, double val, double *_d_out, double *_d_val) { +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x; +//CHECK-NEXT: double _t0 = out[index0]; +//CHECK-NEXT: out[index0] = device_fn(in[index0], val); +//CHECK-NEXT: { +//CHECK-NEXT: out[index0] = _t0; +//CHECK-NEXT: double _r_d0 = _d_out[index0]; +//CHECK-NEXT: _d_out[index0] = 0.; +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: double _r1 = 0.; +//CHECK-NEXT: device_fn_pullback_1(in[index0], val, _r_d0, &_r0, &_r1); +//CHECK-NEXT: atomicAdd(_d_val, _r1); +//CHECK-NEXT: } +//CHECK-NEXT:} + +__device__ double device_fn_2(double *in, double val) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + return in[index] + val; +} + +__global__ void kernel_with_device_call_2(double *out, double *in, double val) { + int index = threadIdx.x; + out[index] = device_fn_2(in, val); +} + +__global__ void dup_kernel_with_device_call_2(double *out, double *in, double val) { + int index = threadIdx.x; + out[index] = device_fn_2(in, val); +} + +// CHECK: void kernel_with_device_call_2_grad_0_2(double *out, double *in, double val, double *_d_out, double *_d_val) { +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x; +//CHECK-NEXT: double _t0 = out[index0]; +//CHECK-NEXT: out[index0] = device_fn_2(in, val); +//CHECK-NEXT: { +//CHECK-NEXT: out[index0] = _t0; +//CHECK-NEXT: double _r_d0 = _d_out[index0]; +//CHECK-NEXT: _d_out[index0] = 0.; +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: device_fn_2_pullback_0_1(in, val, _r_d0, &_r0); +//CHECK-NEXT: atomicAdd(_d_val, _r0); +//CHECK-NEXT: } +//CHECK-NEXT:} + +// CHECK: void kernel_with_device_call_2_grad_0_1(double *out, double *in, double val, double *_d_out, double *_d_in) { +//CHECK-NEXT: double _d_val = 0.; +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x; +//CHECK-NEXT: double _t0 = out[index0]; +//CHECK-NEXT: out[index0] = device_fn_2(in, val); +//CHECK-NEXT: { +//CHECK-NEXT: out[index0] = _t0; +//CHECK-NEXT: double _r_d0 = _d_out[index0]; +//CHECK-NEXT: _d_out[index0] = 0.; +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: device_fn_2_pullback_0_1_3(in, val, _r_d0, _d_in, &_r0); +//CHECK-NEXT: _d_val += _r0; +//CHECK-NEXT: } +//CHECK-NEXT:} + +__device__ double device_fn_3(double *in, double *val) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + return in[index] + *val; +} + +__global__ void kernel_with_device_call_3(double *out, double *in, double *val) { + int index = threadIdx.x; + out[index] = device_fn_3(in, val); +} + +// CHECK: void kernel_with_device_call_3_grad(double *out, double *in, double *val, double *_d_out, double *_d_in, double *_d_val) { +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x; +//CHECK-NEXT: double _t0 = out[index0]; +//CHECK-NEXT: out[index0] = device_fn_3(in, val); +//CHECK-NEXT: { +//CHECK-NEXT: out[index0] = _t0; +//CHECK-NEXT: double _r_d0 = _d_out[index0]; +//CHECK-NEXT: _d_out[index0] = 0.; +//CHECK-NEXT: device_fn_3_pullback_0_1_3_4(in, val, _r_d0, _d_in, _d_val); +//CHECK-NEXT: } +//CHECK-NEXT:} + +__device__ double device_fn_4(double *in, double val) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + return in[index] + val; +} + +__device__ double device_with_device_call(double *in, double val) { + return device_fn_4(in, val); +} + +__global__ void kernel_with_nested_device_call(double *out, double *in, double val) { + int index = threadIdx.x; + out[index] = device_with_device_call(in, val); +} + +// CHECK: void kernel_with_nested_device_call_grad_0_1(double *out, double *in, double val, double *_d_out, double *_d_in) { +//CHECK-NEXT: double _d_val = 0.; +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x; +//CHECK-NEXT: double _t0 = out[index0]; +//CHECK-NEXT: out[index0] = device_with_device_call(in, val); +//CHECK-NEXT: { +//CHECK-NEXT: out[index0] = _t0; +//CHECK-NEXT: double _r_d0 = _d_out[index0]; +//CHECK-NEXT: _d_out[index0] = 0.; +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: device_with_device_call_pullback_0_1_3(in, val, _r_d0, _d_in, &_r0); +//CHECK-NEXT: _d_val += _r0; +//CHECK-NEXT: } +//CHECK-NEXT:} + +// CHECK: __attribute__((device)) void device_fn_pullback_1(double in, double val, double _d_y, double *_d_in, double *_d_val) { +//CHECK-NEXT: { +//CHECK-NEXT: *_d_in += _d_y; +//CHECK-NEXT: *_d_val += _d_y; +//CHECK-NEXT: } +//CHECK-NEXT:} + +// CHECK: __attribute__((device)) void device_fn_2_pullback_0_1(double *in, double val, double _d_y, double *_d_val) { +//CHECK-NEXT: unsigned int _t1 = blockIdx.x; +//CHECK-NEXT: unsigned int _t0 = blockDim.x; +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x + _t1 * _t0; +//CHECK-NEXT: *_d_val += _d_y; +//CHECK-NEXT:} + +// CHECK: __attribute__((device)) void device_fn_2_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) { +//CHECK-NEXT: unsigned int _t1 = blockIdx.x; +//CHECK-NEXT: unsigned int _t0 = blockDim.x; +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x + _t1 * _t0; +//CHECK-NEXT: { +//CHECK-NEXT: atomicAdd(&_d_in[index0], _d_y); +//CHECK-NEXT: *_d_val += _d_y; +//CHECK-NEXT: } +//CHECK-NEXT:} + +// CHECK: __attribute__((device)) void device_fn_3_pullback_0_1_3_4(double *in, double *val, double _d_y, double *_d_in, double *_d_val) { +//CHECK-NEXT: unsigned int _t1 = blockIdx.x; +//CHECK-NEXT: unsigned int _t0 = blockDim.x; +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x + _t1 * _t0; +//CHECK-NEXT: { +//CHECK-NEXT: atomicAdd(&_d_in[index0], _d_y); +//CHECK-NEXT: atomicAdd(_d_val, _d_y); +//CHECK-NEXT: } +//CHECK-NEXT:} + +// CHECK: __attribute__((device)) void device_with_device_call_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) { +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: device_fn_4_pullback_0_1_3(in, val, _d_y, _d_in, &_r0); +//CHECK-NEXT: *_d_val += _r0; +//CHECK-NEXT: } +//CHECK-NEXT:} + +// CHECK: __attribute__((device)) void device_fn_4_pullback_0_1_3(double *in, double val, double _d_y, double *_d_in, double *_d_val) { +//CHECK-NEXT: unsigned int _t1 = blockIdx.x; +//CHECK-NEXT: unsigned int _t0 = blockDim.x; +//CHECK-NEXT: int _d_index = 0; +//CHECK-NEXT: int index0 = threadIdx.x + _t1 * _t0; +//CHECK-NEXT: { +//CHECK-NEXT: atomicAdd(&_d_in[index0], _d_y); +//CHECK-NEXT: *_d_val += _d_y; +//CHECK-NEXT: } +//CHECK-NEXT:} + #define TEST(F, grid, block, shared_mem, use_stream, x, dx, N) \ { \ int *fives = (int*)malloc(N * sizeof(int)); \ @@ -345,9 +525,9 @@ __global__ void add_kernel_7(double *a, double *b) { else { \ test.execute_kernel(grid, block, y, x, dy, dx); \ } \ - cudaDeviceSynchronize(); \ int *res = (int*)malloc(N * sizeof(int)); \ cudaMemcpy(res, dx, N * sizeof(int), cudaMemcpyDeviceToHost); \ + cudaDeviceSynchronize(); \ for (int i = 0; i < (N - 1); i++) { \ printf("%d, ", res[i]); \ } \ @@ -380,9 +560,9 @@ __global__ void add_kernel_7(double *a, double *b) { else { \ test.execute_kernel(grid, block, y, x, N, dy, dx); \ } \ - cudaDeviceSynchronize(); \ int *res = (int*)malloc(N * sizeof(int)); \ cudaMemcpy(res, dx, N * sizeof(int), cudaMemcpyDeviceToHost); \ + cudaDeviceSynchronize(); \ for (int i = 0; i < (N - 1); i++) { \ printf("%d, ", res[i]); \ } \ @@ -415,9 +595,9 @@ __global__ void add_kernel_7(double *a, double *b) { else { \ test.execute_kernel(grid, block, y, x, dy, dx); \ } \ - cudaDeviceSynchronize(); \ double *res = (double*)malloc(N * sizeof(double)); \ cudaMemcpy(res, dx, N * sizeof(double), cudaMemcpyDeviceToHost); \ + cudaDeviceSynchronize(); \ for (int i = 0; i < (N - 1); i++) { \ printf("%0.2f, ", res[i]); \ } \ @@ -427,6 +607,25 @@ __global__ void add_kernel_7(double *a, double *b) { free(res); \ } +#define INIT(x, y, val, dx, dy, d_val) \ +{ \ + double *fives = (double*)malloc(10 * sizeof(double)); \ + for(int i = 0; i < 10; i++) { \ + fives[i] = 5; \ + } \ + double *zeros = (double*)malloc(10 * sizeof(double)); \ + for(int i = 0; i < 10; i++) { \ + zeros[i] = 0; \ + } \ + cudaMemcpy(x, fives, 10 * sizeof(double), cudaMemcpyHostToDevice); \ + cudaMemcpy(y, zeros, 10 * sizeof(double), cudaMemcpyHostToDevice); \ + cudaMemcpy(val, fives, sizeof(double), cudaMemcpyHostToDevice); \ + cudaMemcpy(dx, zeros, 10 * sizeof(double), cudaMemcpyHostToDevice); \ + cudaMemcpy(dy, fives, 10 * sizeof(double), cudaMemcpyHostToDevice); \ + cudaMemcpy(d_val, zeros, sizeof(double), cudaMemcpyHostToDevice); \ + free(fives); \ + free(zeros); \ +} int main(void) { int *a, *d_a; @@ -472,11 +671,71 @@ int main(void) { TEST_2_D(add_kernel_7, dim3(1), dim3(5, 1, 1), 0, false, "a, b", dummy_out_double, dummy_in_double, d_out_double, d_in_double, 10); // CHECK-EXEC: 50.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 + double *val; + cudaMalloc(&val, sizeof(double)); + double *d_val; + cudaMalloc(&d_val, sizeof(double)); + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto test_device = clad::gradient(kernel_with_device_call, "out, val"); + test_device.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_val); + double *res = (double*)malloc(10 * sizeof(double)); + cudaMemcpy(res, d_val, sizeof(double), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("%0.2f\n", *res); // CHECK-EXEC: 50.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto test_device_2 = clad::gradient(kernel_with_device_call_2, "out, val"); + test_device_2.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_val); + cudaMemcpy(res, d_val, sizeof(double), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("%0.2f\n", *res); // CHECK-EXEC: 50.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto check_dup = clad::gradient(dup_kernel_with_device_call_2, "out, val"); // check that the pullback function is not regenerated + check_dup.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_val); + cudaMemcpy(res, d_val, sizeof(double), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("%s\n", cudaGetErrorString(cudaGetLastError())); // CHECK-EXEC: no error + printf("%0.2f\n", *res); // CHECK-EXEC: 50.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto test_device_3 = clad::gradient(kernel_with_device_call_2, "out, in"); + test_device_3.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_in_double); + cudaDeviceSynchronize(); + cudaMemcpy(res, d_in_double, 10 * sizeof(double), cudaMemcpyDeviceToHost); + printf("%0.2f, %0.2f, %0.2f\n", res[0], res[1], res[2]); // CHECK-EXEC: 5.00, 5.00, 5.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto test_device_4 = clad::gradient(kernel_with_device_call_3); + test_device_4.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, val, d_out_double, d_in_double, d_val); + cudaDeviceSynchronize(); + cudaMemcpy(res, d_in_double, 10 * sizeof(double), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + printf("%0.2f, %0.2f, %0.2f\n", res[0], res[1], res[2]); // CHECK-EXEC: 5.00, 5.00, 5.00 + cudaMemcpy(res, d_val, sizeof(double), cudaMemcpyDeviceToHost); + printf("%0.2f\n", *res); // CHECK-EXEC: 50.00 + + INIT(dummy_in_double, dummy_out_double, val, d_in_double, d_out_double, d_val); + + auto nested_device = clad::gradient(kernel_with_nested_device_call, "out, in"); + nested_device.execute_kernel(dim3(1), dim3(10, 1, 1), dummy_out_double, dummy_in_double, 5, d_out_double, d_in_double); + cudaDeviceSynchronize(); + cudaMemcpy(res, d_in_double, 10 * sizeof(double), cudaMemcpyDeviceToHost); + printf("%0.2f, %0.2f, %0.2f\n", res[0], res[1], res[2]); // CHECK-EXEC: 5.00, 5.00, 5.00 + + free(res); cudaFree(dummy_in_double); cudaFree(dummy_out_double); cudaFree(d_out_double); cudaFree(d_in_double); - + cudaFree(val); + cudaFree(d_val); return 0; } From 7bdfde2a812dead8125787ccd7b19443f45500b9 Mon Sep 17 00:00:00 2001 From: kchristin Date: Mon, 14 Oct 2024 12:04:58 +0300 Subject: [PATCH 6/8] Improve names of GlobalArgs and GlobalArgsIndexes --- include/clad/Differentiator/DiffPlanner.h | 2 +- .../clad/Differentiator/ReverseModeVisitor.h | 2 +- lib/Differentiator/ReverseModeVisitor.cpp | 22 +++++++++---------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 663b24b47..a4b06a148 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -47,7 +47,7 @@ struct DiffRequest { /// Args provided to the call to clad::gradient/differentiate. const clang::Expr* Args = nullptr; /// Indexes of global GPU args of function as a subset of Args. - std::vector GlobalArgsIndexes; + std::vector CUDAGlobalArgsIndexes; /// Requested differentiation mode, forward or reverse. DiffMode Mode = DiffMode::unknown; /// If function appears in the call to clad::gradient/differentiate, diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index e58d77398..dabcfd256 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -57,7 +57,7 @@ namespace clad { /// block. Stmts m_Globals; /// Global GPU args of the function. - std::unordered_set m_GlobalArgs; + std::unordered_set m_CUDAGlobalArgs; //// A reference to the output parameter of the gradient function. clang::Expr* m_Result; /// A flag indicating if the Stmt we are currently visiting is inside loop. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f867a321c..c3f0f3b76 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -106,12 +106,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool ReverseModeVisitor::shouldUseCudaAtomicOps(const Expr* E) { // Same as checking whether this is a function executed by the GPU - if (!m_GlobalArgs.empty()) + if (!m_CUDAGlobalArgs.empty()) if (const auto* DRE = dyn_cast(E)) if (const auto* PVD = dyn_cast(DRE->getDecl())) // we need to check whether this param is in the global memory of the // GPU - return m_GlobalArgs.find(PVD) != m_GlobalArgs.end(); + return m_CUDAGlobalArgs.find(PVD) != m_CUDAGlobalArgs.end(); return false; } @@ -455,7 +455,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // global memory of the GPU if (m_DiffReq->hasAttr()) for (auto param : params) - m_GlobalArgs.emplace(param); + m_CUDAGlobalArgs.emplace(param); llvm::ArrayRef paramsRef = clad_compat::makeArrayRef(params.data(), params.size()); @@ -563,7 +563,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback"; - for (auto index : m_DiffReq.GlobalArgsIndexes) + for (auto index : m_DiffReq.CUDAGlobalArgsIndexes) derivativeName += "_" + std::to_string(index); auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); @@ -608,14 +608,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Derivative->setParams(params); // Match the global arguments of the call to the device function to the // pullback function's parameters. - if (!m_DiffReq.GlobalArgsIndexes.empty()) - for (auto index : m_DiffReq.GlobalArgsIndexes) - m_GlobalArgs.emplace(m_Derivative->getParamDecl(index)); + if (!m_DiffReq.CUDAGlobalArgsIndexes.empty()) + for (auto index : m_DiffReq.CUDAGlobalArgsIndexes) + m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index)); // If the function is a global kernel, all its parameters reside in the // global memory of the GPU else if (m_DiffReq->hasAttr()) for (auto param : params) - m_GlobalArgs.emplace(param); + m_CUDAGlobalArgs.emplace(param); m_Derivative->setBody(nullptr); if (!m_DiffReq.DeclarationOnly) { @@ -2004,11 +2004,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::string customPullback = clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; // Add the indexes of the global args to the custom pullback name - if (!m_GlobalArgs.empty()) + if (!m_CUDAGlobalArgs.empty()) for (size_t i = 0; i < pullbackCallArgs.size(); i++) if (auto* DRE = dyn_cast(pullbackCallArgs[i])) if (auto* param = dyn_cast(DRE->getDecl())) - if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) { + if (m_CUDAGlobalArgs.find(param) != m_CUDAGlobalArgs.end()) { customPullback += "_" + std::to_string(i); globalCallArgs.emplace_back(i); } @@ -2053,7 +2053,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Mark the indexes of the global args. Necessary if the argument of the // call has a different name than the function's signature parameter. - pullbackRequest.GlobalArgsIndexes = globalCallArgs; + pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs; pullbackRequest.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); From 77a020b10119c07d01989ee2c588e271b121fecc Mon Sep 17 00:00:00 2001 From: kchristin Date: Tue, 15 Oct 2024 12:56:57 +0300 Subject: [PATCH 7/8] Fix suggestions --- lib/Differentiator/ReverseModeVisitor.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c3f0f3b76..276028011 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -109,8 +109,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!m_CUDAGlobalArgs.empty()) if (const auto* DRE = dyn_cast(E)) if (const auto* PVD = dyn_cast(DRE->getDecl())) - // we need to check whether this param is in the global memory of the - // GPU + // Check whether this param is in the global memory of the GPU return m_CUDAGlobalArgs.find(PVD) != m_CUDAGlobalArgs.end(); return false; @@ -454,7 +453,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // if the function is a global kernel, all its parameters reside in the // global memory of the GPU if (m_DiffReq->hasAttr()) - for (auto param : params) + for (auto* param : params) m_CUDAGlobalArgs.emplace(param); llvm::ArrayRef paramsRef = @@ -611,11 +610,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!m_DiffReq.CUDAGlobalArgsIndexes.empty()) for (auto index : m_DiffReq.CUDAGlobalArgsIndexes) m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index)); - // If the function is a global kernel, all its parameters reside in the - // global memory of the GPU - else if (m_DiffReq->hasAttr()) - for (auto param : params) - m_CUDAGlobalArgs.emplace(param); + m_Derivative->setBody(nullptr); if (!m_DiffReq.DeclarationOnly) { @@ -2302,11 +2297,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i); Expr* gradElem = BuildArraySubscript(gradRef, {idx}); Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem); - if (shouldUseCudaAtomicOps(outputArgs[i])) - PostCallStmts.push_back( - BuildCallToCudaAtomicAdd(outputArgs[i], gradExpr)); - else - PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); + // Inputs were not pointers, so the output args are not in global GPU + // memory. Hence, no need to use atomic ops. + PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); NumDiffArgs.push_back(args[i]); } std::string Name = "central_difference"; From 8accef3450853500e5e6b893257db23bece4b85a Mon Sep 17 00:00:00 2001 From: kchristin Date: Tue, 15 Oct 2024 18:13:07 +0300 Subject: [PATCH 8/8] Check that E arg of shouldUseAtomicOps is defined before calling it --- lib/Differentiator/ReverseModeVisitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 276028011..badbf2591 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2406,7 +2406,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, else { derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx); // Create the (target += dfdx) statement. - if (dfdx()) { + if (dfdx() && derivedE) { if (shouldUseCudaAtomicOps(diff_dx)) { Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx()); // Add it to the body statements.