From 58904205703699cf4a84e8afb3081505a03730e7 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 26 Oct 2023 12:37:44 +0530 Subject: [PATCH] Add support for std::min, std::max and std::clamp functions --- .../clad/Differentiator/BuiltinDerivatives.h | 60 +++++++++++++++++ lib/Differentiator/BaseForwardModeVisitor.cpp | 3 +- lib/Differentiator/CladUtils.cpp | 9 ++- lib/Differentiator/ReverseModeVisitor.cpp | 6 +- test/FirstDerivative/FunctionCalls.C | 1 + test/Gradient/FunctionCalls.C | 65 ++++++++++++++++++- 6 files changed, 137 insertions(+), 7 deletions(-) diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index 54c710c4d..fac077f00 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -13,6 +13,7 @@ namespace custom_derivatives{} #include "clad/Differentiator/ArrayRef.h" #include "clad/Differentiator/CladConfig.h" +#include #include namespace clad { @@ -139,6 +140,61 @@ CUDA_HOST_DEVICE void fma_pullback(T1 a, T2 b, T3 c, T4 d_y, *d_c += d_y; } +template +CUDA_HOST_DEVICE ValueAndPushforward +min_pushforward(const T& a, const T& b, const T& d_a, const T& d_b) { + return {::std::min(a, b), a < b ? d_a : d_b}; +} + +template +CUDA_HOST_DEVICE ValueAndPushforward +max_pushforward(const T& a, const T& b, const T& d_a, const T& d_b) { + return {::std::max(a, b), a < b ? d_b : d_a}; +} + +template +CUDA_HOST_DEVICE void min_pullback(const T& a, const T& b, U d_y, + clad::array_ref d_a, + clad::array_ref d_b) { + if (a < b) + *d_a += d_y; + else + *d_b += d_y; +} + +template +CUDA_HOST_DEVICE void max_pullback(const T& a, const T& b, U d_y, + clad::array_ref d_a, + clad::array_ref d_b) { + if (a < b) + *d_b += d_y; + else + *d_a += d_y; +} + +#if __cplusplus >= 201703L +template +CUDA_HOST_DEVICE ValueAndPushforward +clamp_pushforward(const T& v, const T& lo, const T& hi, const T& d_v, + const T& d_lo, const T& d_hi) { + return {::std::clamp(v, lo, hi), v < lo ? d_lo : hi < v ? d_hi : d_v}; +} + +template +CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi, + const U& d_y, + clad::array_ref d_v, + clad::array_ref d_lo, + clad::array_ref d_hi) { + if (v < lo) + *d_lo += d_y; + else if (hi < v) + *d_hi += d_y; + else + *d_v += d_y; +} +#endif + } // namespace std // These are required because C variants of mathematical functions are // defined in global namespace. @@ -150,6 +206,10 @@ using std::floor_pushforward; using std::fma_pullback; using std::fma_pushforward; using std::log_pushforward; +using std::max_pullback; +using std::max_pushforward; +using std::min_pullback; +using std::min_pushforward; using std::pow_pullback; using std::pow_pushforward; using std::sin_pushforward; diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index dac3e77d5..cfae69aca 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -904,9 +904,8 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( SourceLocation Loc; - if (noOverloadExists(UnresolvedLookup, MARargs)) { + if (noOverloadExists(UnresolvedLookup, MARargs)) return 0; - } OverloadedFn = m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index b9d6a959a..08aa5ea56 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -158,8 +158,9 @@ namespace clad { auto typePtr = QT.getTypePtr(); if (typePtr->isRecordType() && !typePtr->getAs()) { CXXScopeSpec CSS; - clang::CXXRecordDecl const* recordDecl = typePtr->getAsCXXRecordDecl(); - clang::DeclContext const* declContext = static_cast(recordDecl); + const clang::CXXRecordDecl* recordDecl = typePtr->getAsCXXRecordDecl(); + const auto* declContext = + static_cast(recordDecl); utils::BuildNNS(semaRef, const_cast(declContext), CSS); NestedNameSpecifier* NS = CSS.getScopeRep(); if (auto* Prefix = NS->getPrefix()) @@ -185,6 +186,10 @@ namespace clad { DC2 = DC2->getParent(); continue; } + if (DC2->isInlineNamespace()) { + DC2 = DC2->getParent(); + continue; + } // We don't want to 'extend' the DC1 context with class declarations. // There are 2 main reasons for this: // - Class declaration context cannot be extended the way namespace diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 5009c42d9..4e5d7020e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1394,7 +1394,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // if it's of type MaterializeTemporaryExpr, then check its // subexpression. if (const auto* MTE = dyn_cast(arg)) - arg = clad_compat::GetSubExpr(MTE); + arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts(); if (!isa(arg) && !isa(arg)) { allArgsAreConstantLiterals = false; break; @@ -1934,7 +1934,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* call = nullptr; - if (FD->getReturnType()->isReferenceType()) { + QualType returnType = FD->getReturnType(); + if (returnType->isReferenceType() && + !returnType.getNonReferenceType().isConstQualified()) { DiffRequest calleeFnForwPassReq; calleeFnForwPassReq.Function = FD; calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index 276d44fca..639eb69bf 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -4,6 +4,7 @@ #include "clad/Differentiator/Differentiator.h" +#include #include int printf(const char* fmt, ...); diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index aa4afad44..60cfa5fb4 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -1,4 +1,4 @@ -// RUN: %cladnumdiffclang %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s +// RUN: %cladnumdiffclang -std=c++17 %s -I%S/../../include -oFunctionCalls.out 2>&1 | FileCheck %s // RUN: ./FunctionCalls.out | FileCheck -check-prefix=CHECK-EXEC %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -533,6 +533,67 @@ double fn9(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn10(double x, double y) { + double out = x; + out = std::max(out, 0.0); + out = std::min(out, 10.0); + out = std::clamp(out, 3.0, 7.0); + return out * y; +} + +// CHECK: void fn10_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +// CHECK-NEXT: double _d_out = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double out = x; +// CHECK-NEXT: _t0 = out; +// CHECK-NEXT: out = std::max(out, 0.); +// CHECK-NEXT: _t1 = out; +// CHECK-NEXT: out = std::min(out, 10.); +// CHECK-NEXT: _t2 = out; +// CHECK-NEXT: out = std::clamp(out, 3., 7.); +// CHECK-NEXT: _t4 = out; +// CHECK-NEXT: _t3 = y; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r7 = 1 * _t3; +// CHECK-NEXT: _d_out += _r7; +// CHECK-NEXT: double _r8 = _t4 * 1; +// CHECK-NEXT: * _d_y += _r8; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d2 = _d_out; +// CHECK-NEXT: double _grad5 = 0.; +// CHECK-NEXT: double _grad6 = 0.; +// CHECK-NEXT: clad::custom_derivatives::std::clamp_pullback(_t2, 3., 7., _r_d2, &_d_out, &_grad5, &_grad6); +// CHECK-NEXT: double _r4 = _d_out; +// CHECK-NEXT: double _r5 = _grad5; +// CHECK-NEXT: double _r6 = _grad6; +// CHECK-NEXT: _d_out -= _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d1 = _d_out; +// CHECK-NEXT: double _grad3 = 0.; +// CHECK-NEXT: clad::custom_derivatives::std::min_pullback(_t1, 10., _r_d1, &_d_out, &_grad3); +// CHECK-NEXT: double _r2 = _d_out; +// CHECK-NEXT: double _r3 = _grad3; +// CHECK-NEXT: _d_out -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = _d_out; +// CHECK-NEXT: double _grad1 = 0.; +// CHECK-NEXT: clad::custom_derivatives::std::max_pullback(_t0, 0., _r_d0, &_d_out, &_grad1); +// CHECK-NEXT: double _r0 = _d_out; +// CHECK-NEXT: double _r1 = _grad1; +// CHECK-NEXT: _d_out -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: * _d_x += _d_out; +// CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i