Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
PhrygianGates committed Aug 20, 2023
1 parent a6eff5f commit b18b297
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,11 @@ public:
// CHECK-NEXT: }
// CHECK-NEXT: }

SimpleFunctions& operator+=(double value) {
x += value;
return *this;
}

void mem_fn_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j);
void const_mem_fn_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j);
void volatile_mem_fn_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j);
Expand Down Expand Up @@ -756,6 +761,47 @@ double fn(double i,double j) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn3(SimpleFunctions& v, double value) {
v += value;
return v.x;
}

// CHECK: void operator_plus_equal_pullback(double value, SimpleFunctions _d_y, clad::array_ref<SimpleFunctions> _d_this, clad::array_ref<double> _d_value) {
// CHECK-NEXT: this->x += value;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: double _r_d0 = (* _d_this).x;
// CHECK-NEXT: (* _d_this).x += _r_d0;
// CHECK-NEXT: * _d_value += _r_d0;
// CHECK-NEXT: (* _d_this).x -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> operator_plus_equal_forw(double value, clad::array_ref<SimpleFunctions> _d_this, clad::array_ref<SimpleFunctions> _d_value) {
// CHECK-NEXT: this->x += value;
// CHECK-NEXT: return {*this, (* _d_this)};
// CHECK-NEXT: }

// CHECK: void fn3_grad(SimpleFunctions &v, double value, clad::array_ref<SimpleFunctions> _d_v, clad::array_ref<double> _d_value) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: SimpleFunctions _t1;
// CHECK-NEXT: _t0 = value;
// CHECK-NEXT: _t1 = v;
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t2 = _t1.operator_plus_equal_forw(_t0, &(* _d_v), nullptr);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: (* _d_v).x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: _t1.operator_plus_equal_pullback(_t0, {}, &(* _d_v), &_grad0);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_value += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }


int main() {
auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn);
auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn);
Expand Down Expand Up @@ -790,6 +836,12 @@ int main() {
printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00
}

SimpleFunctions sf(2, 3);
SimpleFunctions d_sf;
auto d_fn3 = clad::gradient(fn3);
d_fn3.execute(sf, 2, &d_sf, &result[0]);
printf("%.2f", result[0]); //CHECK-EXEC: 41.00

auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i");

// CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref<volatile SimpleFunctions> _d_this, clad::array_ref<double> _d_i) const volatile & {
Expand Down

0 comments on commit b18b297

Please sign in to comment.