Skip to content

Commit

Permalink
Merge pull request #2853 from andrjohns/vectorised-select
Browse files Browse the repository at this point in the history
Add vectorised select(), any(), and all() functions
  • Loading branch information
andrjohns authored Aug 11, 2023
2 parents 34881d4 + 64728a0 commit 38289cd
Show file tree
Hide file tree
Showing 8 changed files with 656 additions and 16 deletions.
17 changes: 1 addition & 16 deletions stan/math/opencl/kernel_generator/select.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifdef STAN_OPENCL

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
Expand Down Expand Up @@ -150,22 +151,6 @@ select(T_condition&& condition, T_then&& then, T_else&& els) { // NOLINT
as_operation_cl(std::forward<T_else>(els))};
}

/**
* Scalar overload of the selection operation.
* @tparam T_then type of then scalar
* @tparam T_else type of else scalar
* @param condition condition
* @param then then result
* @param els else result
* @return `condition ? then : els`
*/
template <typename T_then, typename T_else,
require_all_arithmetic_t<T_then, T_else>* = nullptr>
inline std::common_type_t<T_then, T_else> select(bool condition, T_then then,
T_else els) {
return condition ? then : els;
}

/** @}*/
} // namespace math
} // namespace stan
Expand Down
3 changes: 3 additions & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <stan/math/prim/fun/acosh.hpp>
#include <stan/math/prim/fun/add.hpp>
#include <stan/math/prim/fun/add_diag.hpp>
#include <stan/math/prim/fun/all.hpp>
#include <stan/math/prim/fun/any.hpp>
#include <stan/math/prim/fun/append_array.hpp>
#include <stan/math/prim/fun/append_col.hpp>
#include <stan/math/prim/fun/append_row.hpp>
Expand Down Expand Up @@ -305,6 +307,7 @@
#include <stan/math/prim/fun/scaled_add.hpp>
#include <stan/math/prim/fun/sd.hpp>
#include <stan/math/prim/fun/segment.hpp>
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/signbit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
Expand Down
85 changes: 85 additions & 0 deletions stan/math/prim/fun/all.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#ifndef STAN_MATH_PRIM_FUN_ALL_HPP
#define STAN_MATH_PRIM_FUN_ALL_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/for_each.hpp>
#include <algorithm>

namespace stan {
namespace math {

/**
* Return true if all values in the input are true.
*
* Overload for a single integral input
*
* @tparam T Any type convertible to `bool`
* @param x integral input
* @return The input unchanged
*/
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
constexpr inline bool all(T x) {
return x;
}

/**
* Return true if all values in the input are true.
*
* Overload for Eigen types
*
* @tparam ContainerT A type derived from `Eigen::EigenBase` that has an
* `integral` scalar type
* @param x Eigen object of boolean inputs
* @return Boolean indicating whether all elements are true
*/
template <typename ContainerT,
require_eigen_st<std::is_integral, ContainerT>* = nullptr>
inline bool all(const ContainerT& x) {
return x.all();
}

// Forward-declaration for correct resolution of all(std::vector<std::tuple>)
template <typename... Types>
inline bool all(const std::tuple<Types...>& x);

/**
* Return true if all values in the input are true.
*
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
* approach cannot be used as std::vector<bool> types do not have a .data()
* member and are not always stored contiguously.
*
* @tparam InnerT Type within std::vector
* @param x Nested container of boolean inputs
* @return Boolean indicating whether all elements are true
*/
template <typename InnerT>
inline bool all(const std::vector<InnerT>& x) {
return std::all_of(x.begin(), x.end(), [](const auto& i) { return all(i); });
}

/**
* Return true if all values in the input are true.
*
* Overload for a tuple input.
*
* @tparam Types of items within tuple
* @param x Tuple of boolean scalar-type elements
* @return Boolean indicating whether all elements are true
*/
template <typename... Types>
inline bool all(const std::tuple<Types...>& x) {
bool all_true = true;
math::for_each(
[&all_true](const auto& i) {
all_true = all_true && all(i);
return;
},
x);
return all_true;
}

} // namespace math
} // namespace stan

#endif
85 changes: 85 additions & 0 deletions stan/math/prim/fun/any.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#ifndef STAN_MATH_PRIM_FUN_ANY_HPP
#define STAN_MATH_PRIM_FUN_ANY_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/for_each.hpp>
#include <algorithm>

namespace stan {
namespace math {

/**
* Return true if any values in the input are true.
*
* Overload for a single boolean input
*
* @tparam T Any type convertible to `bool`
* @param x boolean input
* @return The input unchanged
*/
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
constexpr inline bool any(T x) {
return x;
}

/**
* Return true if any values in the input are true.
*
* Overload for Eigen types
*
* @tparam ContainerT A type derived from `Eigen::EigenBase` that has an
* `integral` scalar type
* @param x Eigen object of boolean inputs
* @return Boolean indicating whether any elements are true
*/
template <typename ContainerT,
require_eigen_st<std::is_integral, ContainerT>* = nullptr>
inline bool any(const ContainerT& x) {
return x.any();
}

// Forward-declaration for correct resolution of any(std::vector<std::tuple>)
template <typename... Types>
inline bool any(const std::tuple<Types...>& x);

/**
* Return true if any values in the input are true.
*
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
* approach cannot be used as std::vector<bool> types do not have a .data()
* member and are not always stored contiguously.
*
* @tparam InnerT Type within std::vector
* @param x Nested container of boolean inputs
* @return Boolean indicating whether any elements are true
*/
template <typename InnerT>
inline bool any(const std::vector<InnerT>& x) {
return std::any_of(x.begin(), x.end(), [](const auto& i) { return any(i); });
}

/**
* Return true if any values in the input are true.
*
* Overload for a tuple input.
*
* @tparam Types of items within tuple
* @param x Tuple of boolean scalar-type elements
* @return Boolean indicating whether any elements are true
*/
template <typename... Types>
inline bool any(const std::tuple<Types...>& x) {
bool any_true = false;
math::for_each(
[&any_true](const auto& i) {
any_true = any_true || any(i);
return;
},
x);
return any_true;
}

} // namespace math
} // namespace stan

#endif
Loading

0 comments on commit 38289cd

Please sign in to comment.