Skip to content

Commit

Permalink
Enhance the support of std::vector and std::array in the rvs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch authored and vgvassilev committed Sep 18, 2024
1 parent dc1ebff commit b9a390d
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 0 deletions.
56 changes: 56 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,22 @@ void operator_subscript_pullback(::std::vector<T>* vec,
(*d_vec)[idx] += d_y;
}

template <typename T>
clad::ValueAndAdjoint<T&, T&>
at_reverse_forw(::std::vector<T>* vec, typename ::std::vector<T>::size_type idx,
::std::vector<T>* d_vec,
typename ::std::vector<T>::size_type d_idx) {
return {(*vec)[idx], (*d_vec)[idx]};
}

template <typename T, typename P>
void at_pullback(::std::vector<T>* vec,
typename ::std::vector<T>::size_type idx, P d_y,
::std::vector<T>* d_vec,
typename ::std::vector<T>::size_type* d_idx) {
(*d_vec)[idx] += d_y;
}

template <typename T, typename S, typename U>
::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
Expand All @@ -443,6 +459,43 @@ void constructor_pullback(::std::vector<T>* v, S count, U val,
d_v->clear();
}

template <typename T, typename U, typename dU>
void assign_pullback(::std::vector<T>* v,
typename ::std::vector<T>::size_type n, U /*val*/,
::std::vector<T>* d_v,
typename ::std::vector<T>::size_type* /*d_n*/, dU* d_val) {
for (typename ::std::vector<T>::size_type i = 0; i < n; ++i) {
(*d_val) += (*d_v)[i];
(*d_v)[i] = 0;
}
}

template <typename T>
void reserve_pullback(::std::vector<T>* v,
typename ::std::vector<T>::size_type n,
::std::vector<T>* d_v,
typename ::std::vector<T>::size_type* /*d_n*/) noexcept {}

template <typename T>
void shrink_to_fit_pullback(::std::vector<T>* /*v*/,
::std::vector<T>* /*d_v*/) noexcept {}

template <typename T>
void size_pullback(::std::vector<T>* /*v*/,
::std::vector<T>* /*d_v*/) noexcept {}

template <typename T>
void capacity_pullback(::std::vector<T>* /*v*/,
::std::vector<T>* /*d_v*/) noexcept {}

template <typename T, typename U>
void size_pullback(::std::vector<T>* /*v*/, U /*d_y*/,
::std::vector<T>* /*d_v*/) noexcept {}

template <typename T, typename U>
void capacity_pullback(::std::vector<T>* /*v*/, U /*d_y*/,
::std::vector<T>* /*d_v*/) noexcept {}

// array reverse mode

template <typename T, ::std::size_t N>
Expand Down Expand Up @@ -514,6 +567,9 @@ void front_pullback(::std::array<T, N>* arr,
}
template <typename T, ::std::size_t N>
void size_pullback(::std::array<T, N>* a, ::std::array<T, N>* d_a) noexcept {}
template <typename T, ::std::size_t N, typename U>
void size_pullback(::std::array<T, N>* /*a*/, U /*d_y*/,
::std::array<T, N>* /*d_a*/) noexcept {}
template <typename T, ::std::size_t N>
::clad::ValueAndAdjoint<::std::array<T, N>, ::std::array<T, N>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::array<T, N>>,
Expand Down
182 changes: 182 additions & 0 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,36 @@ double fn18(double x, double y) {
return a[1];
}

double fn19(double x, double y) {
std::vector<double> v;
for (size_t i = 0; i < 3; ++i) {
v.push_back(x);
}
double res = 0;
for (size_t i = 0; i < v.size(); ++i) {
res += v.at(i);
}

v.assign(3, 0);
v.assign(2, y);

return res + v[0] + v[1] + v[2]; // 3x+2y
}

double fn20(double x, double y) {
std::vector<double> v;

v.reserve(10);

double res = x*v.capacity();

v.push_back(x);
v.shrink_to_fit();
res += y*v.capacity() + x*v.size();

return res; // 11x+y
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
Expand All @@ -158,6 +188,8 @@ int main() {
INIT_GRADIENT(fn16);
INIT_GRADIENT(fn17);
INIT_GRADIENT(fn18);
INIT_GRADIENT(fn19);
INIT_GRADIENT(fn20);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
Expand All @@ -168,6 +200,8 @@ int main() {
TEST_GRADIENT(fn16, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {108.00, 27.00}
TEST_GRADIENT(fn17, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00}
TEST_GRADIENT(fn18, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {2.00, 0.00}
TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 2.00}
TEST_GRADIENT(fn20, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {11.00, 1.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -659,3 +693,151 @@ int main() {
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t0, 1, 0., &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn19_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: size_t _d_i = {{0U|0UL|0}};
// CHECK-NEXT: size_t i = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}tape<double> _t1 = {};
// CHECK-NEXT: {{.*}}tape<{{.*}}vector<double> > _t2 = {};
// CHECK-NEXT: size_t _d_i0 = {{0U|0UL|0}};
// CHECK-NEXT: size_t i0 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}tape<{{.*}}vector<double> > _t4 = {};
// CHECK-NEXT: {{.*}}tape<double> _t5 = {};
// CHECK-NEXT: {{.*}}tape<{{.*}}vector<double> > _t6 = {};
// CHECK-NEXT: {{.*}}vector<double> _d_v({});
// CHECK-NEXT: {{.*}}vector<double> v;
// CHECK-NEXT: {{.*}} _t0 = {{0U|0UL|0}};
// CHECK-NEXT: for (i = 0; ; ++i) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!(i < 3))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: {{.*}}push(_t1, x);
// CHECK-NEXT: {{.*}}push(_t2, v);
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&v, x, &_d_v, *_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {{.*}} _t3 = {{0U|0UL|0}};
// CHECK-NEXT: for (i0 = 0; ; ++i0) {
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}push(_t4, v);
// CHECK-NEXT: }
// CHECK-NEXT: if (!(i0 < v.size()))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t3++;
// CHECK-NEXT: {{.*}}push(_t5, res);
// CHECK-NEXT: {{.*}}push(_t6, v);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t7 = {{.*}}at_reverse_forw(&v, i0, &_d_v, _r0);
// CHECK-NEXT: res += _t7.value;
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}}vector<double> _t8 = v;
// CHECK-NEXT: v.assign(3, 0);
// CHECK-NEXT: double _t9 = y;
// CHECK-NEXT: {{.*}}vector<double> _t10 = v;
// CHECK-NEXT: v.assign(2, y);
// CHECK-NEXT: {{.*}}vector<double> _t11 = v;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t12 = {{.*}}operator_subscript_reverse_forw(&v, 0, &_d_v, _r4);
// CHECK-NEXT: {{.*}}vector<double> _t13 = v;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t14 = {{.*}}operator_subscript_reverse_forw(&v, 1, &_d_v, _r5);
// CHECK-NEXT: {{.*}}vector<double> _t15 = v;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t16 = {{.*}}operator_subscript_reverse_forw(&v, 2, &_d_v, _r6);
// CHECK-NEXT: {
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {{.*}}size_type _r4 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t11, 0, 1, &_d_v, &_r4);
// CHECK-NEXT: {{.*}}size_type _r5 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t13, 1, 1, &_d_v, &_r5);
// CHECK-NEXT: {{.*}}size_type _r6 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t15, 2, 1, &_d_v, &_r6);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: y = _t9;
// CHECK-NEXT: {{.*}}size_type _r3 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}assign_pullback(&_t10, 2, _t9, &_d_v, &_r3, &*_d_y);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r1 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}value_type _r2 = 0.;
// CHECK-NEXT: {{.*}}assign_pullback(&_t8, 3, 0, &_d_v, &_r1, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: for (;; _t3--) {
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_pullback(&{{.*}}back(_t4), &_d_v);
// CHECK-NEXT: {{.*}}pop(_t4);
// CHECK-NEXT: }
// CHECK-NEXT: if (!_t3)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: --i0;
// CHECK-NEXT: {
// CHECK-NEXT: res = {{.*}}pop(_t5);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: size_t _r0 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}at_pullback(&{{.*}}back(_t6), i0, _r_d0, &_d_v, &_r0);
// CHECK-NEXT: _d_i0 += _r0;
// CHECK-NEXT: {{.*}}pop(_t6);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: --i;
// CHECK-NEXT: {
// CHECK-NEXT: x = {{.*}}back(_t1);
// CHECK-NEXT: {{.*}}push_back_pullback(&{{.*}}back(_t2), {{.*}}back(_t1), &_d_v, &*_d_x);
// CHECK-NEXT: {{.*}}pop(_t1);
// CHECK-NEXT: {{.*}}pop(_t2);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn20_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: {{.*}}vector<double> _d_v({});
// CHECK-NEXT: {{.*}}vector<double> v;
// CHECK-NEXT: {{.*}}vector<double> _t0 = v;
// CHECK-NEXT: v.reserve(10);
// CHECK-NEXT: {{.*}}vector<double> _t2 = v;
// CHECK-NEXT: double _t1 = v.capacity();
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = x * _t1;
// CHECK-NEXT: double _t3 = x;
// CHECK-NEXT: {{.*}}vector<double> _t4 = v;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&v, x, &_d_v, *_d_x);
// CHECK-NEXT: {{.*}}vector<double> _t5 = v;
// CHECK-NEXT: v.shrink_to_fit();
// CHECK-NEXT: double _t6 = res;
// CHECK-NEXT: {{.*}}vector<double> _t8 = v;
// CHECK-NEXT: double _t7 = v.capacity();
// CHECK-NEXT: {{.*}}vector<double> _t10 = v;
// CHECK-NEXT: double _t9 = v.size();
// CHECK-NEXT: res += y * _t7 + x * _t9;
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: res = _t6;
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_y += _r_d0 * _t7;
// CHECK-NEXT: {{.*}}capacity_pullback(&_t8, y * _r_d0, &_d_v);
// CHECK-NEXT: *_d_x += _r_d0 * _t9;
// CHECK-NEXT: {{.*}}size_pullback(&_t10, x * _r_d0, &_d_v);
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}}shrink_to_fit_pullback(&_t5, &_d_v);
// CHECK-NEXT: {
// CHECK-NEXT: x = _t3;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t4, _t3, &_d_v, &*_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: *_d_x += _d_res * _t1;
// CHECK-NEXT: {{.*}}capacity_pullback(&_t2, x * _d_res, &_d_v);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r0 = {{0U|0UL|0}};
// CHECK-NEXT: {{.*}}reserve_pullback(&_t0, 10, &_d_v, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit b9a390d

Please sign in to comment.