diff --git a/test/Gradient/MemberFunctions.C b/test/Gradient/MemberFunctions.C index 2350b1d0f..148ce639f 100644 --- a/test/Gradient/MemberFunctions.C +++ b/test/Gradient/MemberFunctions.C @@ -703,6 +703,11 @@ public: // CHECK-NEXT: } // CHECK-NEXT: } + SimpleFunctions& operator+=(double value) { + x += value; + return *this; + } + void mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void const_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void volatile_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); @@ -756,6 +761,47 @@ double fn(double i,double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn3(SimpleFunctions& v, double value) { + v += value; + return v.x; +} + +// CHECK: void operator_plus_equal_pullback(double value, SimpleFunctions _d_y, clad::array_ref _d_this, clad::array_ref _d_value) { +// CHECK-NEXT: this->x += value; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = (* _d_this).x; +// CHECK-NEXT: (* _d_this).x += _r_d0; +// CHECK-NEXT: * _d_value += _r_d0; +// CHECK-NEXT: (* _d_this).x -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: clad::ValueAndAdjoint operator_plus_equal_forw(double value, clad::array_ref _d_this, clad::array_ref _d_value) { +// CHECK-NEXT: this->x += value; +// CHECK-NEXT: return {*this, (* _d_this)}; +// CHECK-NEXT: } + +// CHECK: void fn3_grad(SimpleFunctions &v, double value, clad::array_ref _d_v, clad::array_ref _d_value) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: SimpleFunctions _t1; +// CHECK-NEXT: _t0 = value; +// CHECK-NEXT: _t1 = v; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.operator_plus_equal_forw(_t0, &(* _d_v), nullptr); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_v).x += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: _t1.operator_plus_equal_pullback(_t0, {}, &(* _d_v), &_grad0); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_value += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn); @@ -790,6 +836,12 @@ int main() { printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00 } + SimpleFunctions sf(2, 3); + SimpleFunctions d_sf; + auto d_fn3 = clad::gradient(fn3); + d_fn3.execute(sf, 2, &d_sf, &result[0]); + printf("%.2f", result[0]); //CHECK-EXEC: 41.00 + auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i"); // CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref _d_this, clad::array_ref _d_i) const volatile & {