From 1b81084bd69b15c776582e24aa5d7bbe411c5b81 Mon Sep 17 00:00:00 2001 From: ovdiiuv <104850830+ovdiiuv@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:57:40 +0200 Subject: [PATCH] Product of references in different scope fix (#1030) --- lib/Differentiator/ReverseModeVisitor.cpp | 10 +++++++++- test/Gradient/Loops.C | 10 ++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 97c78d519..fb1102d66 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2292,7 +2292,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff RResult; // If R has no side effects, it can be just cloned // (no need to store it). - if (!ShouldRecompute(R)) { + + // Check if the local variable declaration is reference type, since it is + // moved to the global scope and the right side should be recomputed + bool promoteToFnScope = false; + if (auto* RDeclRef = dyn_cast(R->IgnoreImplicit())) + promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() && + !getCurrentScope()->isFunctionScope(); + + if (!ShouldRecompute(R) || promoteToFnScope) { RDelayed = std::unique_ptr( new DelayedStoreResult(DelayedGlobalStoreAndRef(R))); RResult = RDelayed->Result; diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index d648a1149..7dfd0ba9a 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2689,7 +2689,7 @@ double fn34(double x, double y){ double r = 0; double a[] = {y, x*y, x*x + y}; for(auto& i: a){ - r+=i; + r+=i*i; } return r; } @@ -2724,7 +2724,7 @@ double fn34(double x, double y){ //CHECK-NEXT: clad::push(_t3, _d_i); //CHECK-NEXT: } //CHECK-NEXT: clad::push(_t1, r); -//CHECK-NEXT: r += *i; +//CHECK-NEXT: r += *i * *i; //CHECK-NEXT: } //CHECK-NEXT: _d_r += 1; //CHECK-NEXT: for (; _t0; _t0--) { @@ -2737,7 +2737,8 @@ double fn34(double x, double y){ //CHECK-NEXT: { //CHECK-NEXT: r = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_r; -//CHECK-NEXT: *_d_i += _r_d0; +//CHECK-NEXT: *_d_i += _r_d0 * *i; +//CHECK-NEXT: *_d_i += *i * _r_d0; //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } @@ -2751,6 +2752,7 @@ double fn34(double x, double y){ //CHECK-NEXT: } //CHECK-NEXT: } + double fn35(double x, double y){ double a[] = {1, 2, 3}; double sum = 0; @@ -2901,7 +2903,7 @@ int main() { TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00} TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} - TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00} + TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00} TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} }