From 27958c363a4349cca496698ccfd44bda885411a2 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Sat, 14 Dec 2024 21:02:41 +0000 Subject: [PATCH] Respect shadow declarations when writing propagators. In cases where the public declaration is introduced with using declaration pointing to an internal namespace with the implementation details, we should put the propagator function in the namespace of the public function and not the implementation. That would allow users to position their pullbacks in the same namespace structure as the used functions. --- .../clad/Differentiator/DerivativeBuilder.h | 6 ++-- .../clad/Differentiator/ReverseModeVisitor.h | 2 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 9 ++--- lib/Differentiator/DerivativeBuilder.cpp | 22 ++++++++++++- lib/Differentiator/ReverseModeVisitor.cpp | 33 ++++++++++--------- test/FirstDerivative/BuiltinDerivatives.C | 25 ++++++++++++++ test/Gradient/Assignments.C | 2 +- test/Gradient/Gradients.C | 6 ++-- test/Jacobian/FunctionCalls.C | 4 +-- test/NestedCalls/NestedCalls.C | 8 ++--- 10 files changed, 82 insertions(+), 35 deletions(-) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 7c9d843d7..3c98f3ff2 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -107,6 +107,8 @@ namespace clad { /// overload to be found. /// \param[in] CallArgs The call args to be used to resolve to the /// correct overload. + /// \param[in] callSite - The call expression which triggers the custom + /// derivative call. /// \param[in] forCustomDerv A flag to keep track of which /// namespace we should look in for the overloads. /// \param[in] namespaceShouldExist A flag to enforce assertion failure @@ -117,8 +119,8 @@ namespace clad { /// null otherwise. clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, const clang::DeclContext* originalFnDC, - bool forCustomDerv = true, bool namespaceShouldExist = true, + clang::Scope* S, const clang::Expr* callSite, bool forCustomDerv = true, + bool namespaceShouldExist = true, clang::Expr* CUDAExecConfig = nullptr); bool noOverloadExists(clang::Expr* UnresolvedLookup, llvm::MutableArrayRef ARargs); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index ba0cf7c14..42f25cf1d 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -107,7 +107,7 @@ namespace clad { /// Tries to find and build call to user-provided `_forw` function. clang::Expr* BuildCallToCustomForwPassFn( - const clang::FunctionDecl* FD, llvm::ArrayRef primalArgs, + const clang::Expr* callSite, llvm::ArrayRef primalArgs, llvm::ArrayRef derivedArgs, clang::Expr* baseExpr); public: diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 382ba00db..41b174f5c 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1226,8 +1226,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { std::string customPushforward = clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPushforward, customDerivativeArgs, getCurrentScope(), - FD->getDeclContext()); + customPushforward, customDerivativeArgs, getCurrentScope(), CE); // Custom derivative templates can be written in a // general way that works for both vectorized and non-vectorized // modes. We have to also look for the pushforward with the regular name. @@ -1235,8 +1234,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { customPushforward = clad::utils::ComputeEffectiveFnName(FD) + "_pushforward"; callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPushforward, customDerivativeArgs, getCurrentScope(), - FD->getDeclContext()); + customPushforward, customDerivativeArgs, getCurrentScope(), CE); } if (!isLambda) { // Check if it is a recursive call. @@ -2316,8 +2314,7 @@ clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall( clad::utils::ComputeEffectiveFnName(CE->getConstructor()) + GetPushForwardFunctionSuffix(); Expr* pushforwardCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPushforwardName, customPushforwardArgs, getCurrentScope(), - CE->getConstructor()->getDeclContext()); + customPushforwardName, customPushforwardArgs, getCurrentScope(), CE); return pushforwardCall; } } // end namespace clad diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 859a52b18..b1fc567c8 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -247,9 +247,29 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, const clang::DeclContext* originalFnDC, + clang::Scope* S, const clang::Expr* callSite, bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/, Expr* CUDAExecConfig /*=nullptr*/) { + const DeclContext* originalFnDC = nullptr; + + // FIXME: callSite must not be numm but it comes when we try to build + // a numerical diff call. We should merge both paths and remove the + // special branches being taken for propagators and numerical diff. + if (callSite) { + // Check if the callSite is not associated with a shadow declaration. + if (auto* ME = dyn_cast(callSite)) { + originalFnDC = ME->getMethodDecl()->getParent(); + } else if (auto* CE = dyn_cast(callSite)) { + const Expr* Callee = CE->getCallee()->IgnoreParenCasts(); + if (auto* DRE = dyn_cast(Callee)) + originalFnDC = DRE->getFoundDecl()->getDeclContext(); + else if (auto* MemberE = dyn_cast(Callee)) + originalFnDC = MemberE->getFoundDecl().getDecl()->getDeclContext(); + } else if (auto* CtorExpr = dyn_cast(callSite)) { + originalFnDC = CtorExpr->getConstructor()->getDeclContext(); + } + } + CXXScopeSpec SS; LookupResult R = LookupCustomDerivativeOrNumericalDiff( Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 7e6f5486e..01766e53b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1827,8 +1827,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivedCallArgs.front()->getType(), m_Context, 1)); OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPushforward, pushforwardCallArgs, getCurrentScope(), - FD->getDeclContext(), + customPushforward, pushforwardCallArgs, getCurrentScope(), CE, /*forCustomDerv=*/true, /*namespaceShouldExist=*/true, CUDAExecConfig); if (OverloadedDerivedFn) @@ -1931,8 +1930,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPullback, pullbackCallArgs, getCurrentScope(), - FD->getDeclContext(), + customPullback, pullbackCallArgs, getCurrentScope(), CE, /*forCustomDerv=*/true, /*namespaceShouldExist=*/true, CUDAExecConfig); if (baseDiff.getExpr()) @@ -2064,7 +2062,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseDiff.getExpr_dx(), Loc)); if (Expr* customForwardPassCE = - BuildCallToCustomForwPassFn(FD, CallArgs, CallArgDx, baseExpr)) { + BuildCallToCustomForwPassFn(CE, CallArgs, CallArgDx, baseExpr)) { if (!utils::isNonConstReferenceType(returnType) && !returnType->isPointerType()) return StmtDiff{customForwardPassCE}; @@ -2214,7 +2212,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::string Name = "central_difference"; return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( Name, NumDiffArgs, getCurrentScope(), - /*OriginalFnDC=*/nullptr, + /*callSite=*/nullptr, /*forCustomDerv=*/false, /*namespaceShouldExist=*/false, CUDAExecConfig); } @@ -4247,8 +4245,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::string customPullbackName = "constructor_pullback"; if (Expr* customPullbackCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPullbackName, pullbackArgs, getCurrentScope(), - CE->getConstructor()->getDeclContext())) { + customPullbackName, pullbackArgs, getCurrentScope(), CE)) { curRevBlock.insert(it, customPullbackCall); if (m_TrackConstructorPullbackInfo) { setConstructorPullbackCallInfo(llvm::cast(customPullbackCall), @@ -4278,9 +4275,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // SomeClass _d_c = _t0.adjoint; // SomeClass c = _t0.value; // ``` - if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn( - CE->getConstructor(), primalArgs, reverseForwAdjointArgs, - /*baseExpr=*/nullptr)) { + if (Expr* customReverseForwFnCall = + BuildCallToCustomForwPassFn(CE, primalArgs, reverseForwAdjointArgs, + /*baseExpr=*/nullptr)) { if (RD->isAggregate()) { SmallString<128> Name_class; llvm::raw_svector_ostream OS_class(Name_class); @@ -4555,16 +4552,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn( - const FunctionDecl* FD, llvm::ArrayRef primalArgs, + const Expr* callSite, llvm::ArrayRef primalArgs, llvm::ArrayRef derivedArgs, Expr* baseExpr) { - std::string forwPassFnName = - clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw"; llvm::SmallVector args; if (baseExpr) { baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr, m_DiffReq->getLocation()); args.push_back(baseExpr); } + const FunctionDecl* FD = nullptr; + if (auto* CE = dyn_cast(callSite)) + FD = CE->getDirectCallee(); + else + FD = cast(callSite)->getConstructor(); + if (auto CD = llvm::dyn_cast(FD)) { const RecordDecl* RD = CD->getParent(); QualType constructorReverseForwTagT = @@ -4582,9 +4583,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } args.append(primalArgs.begin(), primalArgs.end()); args.append(derivedArgs.begin(), derivedArgs.end()); + std::string forwPassFnName = + clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw"; Expr* customForwPassCE = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - forwPassFnName, args, getCurrentScope(), FD->getDeclContext()); + forwPassFnName, args, getCurrentScope(), callSite); return customForwPassCE; } diff --git a/test/FirstDerivative/BuiltinDerivatives.C b/test/FirstDerivative/BuiltinDerivatives.C index 1d0807176..d66168fb2 100644 --- a/test/FirstDerivative/BuiltinDerivatives.C +++ b/test/FirstDerivative/BuiltinDerivatives.C @@ -7,6 +7,28 @@ #include "../TestUtils.h" extern "C" int printf(const char* fmt, ...); + +namespace N { + namespace impl { + double sq(double x); + } + using impl::sq; // using shadow +} + +namespace clad { + namespace custom_derivatives { + namespace N { + clad::ValueAndPushforward sq_pushforward(double x, double *d_x) { + return { x * x, 2 * x }; + } + } + } +} + +float f0 (float x) { + return N::sq(x); // must find the sq_pushforward. +} + namespace clad{ namespace custom_derivatives{ float f1_darg0(float x) { @@ -296,6 +318,9 @@ int main () { //expected-no-diagnostics double d_result[2]; int i_result[1]; + auto f0_darg0 = clad::differentiate(f0, 0); + printf("Result is = %f\n", f0_darg0.execute(2)); // CHECK-EXEC: Result is = -0.952413 + auto f1_darg0 = clad::differentiate(f1, 0); printf("Result is = %f\n", f1_darg0.execute(60)); // CHECK-EXEC: Result is = -0.952413 diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index 99ba93898..c88ed4308 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -716,7 +716,7 @@ double f19(double a, double b) { //CHECK-NEXT: double _r0 = 0.; //CHECK-NEXT: double _r1 = 0.; //CHECK-NEXT: double _r2 = 0.; -//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2); +//CHECK-NEXT: clad::custom_derivatives::std::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2); //CHECK-NEXT: *_d_a += _r0; //CHECK-NEXT: *_d_b += _r1; //CHECK-NEXT: *_d_b += _r2; diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index a23fe2e8b..e2a4f172f 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -406,7 +406,7 @@ void f_norm_grad(double x, //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0.; //CHECK-NEXT: double _r5 = 0.; -//CHECK-NEXT: clad::custom_derivatives::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5); +//CHECK-NEXT: clad::custom_derivatives::std::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5); //CHECK-NEXT: double _r1 = 0.; //CHECK-NEXT: double _r2 = 0.; //CHECK-NEXT: double _r3 = 0.; @@ -430,10 +430,10 @@ void f_sin_grad(double x, double y, double *_d_x, double *_d_y); //CHECK-NEXT: double _t0 = (std::sin(x) + std::sin(y)); //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0.; -//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; +//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::std::sin_pushforward(x, 1.).pushforward; //CHECK-NEXT: *_d_x += _r0; //CHECK-NEXT: double _r1 = 0.; -//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward; +//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::std::sin_pushforward(y, 1.).pushforward; //CHECK-NEXT: *_d_y += _r1; //CHECK-NEXT: *_d_x += _t0 * 1; //CHECK-NEXT: *_d_y += _t0 * 1; diff --git a/test/Jacobian/FunctionCalls.C b/test/Jacobian/FunctionCalls.C index f46fb1931..4d3fba25d 100644 --- a/test/Jacobian/FunctionCalls.C +++ b/test/Jacobian/FunctionCalls.C @@ -19,10 +19,10 @@ void fn1(double i, double j, double* output) { // CHECK-NEXT: clad::array _d_vector_i = clad::one_hot_vector(indepVarCount, {{0U|0UL|0ULL}}); // CHECK-NEXT: clad::array _d_vector_j = clad::one_hot_vector(indepVarCount, {{1U|1UL|1ULL}}); // CHECK-NEXT: *_d_vector_output = clad::identity_matrix(_d_vector_output->rows(), indepVarCount, {{2U|2UL|2ULL}}); -// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::pow_pushforward(i, j, _d_vector_i, _d_vector_j); +// CHECK-NEXT: {{.*}} _t0 = clad::custom_derivatives::std::pow_pushforward(i, j, _d_vector_i, _d_vector_j); // CHECK-NEXT: *_d_vector_output[0] = _t0.pushforward; // CHECK-NEXT: output[0] = _t0.value; -// CHECK-NEXT: {{.*}} _t1 = clad::custom_derivatives::pow_pushforward(j, i, _d_vector_j, _d_vector_i); +// CHECK-NEXT: {{.*}} _t1 = clad::custom_derivatives::std::pow_pushforward(j, i, _d_vector_j, _d_vector_i); // CHECK-NEXT: *_d_vector_output[1] = _t1.pushforward; // CHECK-NEXT: output[1] = _t1.value; // CHECK-NEXT: } diff --git a/test/NestedCalls/NestedCalls.C b/test/NestedCalls/NestedCalls.C index 9c9eb56be..dc7911462 100644 --- a/test/NestedCalls/NestedCalls.C +++ b/test/NestedCalls/NestedCalls.C @@ -57,9 +57,9 @@ int main () { // expected-no-diagnostics // CHECK: clad::ValueAndPushforward sq_pushforward(double x, double _d_x); // CHECK: clad::ValueAndPushforward one_pushforward(double x, double _d_x) { -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives::sin_pushforward(x, _d_x); +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives::std::sin_pushforward(x, _d_x); // CHECK-NEXT: clad::ValueAndPushforward _t1 = sq_pushforward(_t0.value, _t0.pushforward); -// CHECK-NEXT: ValueAndPushforward _t2 = clad::custom_derivatives::cos_pushforward(x, _d_x); +// CHECK-NEXT: ValueAndPushforward _t2 = clad::custom_derivatives::std::cos_pushforward(x, _d_x); // CHECK-NEXT: clad::ValueAndPushforward _t3 = sq_pushforward(_t2.value, _t2.pushforward); // CHECK-NEXT: return {_t1.value + _t3.value, _t1.pushforward + _t3.pushforward}; // CHECK-NEXT: } @@ -71,12 +71,12 @@ int main () { // expected-no-diagnostics //CHECK-NEXT: double _r0 = 0.; //CHECK-NEXT: sq_pullback(std::sin(x), _d_y, &_r0); //CHECK-NEXT: double _r1 = 0.; -//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; +//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::std::sin_pushforward(x, 1.).pushforward; //CHECK-NEXT: *_d_x += _r1; //CHECK-NEXT: double _r2 = 0.; //CHECK-NEXT: sq_pullback(std::cos(x), _d_y, &_r2); //CHECK-NEXT: double _r3 = 0.; -//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::cos_pushforward(x, 1.).pushforward; +//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::std::cos_pushforward(x, 1.).pushforward; //CHECK-NEXT: *_d_x += _r3; //CHECK-NEXT: } //CHECK-NEXT: }