Skip to content

Commit

Permalink
Make the signatues of KokkosBuiltins.h more general
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch authored and vgvassilev committed Aug 28, 2024
1 parent e2f4638 commit 181208c
Showing 1 changed file with 64 additions and 53 deletions.
117 changes: 64 additions & 53 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ constructor_pushforward(
const ::std::string& name, const size_t& idx0, const size_t& idx1,
const size_t& idx2, const size_t& idx3, const size_t& idx4,
const size_t& idx5, const size_t& idx6, const size_t& idx7,
const ::std::string& d_name, const size_t& d_idx0, const size_t& d_idx1,
const size_t& d_idx2, const size_t& d_idx3, const size_t& d_idx4,
const size_t& d_idx5, const size_t& d_idx6, const size_t& d_idx7) {
const ::std::string& /*d_name*/, const size_t& /*d_idx0*/,
const size_t& /*d_idx1*/, const size_t& /*d_idx2*/,
const size_t& /*d_idx3*/, const size_t& /*d_idx4*/,
const size_t& /*d_idx5*/, const size_t& /*d_idx6*/,
const size_t& /*d_idx7*/) {
return {Kokkos::View<DataType, ViewParams...>(name, idx0, idx1, idx2, idx3,
idx4, idx5, idx6, idx7),
Kokkos::View<DataType, ViewParams...>(
Expand All @@ -37,63 +39,71 @@ operator_call_pushforward(const View* v, Idx i0, const View* d_v,
Idx /*d_i0*/) {
return {(*v)(i0), (*d_v)(i0)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, const View* d_v,
Idx /*d_i0*/, Idx /*d_i1*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, const View* d_v,
Idx0 /*d_i0*/, Idx1 /*d_i1*/) {
return {(*v)(i0, i1), (*d_v)(i0, i1)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1, typename Idx2>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2,
const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2,
const View* d_v, Idx0 /*d_i0*/, Idx1 /*d_i1*/,
Idx2 /*d_i2*/) {
return {(*v)(i0, i1, i2), (*d_v)(i0, i1, i2)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3,
const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/, Idx /*d_i3*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3,
const View* d_v, Idx0 /*d_i0*/, Idx1 /*d_i1*/,
Idx2 /*d_i2*/, Idx3 /*d_i3*/) {
return {(*v)(i0, i1, i2, i3), (*d_v)(i0, i1, i2, i3)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/, Idx /*d_i3*/, Idx /*d_i4*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3,
Idx4 i4, const View* d_v, Idx0 /*d_i0*/,
Idx1 /*d_i1*/, Idx2 /*d_i2*/, Idx3 /*d_i3*/,
Idx4 /*d_i4*/) {
return {(*v)(i0, i1, i2, i3, i4), (*d_v)(i0, i1, i2, i3, i4)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/, Idx /*d_i3*/, Idx /*d_i4*/,
Idx /*d_i5*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3,
Idx4 i4, Idx5 i5, const View* d_v, Idx0 /*d_i0*/,
Idx1 /*d_i1*/, Idx2 /*d_i2*/, Idx3 /*d_i3*/,
Idx4 /*d_i4*/, Idx5 /*d_i5*/) {
return {(*v)(i0, i1, i2, i3, i4, i5), (*d_v)(i0, i1, i2, i3, i4, i5)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, Idx i6, const View* d_v, Idx /*d_i0*/,
Idx /*d_i1*/, Idx /*d_i2*/, Idx /*d_i3*/,
Idx /*d_i4*/, Idx /*d_i5*/, Idx /*d_i6*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3,
Idx4 i4, Idx5 i5, Idx6 i6, const View* d_v,
Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/,
Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/,
Idx6 /*d_i6*/) {
return {(*v)(i0, i1, i2, i3, i4, i5, i6), (*d_v)(i0, i1, i2, i3, i4, i5, i6)};
}
template <typename View, typename Idx>
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6,
typename Idx7>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, Idx i6, Idx i7, const View* d_v, Idx /*d_i0*/,
Idx /*d_i1*/, Idx /*d_i2*/, Idx /*d_i3*/,
Idx /*d_i4*/, Idx /*d_i5*/, Idx /*d_i6*/,
Idx /*d_i7*/) {
operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3,
Idx4 i4, Idx5 i5, Idx6 i6, Idx7 i7, const View* d_v,
Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/,
Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/,
Idx6 /*d_i6*/, Idx7 /*d_i7*/) {
return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7),
(*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)};
}
Expand All @@ -108,28 +118,29 @@ inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
deep_copy(dst, src);
deep_copy(d_dst, d_src);
}
template <class View>
inline void resize_pushforward(View& v, const size_t n0, const size_t n1,
const size_t n2, const size_t n3,
const size_t n4, const size_t n5,
const size_t n6, const size_t n7, View& d_v,
const size_t /*d_n0*/, const size_t /*d_n1*/,
const size_t /*d_n2*/, const size_t /*d_n3*/,
const size_t /*d_n4*/, const size_t /*d_n5*/,
const size_t /*d_n6*/, const size_t /*d_n7*/) {
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6,
typename Idx7>
inline void
resize_pushforward(View& v, const Idx0 n0, const Idx1 n1, const Idx2 n2,
const Idx3 n3, const Idx4 n4, const Idx5 n5, const Idx6 n6,
const Idx7 n7, View& d_v, const Idx0 /*d_n*/,
const Idx1 /*d_n*/, const Idx2 /*d_n*/, const Idx3 /*d_n*/,
const Idx4 /*d_n*/, const Idx5 /*d_n*/, const Idx6 /*d_n*/,
const Idx7 /*d_n*/) {
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class I, class dI, class View>
inline void resize_pushforward(const I& arg, View& v, const size_t n0,
const size_t n1, const size_t n2,
const size_t n3, const size_t n4,
const size_t n5, const size_t n6,
const size_t n7, const dI& /*d_arg*/, View& d_v,
const size_t /*d_n0*/, const size_t /*d_n1*/,
const size_t /*d_n2*/, const size_t /*d_n3*/,
const size_t /*d_n4*/, const size_t /*d_n5*/,
const size_t /*d_n6*/, const size_t /*d_n7*/) {
template <class I, class dI, class View, typename Idx0, typename Idx1,
typename Idx2, typename Idx3, typename Idx4, typename Idx5,
typename Idx6, typename Idx7>
inline void
resize_pushforward(const I& arg, View& v, const Idx0 n0, const Idx1 n1,
const Idx2 n2, const Idx3 n3, const Idx4 n4, const Idx5 n5,
const Idx6 n6, const Idx7 n7, const dI& /*d_arg*/, View& d_v,
const Idx0 /*d_n*/, const Idx1 /*d_n*/, const Idx2 /*d_n*/,
const Idx3 /*d_n*/, const Idx4 /*d_n*/, const Idx5 /*d_n*/,
const Idx6 /*d_n*/, const Idx7 /*d_n*/) {
::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
Expand Down

0 comments on commit 181208c

Please sign in to comment.