diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index c81b0cd2b..eb4f2d2a7 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -122,15 +122,11 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); if (noHiddenParam) { MutableArrayRef FDparam = FD->parameters(); - m_Varied = true; - m_Marking = true; for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); TraverseStmt(par); m_VariedDecls.insert(FDparam[i]); } - m_Varied = false; - m_Marking = false; } return true; } @@ -141,7 +137,8 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { m_Varied = false; TraverseStmt(init); m_Marking = true; - if (m_Varied) + QualType VDTy = cast(D)->getType(); + if (m_Varied || utils::isArrayOrPointerType(VDTy)) copyVarToCurBlock(cast(D)); m_Marking = false; } diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 28485a413..8f4cb12f6 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -636,11 +636,19 @@ namespace clad { if (!EnableVariedAnalysis) return true; - if (VD->getType()->isPointerType() || isa(VD->getType())) - return true; - if (!m_ActivityRunInfo.HasAnalysisRun) { - std::copy(Function->param_begin(), Function->param_end(), + ArrayRef FDparam = Function->parameters(); + std::vector derivedParam; + + for (auto* parameter : FDparam) { + QualType parType = parameter->getType(); + while (parType->isPointerType()) + parType = parType->getPointeeType(); + if (!parType.isConstQualified()) + derivedParam.push_back(parameter); + } + + std::copy(derivedParam.begin(), derivedParam.end(), std::inserter(m_ActivityRunInfo.ToBeRecorded, m_ActivityRunInfo.ToBeRecorded.end())); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b34086b00..026164498 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1826,7 +1826,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // subexpression. if (const auto* MTE = dyn_cast(arg)) arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts(); - if (!arg->isEvaluatable(m_Context)) { + // FIXME: We should consider moving this code in the VariedAnalysis + // where we could decide to remove pullback requests from the + // diff graph. + class VariedChecker : public RecursiveASTVisitor { + const DiffRequest& Request; + + public: + VariedChecker(const DiffRequest& DR) : Request(DR) {} + bool isVariedE(const clang::Expr* E) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return !TraverseStmt(const_cast(E)); + } + bool VisitDeclRefExpr(const clang::DeclRefExpr* DRE) { + if (!isa(DRE->getDecl())) + return true; + if (Request.shouldHaveAdjoint(cast(DRE->getDecl()))) + return false; + return true; + } + } analyzer(m_DiffReq); + if (analyzer.isVariedE(arg)) { allArgsAreConstantLiterals = false; break; } diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 1a28f83be..9979c85ce 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -264,6 +264,46 @@ double f8(double x){ // CHECK-NEXT: } // CHECK-NEXT: } +double fn9(double x, double const *obs) +{ + double res = 0.0; + for (int loopIdx0 = 0; loopIdx0 < 2; loopIdx0++) { + res += std::lgamma(obs[2 + loopIdx0] + 1) + x; + } + return res; +} + +// CHECK: void fn9_grad(double x, const double *obs, double *_d_x, double *_d_obs) { +// CHECK-NEXT: int loopIdx0 = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double res = 0.; +// CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +// CHECK-NEXT: for (loopIdx0 = 0; ; loopIdx0++) { +// CHECK-NEXT: { +// CHECK-NEXT: if (!(loopIdx0 < 2)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: _t0++; +// CHECK-NEXT: clad::push(_t1, res); +// CHECK-NEXT: res += std::lgamma(obs[2 + loopIdx0] + 1) + x; +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: if (!_t0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: loopIdx0--; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -272,7 +312,10 @@ double f8(double x){ } int main(){ + double arr[] = {1,2,3,4,5}; + double darr[] = {0,0,0,0,0}; double result[3] = {}; + double dx; TEST(f1, 3);// CHECK-EXEC: {6.00} TEST(f2, 3);// CHECK-EXEC: {6.00} TEST(f3, 3);// CHECK-EXEC: {0.00} @@ -281,6 +324,9 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} + auto grad = clad::gradient(fn9); + grad.execute(3, arr, &dx, darr); + printf("%.2f\n", dx);// CHECK-EXEC: 2.00 } // CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) {