From 70a186d7d5eb666038447d81483fc0a57eb2deef Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 14 Mar 2024 01:30:18 +0200 Subject: [PATCH] Fix recursive call differentiation and add a test --- lib/Differentiator/ReverseModeVisitor.cpp | 7 +--- test/Gradient/FunctionCalls.C | 46 +++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9471e898b..13ae8aa21 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1790,7 +1790,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!OverloadedDerivedFn) { if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) { // Recursive call. - auto* selfRef = + Expr* selfRef = m_Sema .BuildDeclarationNameExpr( CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) @@ -1798,10 +1798,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, noLoc, - llvm::MutableArrayRef(DerivedCallArgs), - noLoc) - .get(); + .ActOnCallExpr(getCurrentScope(), selfRef, noLoc, pullbackCallArgs, noLoc).get(); } else { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 013dbc7b7..03fb4045b 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -715,6 +715,50 @@ double fn15(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double recFun (double x, double y) { + if (x > y) + return recFun(x-1, y); + return x * y; +} + +//CHECK: void recFun_pullback(double x, double y, double _d_y0, clad::array_ref _d_x, clad::array_ref _d_y) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: _cond0 = x > y; +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: goto _label1; +//CHECK-NEXT: _label1: +//CHECK-NEXT: { +//CHECK-NEXT: * _d_x += _d_y0 * y; +//CHECK-NEXT: * _d_y += x * _d_y0; +//CHECK-NEXT: } +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: recFun_pullback(x - 1, y, _d_y0, &_r0, &_r1); +//CHECK-NEXT: * _d_x += _r0; +//CHECK-NEXT: * _d_y += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn16(double x, double y) { + return recFun(x, y); +} + +//CHECK: void fn16_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: recFun_pullback(x, y, 1, &_r0, &_r1); +//CHECK-NEXT: * _d_x += _r0; +//CHECK-NEXT: * _d_y += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i