From 216201a8b8f5748d4fd2d7ef1946d360e855ec9c Mon Sep 17 00:00:00 2001 From: Rohan Julka Date: Thu, 20 Jun 2024 13:33:45 +0100 Subject: [PATCH] Add support for && operator Add support for differentiation of expressions which include && operator. Check whether then/else block of if stmt is empty before adding it to reverse or forward block. --- lib/Differentiator/ReverseModeVisitor.cpp | 24 +++- test/Gradient/FunctionCalls.C | 40 ++++-- test/Gradient/Gradients.C | 155 ++++++++++++++++++++++ 3 files changed, 208 insertions(+), 11 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c0adadbe9..6f1df389a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -895,7 +895,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock), - utils::unwrapIfSingleStmt(ReverseBlock)); + utils::unwrapIfSingleStmt(ReverseBlock), + /*forwSweepDiff=*/nullptr, + /*valueForRevSweep=*/condDiffStored); } StmtDiff ReverseModeVisitor::VisitConditionalOperator( @@ -2382,6 +2384,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Rdiff = Visit(R, dfdx()); valueForRevPass = Ldiff.getRevSweepAsExpr(); ResultRef = Ldiff.getExpr(); + } else if (opCode == BO_LAnd) { + VarDecl* condVar = GlobalStoreImpl(m_Context.BoolTy, "_cond"); + VarDecl* derivedCondVar = GlobalStoreImpl( + m_Context.DoubleTy, "_d" + condVar->getNameAsString()); + Expr* condVarRef = BuildDeclRef(condVar); + Expr* assignExpr = BuildOp(BO_Assign, condVarRef, Clone(R)); + m_Variables.emplace(condVar, BuildDeclRef(derivedCondVar)); + auto* IfStmt = clad_compat::IfStmt_Create( + /*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false, + /*Init=*/nullptr, /*Var=*/nullptr, + /*Cond=*/L, /*LPL=*/noLoc, /*RPL=*/noLoc, /*Then=*/assignExpr, + /*EL=*/noLoc, + /*Else=*/nullptr); + + StmtDiff IfStmtDiff = VisitIfStmt(IfStmt); + addToCurrentBlock(utils::unwrapIfSingleStmt(IfStmtDiff.getStmt())); + addToCurrentBlock(utils::unwrapIfSingleStmt(IfStmtDiff.getStmt_dx()), + direction::reverse); + auto* condDiffStored = IfStmtDiff.getRevSweepAsExpr(); + return BuildOp(BO_LAnd, condDiffStored, condVarRef); } else { // We should not output any warning on visiting boolean conditions // FIXME: We should support boolean differentiation or ignore it diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index a419b5ab2..6004d39db 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -966,16 +966,36 @@ double sq_defined_later(double x) { // CHECK-NEXT: } // CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) { -// CHECK-NEXT: bool _cond0; -// CHECK-NEXT: { -// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a'; -// CHECK-NEXT: if (_cond0) -// CHECK-NEXT: goto _label0; -// CHECK-NEXT: } -// CHECK-NEXT: if (_cond0) -// CHECK-NEXT: _label0: -// CHECK-NEXT: *_d_x += _d_y; -// CHECK-NEXT: } +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: double _d_cond0; +// CHECK-NEXT: bool _cond1; +// CHECK-NEXT: bool _t0; +// CHECK-NEXT: bool _cond2; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: _cond1 = c == 'a'; +// CHECK-NEXT: if (_cond1) { +// CHECK-NEXT: _t0 = _cond0; +// CHECK-NEXT: _cond0 = s[0] == 'a'; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _cond2 = _cond1 && _cond0; +// CHECK-NEXT: if (_cond2) +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond2) +// CHECK-NEXT: _label0: +// CHECK-NEXT: *_d_x += _d_y; +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond1) { +// CHECK-NEXT: _cond0 = _t0; +// CHECK-NEXT: double _r_d0 = _d_cond0; +// CHECK-NEXT: _d_cond0 -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} // CHECK: void custom_max_pullback(const double &a, const double &b, double _d_y, double *_d_a, double *_d_b) { // CHECK-NEXT: bool _cond0; diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index 6c09c4038..730ca8884 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -924,6 +924,154 @@ double fn_empty_if_else(double x) { //CHECK-NEXT: } //CHECK-NEXT:} +double fn_cond_false(double i, double j) { + double res = 0; + if (i*j && res > 0) { + res = 6 * i * j; + } + return res; +} + +// CHECK: void fn_cond_false_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: double _d_cond0; +// CHECK-NEXT: bool _cond1; +// CHECK-NEXT: bool _t0; +// CHECK-NEXT: bool _cond2; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: _cond1 = i * j; +// CHECK-NEXT: if (_cond1) { +// CHECK-NEXT: _t0 = _cond0; +// CHECK-NEXT: _cond0 = res > 0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _cond2 = _cond1 && _cond0; +// CHECK-NEXT: if (_cond2) { +// CHECK-NEXT: _t1 = res; +// CHECK-NEXT: res = 6 * i * j; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond2) { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t1; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res -= _r_d1; +// CHECK-NEXT: *_d_i += 6 * _r_d1 * j; +// CHECK-NEXT: *_d_j += 6 * i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond1) { +// CHECK-NEXT: _cond0 = _t0; +// CHECK-NEXT: double _r_d0 = _d_cond0; +// CHECK-NEXT: _d_cond0 -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + +double fn_cond_add_assign(double i, double j) { + double res = 0; + if ((res = 2 * i * j) && (res += 3 * i * j) && (res += 5 * i * j)) { + res += 6 * i * j; + } + return res; +} + +// CHECK: void fn_cond_add_assign_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: double _d_cond0; +// CHECK-NEXT: bool _cond1; +// CHECK-NEXT: double _d_cond1; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: bool _cond2; +// CHECK-NEXT: bool _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: bool _cond3; +// CHECK-NEXT: bool _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: bool _cond4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: _cond2 = (res = 2 * i * j); +// CHECK-NEXT: if (_cond2) { +// CHECK-NEXT: _t1 = _cond1; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: _cond1 = (res += 3 * i * j); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _cond3 = _cond2 && _cond1; +// CHECK-NEXT: if (_cond3) { +// CHECK-NEXT: _t3 = _cond0; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: _cond0 = (res += 5 * i * j); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _cond4 = _cond3 && _cond0; +// CHECK-NEXT: if (_cond4) { +// CHECK-NEXT: _t5 = res; +// CHECK-NEXT: res += 6 * i * j; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond4) { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t5; +// CHECK-NEXT: double _r_d5 = _d_res; +// CHECK-NEXT: *_d_i += 6 * _r_d5 * j; +// CHECK-NEXT: *_d_j += 6 * i * _r_d5; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond3) { +// CHECK-NEXT: _cond0 = _t3; +// CHECK-NEXT: double _r_d3 = _d_cond0; +// CHECK-NEXT: _d_cond0 -= _r_d3; +// CHECK-NEXT: _d_res += _r_d3; +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d4 = _d_res; +// CHECK-NEXT: *_d_i += 5 * _r_d4 * j; +// CHECK-NEXT: *_d_j += 5 * i * _r_d4; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: if (_cond2) { +// CHECK-NEXT: _cond1 = _t1; +// CHECK-NEXT: double _r_d1 = _d_cond1; +// CHECK-NEXT: _d_cond1 -= _r_d1; +// CHECK-NEXT: _d_res += _r_d1; +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: *_d_i += 3 * _r_d2 * j; +// CHECK-NEXT: *_d_j += 3 * i * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: *_d_i += 2 * _r_d0 * j; +// CHECK-NEXT: *_d_j += 2 * i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -1006,4 +1154,11 @@ int main() { INIT_GRADIENT(fn_empty_if_else); TEST_GRADIENT(fn_empty_if_else, /*numOfDerivativeArgs=*/1, 1, &dx); // CHECK-EXEC: 5.00 + INIT_GRADIENT(fn_cond_false); + TEST_GRADIENT(fn_cond_false, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {0.00, 0.00} + + INIT_GRADIENT(fn_cond_add_assign); + TEST_GRADIENT(fn_cond_add_assign, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {80.00, 48.00} + + }