From 985097e2391c9706808d3fea59c2018f72048407 Mon Sep 17 00:00:00 2001 From: Parth Date: Sun, 10 Apr 2022 21:45:09 +0530 Subject: [PATCH] add support for operator overload in reverse mode --- benchmark/BenchmarkedFunctions.h | 8 ++ benchmark/CMakeLists.txt | 1 + benchmark/VectorModeComparison.cpp | 122 ++++++++++++++++++ .../ReverseModeForwPassVisitor.h | 1 + .../clad/Differentiator/ReverseModeVisitor.h | 2 +- .../ReverseModeForwPassVisitor.cpp | 32 ++++- lib/Differentiator/ReverseModeVisitor.cpp | 54 +++++--- test/Gradient/MemberFunctions.C | 115 ++++++++++++++++- 8 files changed, 310 insertions(+), 25 deletions(-) create mode 100644 benchmark/VectorModeComparison.cpp diff --git a/benchmark/BenchmarkedFunctions.h b/benchmark/BenchmarkedFunctions.h index 0c8a8ff1a..138ff6ca4 100644 --- a/benchmark/BenchmarkedFunctions.h +++ b/benchmark/BenchmarkedFunctions.h @@ -39,3 +39,11 @@ inline double product(double p[], int n) { } return prod; } + +///\returns the weighted sum of the elements in \p +inline double weightedSum(double p[], double w[], int n) { + double sum = 0; + for (int i = 0; i < n; i++) + sum += p[i] * w[i]; + return sum; +} \ No newline at end of file diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index d8553179e..2d22a6ea4 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -8,6 +8,7 @@ CB_ADD_GBENCHMARK(Simple Simple.cpp) CB_ADD_GBENCHMARK(AlgorithmicComplexity AlgorithmicComplexity.cpp) CB_ADD_GBENCHMARK(EnzymeCladComparison EnzymeCladComparison.cpp) CB_ADD_GBENCHMARK(MemoryComplexity MemoryComplexity.cpp) +CB_ADD_GBENCHMARK(VectorModeComparison VectorModeComparison.cpp) set (CLAD_BENCHMARK_DEPS clad) get_property(_benchmark_names DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY TESTS) diff --git a/benchmark/VectorModeComparison.cpp b/benchmark/VectorModeComparison.cpp new file mode 100644 index 000000000..3618b7d63 --- /dev/null +++ b/benchmark/VectorModeComparison.cpp @@ -0,0 +1,122 @@ +#include "benchmark/benchmark.h" + +#include "clad/Differentiator/Differentiator.h" + +#include "BenchmarkedFunctions.h" + +// Benchmark forward mode for weighted sum. +static void BM_ForwardModeWeightedSum(benchmark::State& state) { + auto dp0 = clad::differentiate(weightedSum, "p[0]"); + auto dp1 = clad::differentiate(weightedSum, "p[1]"); + auto dp2 = clad::differentiate(weightedSum, "p[2]"); + auto dp3 = clad::differentiate(weightedSum, "p[3]"); + auto dp4 = clad::differentiate(weightedSum, "p[4]"); + + auto dw0 = clad::differentiate(weightedSum, "w[0]"); + auto dw1 = clad::differentiate(weightedSum, "w[1]"); + auto dw2 = clad::differentiate(weightedSum, "w[2]"); + auto dw3 = clad::differentiate(weightedSum, "w[3]"); + auto dw4 = clad::differentiate(weightedSum, "w[4]"); + + constexpr int n = 5; + double inputs[n]; + double weights[n]; + for (int i = 0; i < n; ++i) { + inputs[i] = i + 1; + weights[i] = 1.0 / (double)(i + 1); + } + + double sum = 0; + for (auto _ : state) { + benchmark::DoNotOptimize( + sum += + dp0.execute(inputs, weights, n) + dp1.execute(inputs, weights, n) + + dp2.execute(inputs, weights, n) + dp3.execute(inputs, weights, n) + + dp4.execute(inputs, weights, n) + dw0.execute(inputs, weights, n) + + dw1.execute(inputs, weights, n) + dw2.execute(inputs, weights, n) + + dw3.execute(inputs, weights, n) + dw4.execute(inputs, weights, n)); + } +} +BENCHMARK(BM_ForwardModeWeightedSum); + +// Benchmark reverse mode for weighted sum. +static void BM_ReverseModeWeightedSum(benchmark::State& state) { + auto grad = clad::gradient(weightedSum, "p, w"); + constexpr int n = 5; + + double inputs[n]; + double weights[n]; + for (int i = 0; i < n; ++i) { + inputs[i] = i + 1; + weights[i] = 1.0 / (double)(i + 1); + } + + double dinp[n]; + double dweights[n]; + clad::array_ref dinp_ref(dinp, n); + clad::array_ref dweights_ref(dweights, n); + + double sum = 0; + for (auto _ : state) { + grad.execute(inputs, weights, n, dinp_ref, dweights_ref); + for (int i = 0; i < n; ++i) + sum += dinp[i] + dweights[i]; + } +} +BENCHMARK(BM_ReverseModeWeightedSum); + +// Benchmark enzyme's reverse mode for weighted sum. +static void BM_EnzymeReverseModeWeightedSum(benchmark::State& state) { + auto grad = clad::gradient(weightedSum, "p, w"); + constexpr int n = 5; + + double inputs[n]; + double weights[n]; + for (int i = 0; i < n; ++i) { + inputs[i] = i + 1; + weights[i] = 1.0 / (double)(i + 1); + } + + double dinp[n]; + double dweights[n]; + clad::array_ref dinp_ref(dinp, n); + clad::array_ref dweights_ref(dweights, n); + + double sum = 0; + for (auto _ : state) { + grad.execute(inputs, weights, n, dinp_ref, dweights_ref); + for (int i = 0; i < n; ++i) + sum += dinp[i] + dweights[i]; + } +} +BENCHMARK(BM_EnzymeReverseModeWeightedSum); + +// Benchmark vector forward mode for weighted sum. +static void BM_VectorForwardModeWeightedSum(benchmark::State& state) { + auto vm_grad = + clad::differentiate(weightedSum, "p, w"); + constexpr int n = 5; + + double inputs[n]; + double weights[n]; + for (int i = 0; i < n; ++i) { + inputs[i] = i + 1; + weights[i] = 1.0 / (double)(i + 1); + } + + double dinp[n]; + double dweights[n]; + clad::array_ref dinp_ref(dinp, n); + clad::array_ref dweights_ref(dweights, n); + + double sum = 0; + for (auto _ : state) { + vm_grad.execute(inputs, weights, n, dinp_ref, dweights_ref); + for (int i = 0; i < n; ++i) + sum += dinp[i] + dweights[i]; + } +} +BENCHMARK(BM_VectorForwardModeWeightedSum); + +// Define our main. +BENCHMARK_MAIN(); diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h index 5d60e6cb6..fc2236306 100644 --- a/include/clad/Differentiator/ReverseModeForwPassVisitor.h +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -30,6 +30,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor { StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override; StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override; StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; + StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp) override; }; } // namespace clad diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 3b7c4b1cb..64dcb52af 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -336,7 +336,7 @@ namespace clad { StmtDiff VisitParenExpr(const clang::ParenExpr* PE); virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); StmtDiff VisitStmt(const clang::Stmt* S); - StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); + virtual StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); /// Decl is not Stmt, so it cannot be visited directly. StmtDiff VisitWhileStmt(const clang::WhileStmt* WS); diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index c22f4fadf..3386cfb47 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -29,7 +29,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, DiffParams args{}; std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - auto fnName = m_Function->getNameAsString() + "_forw"; + auto fnName = clad::utils::ComputeEffectiveFnName(m_Function) + "_forw"; auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); auto paramTypes = ComputeParamTypes(args); @@ -86,8 +86,6 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, QualType ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, QualType xType) { - assert(yType.getNonReferenceType()->isRealType() && - "yType should be a builtin-numerical scalar type!!"); QualType xValueType = utils::GetValueType(xType); // derivative variables should always be of non-const type. xValueType.removeLocalConst(); @@ -106,7 +104,7 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { QualType effectiveReturnType = m_Function->getReturnType().getNonReferenceType(); - + if (const auto* MD = dyn_cast(m_Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { @@ -240,4 +238,30 @@ ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get(); return {newRS}; } + +StmtDiff +ReverseModeForwPassVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { + auto opCode = UnOp->getOpcode(); + StmtDiff diff{}; + // If it is a post-increment/decrement operator, its result is a reference + // and we should return it. + Expr* ResultRef = nullptr; + if (opCode == UnaryOperatorKind::UO_Deref) { + if (const auto* MD = dyn_cast(m_Function)) { + if (MD->isInstance()) { + diff = Visit(UnOp->getSubExpr()); + Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); + Expr* derivedE = diff.getExpr_dx(); + return {cloneE, derivedE}; + } + } + } else if (opCode == UO_Plus) + diff = Visit(UnOp->getSubExpr(), dfdx()); + else if (opCode == UO_Minus) { + auto d = BuildOp(UO_Minus, dfdx()); + diff = Visit(UnOp->getSubExpr(), d); + } + Expr* op = BuildOp(opCode, diff.getExpr()); + return StmtDiff(op, ResultRef); +} } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9e00797f2..a64afbe90 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1371,7 +1371,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not // contribute to gradient. - if (!NArgs && !isa(CE)) + if ((NArgs == 0U) && !isa(CE) && + !isa(CE)) return StmtDiff(Clone(CE)); // Stores the call arguments for the function to be derived @@ -1391,7 +1392,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // derived function. In the case of member functions, `implicit` // this object is always passed by reference. if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(CE)) { + !isa(CE) && !isa(CE)) { for (const Expr* Arg : CE->arguments()) { StmtDiff ArgDiff = Visit(Arg, dfdx()); CallArgs.push_back(ArgDiff.getExpr()); @@ -1415,9 +1416,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: We should add instructions for handling non-differentiable // arguments. Currently we are implicitly assuming function call only // contains differentiable arguments. - for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + bool isCXXOperatorCall = isa(CE); + + for (std::size_t i = static_cast(isCXXOperatorCall), + e = CE->getNumArgs(); + i != e; ++i) { const Expr* arg = CE->getArg(i); - auto PVD = FD->getParamDecl(i); + const auto* PVD = + FD->getParamDecl(i - static_cast(isCXXOperatorCall)); StmtDiff argDiff{}; bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); // We do not need to create result arg for arguments passed by reference @@ -1597,8 +1603,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. - if (auto MCE = dyn_cast(CE)) { + if (auto MCE = dyn_cast(CE)) baseDiff = Visit(MCE->getImplicitObjectArgument()); + else if (const auto* OCE = dyn_cast(CE)) + baseDiff = Visit(OCE->getArg(0)); + if (baseDiff.getExpr()) { StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); if (isInsideLoop) { addToCurrentBlock(baseDiffStore.getExpr()); @@ -1689,15 +1698,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackCallArgs = DerivedCallArgs; if (pullback) - pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(), + pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - + static_cast(isCXXOperatorCall), pullback); // Try to find it in builtin derivatives - std::string customPullback = FD->getNameAsString() + "_pullback"; + if (baseDiff.getExpr()) + pullbackCallArgs.insert( + pullbackCallArgs.begin(), + BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr())); + std::string customPullback = + clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), const_cast(FD->getDeclContext())); + if (baseDiff.getExpr()) + pullbackCallArgs.erase(pullbackCallArgs.begin()); } // should be true if we are using numerical differentiation to differentiate @@ -1728,7 +1745,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // derive the called function. DiffRequest pullbackRequest{}; pullbackRequest.Function = FD; - pullbackRequest.BaseFunctionName = FD->getNameAsString(); + pullbackRequest.BaseFunctionName = + clad::utils::ComputeEffectiveFnName(FD); pullbackRequest.Mode = DiffMode::experimental_pullback; // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; @@ -1775,7 +1793,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, usingNumericalDiff = true; } } else if (pullbackFD) { - if (isa(CE)) { + if (baseDiff.getExpr()) { Expr* baseE = baseDiff.getExpr(); OverloadedDerivedFn = BuildCallExprToMemFn( baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD); @@ -1861,7 +1879,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DiffRequest calleeFnForwPassReq; calleeFnForwPassReq.Function = FD; calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; - calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); + calleeFnForwPassReq.BaseFunctionName = + clad::utils::ComputeEffectiveFnName(FD); calleeFnForwPassReq.VerboseDiags = true; FunctionDecl* calleeFnForwPassFD = plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); @@ -1878,20 +1897,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We cannot reuse the derivatives previously computed because // they might contain 'clad::pop(..)` expression. - if (isa(CE)) { + if (isa(CE) || isa(CE)) { Expr* derivedBase = baseDiff.getExpr_dx(); // FIXME: We may need this if-block once we support pointers, and // passing pointers-by-reference if // (isCladArrayType(derivedBase->getType())) // CallArgs.push_back(derivedBase); // else + // Currently derivedBase `*d_this` can never be CladArrayType CallArgs.push_back( BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); } - for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + for (std::size_t i = static_cast(isCXXOperatorCall), + e = CE->getNumArgs(); + i != e; ++i) { const Expr* arg = CE->getArg(i); - const ParmVarDecl* PVD = FD->getParamDecl(i); + const ParmVarDecl* PVD = + FD->getParamDecl(i - static_cast(isCXXOperatorCall)); StmtDiff argDiff = Visit(arg); if ((argDiff.getExpr_dx() != nullptr) && PVD->getType()->isReferenceType()) { @@ -1906,7 +1929,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); } - if (isa(CE)) { + if (isa(CE) || isa(CE)) { Expr* baseE = baseDiff.getExpr(); call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), CallArgs, calleeFnForwPassFD); @@ -3154,9 +3177,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_Mode == DiffMode::reverse) assert(yType->isRealType() && "yType should be a non-reference builtin-numerical scalar type!!"); - else if (m_Mode == DiffMode::experimental_pullback) - assert(yType.getNonReferenceType()->isRealType() && - "yType should be a builtin-numerical scalar type!!"); QualType xValueType = utils::GetValueType(xType); // derivative variables should always be of non-const type. xValueType.removeLocalConst(); diff --git a/test/Gradient/MemberFunctions.C b/test/Gradient/MemberFunctions.C index a6ffeac36..c7d09c4b7 100644 --- a/test/Gradient/MemberFunctions.C +++ b/test/Gradient/MemberFunctions.C @@ -703,7 +703,19 @@ public: // CHECK-NEXT: } // CHECK-NEXT: } - double& ref_mem_fn(double i) {return x;} + double& ref_mem_fn(double i) { + x = +i; + x = -i; + return x; + } + SimpleFunctions& operator+=(double value) { + x += value; + return *this; + } + SimpleFunctions& operator++() { + x += 1.0; + return *this; + } void mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void const_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); @@ -763,13 +775,29 @@ double fn2(SimpleFunctions& sf, double i) { } // CHECK: void ref_mem_fn_pullback(double i, double _d_y, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: this->x = +i; +// CHECK-NEXT: this->x = -i; // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: (* _d_this).x += _d_y; +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d1 = (* _d_this).x; +// CHECK-NEXT: * _d_i += -_r_d1; +// CHECK-NEXT: (* _d_this).x -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = (* _d_this).x; +// CHECK-NEXT: * _d_i += _r_d0; +// CHECK-NEXT: (* _d_this).x -= _r_d0; +// CHECK-NEXT: } // CHECK-NEXT: } + // CHECK: clad::ValueAndAdjoint ref_mem_fn_forw(double i, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: this->x = +i; +// CHECK-NEXT: this->x = -i; // CHECK-NEXT: return {this->x, (* _d_this).x}; // CHECK-NEXT: } + // CHECK: void fn2_grad(SimpleFunctions &sf, double i, clad::array_ref _d_sf, clad::array_ref _d_i) { // CHECK-NEXT: double _t0; // CHECK-NEXT: SimpleFunctions _t1; @@ -787,6 +815,78 @@ double fn2(SimpleFunctions& sf, double i) { // CHECK-NEXT: } +double fn3(SimpleFunctions& v, double value) { + v += value; + return v.x; +} + +// CHECK: void operator_plus_equal_pullback(double value, SimpleFunctions _d_y, clad::array_ref _d_this, clad::array_ref _d_value) { +// CHECK-NEXT: this->x += value; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = (* _d_this).x; +// CHECK-NEXT: (* _d_this).x += _r_d0; +// CHECK-NEXT: * _d_value += _r_d0; +// CHECK-NEXT: (* _d_this).x -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: clad::ValueAndAdjoint operator_plus_equal_forw(double value, clad::array_ref _d_this, clad::array_ref _d_value) { +// CHECK-NEXT: this->x += value; +// CHECK-NEXT: return {*this, (* _d_this)}; +// CHECK-NEXT: } + +// CHECK: void fn3_grad(SimpleFunctions &v, double value, clad::array_ref _d_v, clad::array_ref _d_value) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: SimpleFunctions _t1; +// CHECK-NEXT: _t0 = value; +// CHECK-NEXT: _t1 = v; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.operator_plus_equal_forw(_t0, &(* _d_v), nullptr); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_v).x += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: _t1.operator_plus_equal_pullback(_t0, {}, &(* _d_v), &_grad0); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_value += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn4(SimpleFunctions& v) { + ++v; + return v.x; +} + +// CHECK: void operator_plus_plus_pullback(SimpleFunctions _d_y, clad::array_ref _d_this) { +// CHECK-NEXT: this->x += 1.; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = (* _d_this).x; +// CHECK-NEXT: (* _d_this).x += _r_d0; +// CHECK-NEXT: (* _d_this).x -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: clad::ValueAndAdjoint operator_plus_plus_forw(clad::array_ref _d_this) { +// CHECK-NEXT: this->x += 1.; +// CHECK-NEXT: return {*this, (* _d_this)}; +// CHECK-NEXT: } + +// CHECK: void fn4_grad(SimpleFunctions &v, clad::array_ref _d_v) { +// CHECK-NEXT: SimpleFunctions _t0; +// CHECK-NEXT: _t0 = v; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = _t0.operator_plus_plus_forw(&(* _d_v)); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_v).x += 1; +// CHECK-NEXT: _t0.operator_plus_plus_pullback({}, &(* _d_v)); +// CHECK-NEXT: } + int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn); @@ -821,12 +921,21 @@ int main() { printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00 } - SimpleFunctions sf(2, 3); + SimpleFunctions sf1(2, 3), sf2(3, 4), sf3(4, 5); SimpleFunctions d_sf; + auto d_fn2 = clad::gradient(fn2); - d_fn2.execute(sf, 2, &d_sf, &result[0]); + d_fn2.execute(sf1, 2, &d_sf, &result[0]); + printf("%.2f", result[0]); //CHECK-EXEC: 39.00 + + auto d_fn3 = clad::gradient(fn3); + d_fn3.execute(sf2, 3, &d_sf, &result[0]); printf("%.2f", result[0]); //CHECK-EXEC: 40.00 + auto d_fn4 = clad::gradient(fn4); + d_fn4.execute(sf3, &d_sf); + printf("%.2f", d_sf.x); //CHECK-EXEC: 2.00 + auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i"); // CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref _d_this, clad::array_ref _d_i) const volatile & {