Skip to content

Commit

Permalink
Add support for Kokkos::parallel_reduce in the fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch authored and vgvassilev committed Aug 27, 2024
1 parent c3b76c0 commit 3bd0624
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 2 deletions.
297 changes: 297 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,303 @@ void parallel_for_pushforward(
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

/// Parallel reduce
// TODO: ADD SUPORT FOR MULTIPLE REDUCED ARGUMENTS
// TODO: ADD SUPPORT FOR UNNAMED LOOPS

// This structure is used to dispatch parallel reduce pushforward calls for
// multidimentional policies
template <class Policy, class FunctorType, class Reduced, class WT, int Rank>
struct diff_parallel_reduce_MDP_dispatch { // non-MDPolicy
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, Reduced& res,
const FunctorType& d_functor, Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto _work_tag, const auto& i, auto& r, auto& d_r) {
functor.operator_call_pushforward(_work_tag, i, r, &d_functor, {}, {},
d_r);
},
res, d_res);
}
};
template <class Policy, class FunctorType, class Reduced, int Rank>
struct diff_parallel_reduce_MDP_dispatch<Policy, FunctorType, Reduced, void,
Rank> { // non-MDPolicy
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, Reduced& res,
const FunctorType& d_functor, Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i, auto& r, auto& d_r) {
functor.operator_call_pushforward(i, r, &d_functor, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced, class WT>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced, WT,
2> { // MDPolicy, rank = 2
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto _work_tag, const auto& i0, const auto& i1, auto& r,
auto& d_r) {
functor.operator_call_pushforward(_work_tag, i0, i1, r, &d_functor,
{}, {}, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
void, 2> { // MDPolicy, rank = 2
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i0, const auto& i1, auto& r, auto& d_r) {
functor.operator_call_pushforward(i0, i1, r, &d_functor, {}, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced, class WT>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced, WT,
3> { // MDPolicy, rank = 3
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto _work_tag, const auto& i0, const auto& i1,
const auto& i2, auto& r, auto& d_r) {
functor.operator_call_pushforward(_work_tag, i0, i1, i2, r,
&d_functor, {}, {}, {}, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
void, 3> { // MDPolicy, rank = 3
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i0, const auto& i1, const auto& i2, auto& r,
auto& d_r) {
functor.operator_call_pushforward(i0, i1, i2, r, &d_functor, {}, {},
{}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced, class WT>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced, WT,
4> { // MDPolicy, rank = 4
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto wt, const auto& i0, const auto& i1, const auto& i2,
const auto& i3, auto& r, auto& d_r) {
functor.operator_call_pushforward(wt, i0, i1, i2, i3, r, &d_functor,
{}, {}, {}, {}, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
void, 4> { // MDPolicy, rank = 4
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i0, const auto& i1, const auto& i2, const auto& i3,
auto& r, auto& d_r) {
functor.operator_call_pushforward(i0, i1, i2, i3, r, &d_functor, {},
{}, {}, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced, class WT>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced, WT,
5> { // MDPolicy, rank = 5
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto wt, const auto& i0, const auto& i1, const auto& i2,
const auto& i3, const auto& i4, auto& r, auto& d_r) {
functor.operator_call_pushforward(wt, i0, i1, i2, i3, i4, r,
&d_functor, {}, {}, {}, {}, {}, {},
d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
void, 5> { // MDPolicy, rank = 5
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i0, const auto& i1, const auto& i2, const auto& i3,
const auto& i4, auto& r, auto& d_r) {
functor.operator_call_pushforward(i0, i1, i2, i3, i4, r, &d_functor,
{}, {}, {}, {}, {}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced, class WT>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced, WT,
6> { // MDPolicy, rank = 6
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto wt, const auto& i0, const auto& i1, const auto& i2,
const auto& i3, const auto& i4, const auto& i5, auto& r,
auto& d_r) {
functor.operator_call_pushforward(wt, i0, i1, i2, i3, i4, i5, r,
&d_functor, {}, {}, {}, {}, {}, {},
{}, d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
void, 6> { // MDPolicy, rank = 6
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i0, const auto& i1, const auto& i2, const auto& i3,
const auto& i4, const auto& i5, auto& r, auto& d_r) {
functor.operator_call_pushforward(i0, i1, i2, i3, i4, i5, r,
&d_functor, {}, {}, {}, {}, {}, {},
d_r);
},
res, d_res);
}
};
template <class PolicyP, class... PolicyParams, class FunctorType,
class Reduced>
struct diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
void, 0> { // MDPolicy matched, now figure out the rank
static void
run(const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, Reduced& res, const FunctorType& d_functor,
Reduced& d_res) {
diff_parallel_reduce_MDP_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType, Reduced,
typename ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::work_tag,
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::rank>::run(str,
policy,
functor,
res,
d_functor,
d_res);
}
};

// This structure is used to dispatch parallel reduce pushforward calls for
// integral policies
template <class Policy, class FunctorType, class Reduced, bool isInt>
struct diff_parallel_reduce_int_dispatch { // non-integral policy
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, Reduced& res,
const FunctorType& d_functor, Reduced& d_res) {
diff_parallel_reduce_MDP_dispatch<Policy, FunctorType, Reduced, void,
0>::run(str, policy, functor, res,
d_functor, d_res);
}
};
template <class Policy, class FunctorType, class Reduced> // integral policy
struct diff_parallel_reduce_int_dispatch<Policy, FunctorType, Reduced, true> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, Reduced& res,
const FunctorType& d_functor, Reduced& d_res) {
::Kokkos::parallel_reduce(
"_diff_" + str, policy,
[&](const auto& i, auto& r, auto& d_r) {
functor.operator_call_pushforward(i, r, &d_functor, {}, d_r);
},
res, d_res);
}
};

template <class Policy, class FunctorType,
class Reduced> // generally, this is matched
void parallel_reduce_pushforward(const ::std::string& str, const Policy& policy,
const FunctorType& functor, Reduced& res,
const ::std::string& /*d_str*/,
const Policy& /*d_policy*/,
const FunctorType& d_functor, Reduced& d_res) {
diff_parallel_reduce_int_dispatch<
Policy, FunctorType, Reduced,
::std::is_integral<Policy>::value>::run(str, policy, functor, res,
d_functor, d_res);
}

} // namespace Kokkos
} // namespace clad::custom_derivatives

Expand Down
1 change: 0 additions & 1 deletion unittests/Kokkos/ParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "clad/Differentiator/KokkosBuiltins.h"
#include "gtest/gtest.h"
// #include "TestUtils.h"
#include "ParallelAdd.h"

TEST(ParallelFor, HelloWorldLambdaLoopForward) {
// // check finite difference and forward mode similarity
Expand Down
Loading

0 comments on commit 3bd0624

Please sign in to comment.