Skip to content

Commit

Permalink
Replace KOKKOSFFT_EXPECTS with KOKKOSFFT_THROW_IF
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Sep 11, 2024
1 parent be5919c commit 89209d8
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 152 deletions.
8 changes: 4 additions & 4 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ auto get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction = 1) {

// Assert if the elements are overlapped
constexpr int rank = ViewType::rank();
KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(axes),
"Axes overlap");
KOKKOSFFT_EXPECTS(
!KokkosFFT::Impl::is_out_of_range_value_included(axes, rank),
KOKKOSFFT_THROW_IF(KokkosFFT::Impl::has_duplicate_values(axes),
"Axes overlap");
KOKKOSFFT_THROW_IF(
KokkosFFT::Impl::is_out_of_range_value_included(axes, rank),
"Axes include an out-of-range index."
"Axes must be in the range of [-rank, rank-1].");

Expand Down
6 changes: 3 additions & 3 deletions common/src/KokkosFFT_asserts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

#if defined(__cpp_lib_source_location) && __cpp_lib_source_location >= 201907L
#include <source_location>
#define KOKKOSFFT_EXPECTS(expression, msg) \
#define KOKKOSFFT_THROW_IF(expression, msg) \
KokkosFFT::Impl::check_precondition( \
(expression), msg, std::source_location::current().file_name(), \
std::source_location::current().line(), \
std::source_location::current().function_name(), \
std::source_location::current().column())
#else
#include <cstdlib>
#define KOKKOSFFT_EXPECTS(expression, msg) \
#define KOKKOSFFT_THROW_IF(expression, msg) \
KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \
__FUNCTION__)
#endif
Expand All @@ -33,7 +33,7 @@ inline void check_precondition(const bool expression,
const char* function_name,
const int column = -1) {
// Quick return if possible
if (expression) return;
if (!expression) return;

std::stringstream ss("file: ");
if (column == -1) {
Expand Down
12 changes: 6 additions & 6 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ auto get_extents(const InViewType& in, const OutViewType& out,
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

constexpr std::size_t rank = InViewType::rank;
[[maybe_unused]] int inner_most_axis =
Expand Down Expand Up @@ -65,17 +65,17 @@ auto get_extents(const InViewType& in, const OutViewType& out,

if (is_real_v<in_value_type>) {
// Then R2C
KOKKOSFFT_EXPECTS(
_out_extents.at(inner_most_axis) ==
KOKKOSFFT_THROW_IF(
_out_extents.at(inner_most_axis) !=
_in_extents.at(inner_most_axis) / 2 + 1,
"For R2C, the 'output extent' of transform must be equal to "
"'input extent'/2 + 1");
}

if (is_real_v<out_value_type>) {
// Then C2R
KOKKOSFFT_EXPECTS(
_in_extents.at(inner_most_axis) ==
KOKKOSFFT_THROW_IF(
_in_extents.at(inner_most_axis) !=
_out_extents.at(inner_most_axis) / 2 + 1,
"For C2R, the 'input extent' of transform must be equal to "
"'output extent' / 2 + 1");
Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */,
static_assert(
KokkosFFT::Impl::have_same_rank_v<InViewType, OutViewType>,
"get_modified_shape: Input View and Output View must have the same rank");
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
Expand Down
8 changes: 4 additions & 4 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(view, _axes),
"get_map_axes: input axes are not valid for the view");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(view, _axes),
"get_map_axes: input axes are not valid for the view");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
Expand Down Expand Up @@ -400,8 +400,8 @@ void transpose(const ExecutionSpace& exec_space, InViewType& in,
"transpose: Rank of View must be equal to Rank of "
"transpose axes.");

KOKKOSFFT_EXPECTS(KokkosFFT::Impl::is_transpose_needed(map),
"transpose: transpose not necessary");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::is_transpose_needed(map),
"transpose: transpose not necessary");

// in order not to call transpose_impl for 1D case
if constexpr (DIM > 1) {
Expand Down
66 changes: 33 additions & 33 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ auto convert_negative_axis(ViewType, int _axis = -1) {
"convert_negative_axis: ViewType must be a Kokkos::View.");
int rank = static_cast<int>(ViewType::rank());

KOKKOSFFT_EXPECTS(_axis >= -rank && _axis < rank,
"Axis must be in [-rank, rank-1]");
KOKKOSFFT_THROW_IF(_axis < -rank || _axis >= rank,
"Axis must be in [-rank, rank-1]");

int axis = _axis < 0 ? rank + _axis : _axis;
return axis;
Expand Down Expand Up @@ -130,7 +130,7 @@ std::size_t get_index(ContainerType& values, const ValueType& value) {
static_assert(std::is_same_v<value_type, ValueType>,
"get_index: Container value type must match ValueType");
auto it = std::find(values.begin(), values.end(), value);
KOKKOSFFT_EXPECTS(it != values.end(), "value is not included in values");
KOKKOSFFT_THROW_IF(it == values.end(), "value is not included in values");
return it - values.begin();
}

Expand Down Expand Up @@ -256,44 +256,44 @@ void create_view(ViewType& out, const Label& label,

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 1>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(extents[0]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(extents[0]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 2>& extents) {
KOKKOSFFT_EXPECTS(
ViewType::required_allocation_size(out.layout()) >=
KOKKOSFFT_THROW_IF(
ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(extents[0], extents[1]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 3>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 4>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3]),
"reshape_view: insufficient memory");

out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 5>& extents) {
KOKKOSFFT_EXPECTS(
ViewType::required_allocation_size(out.layout()) >=
KOKKOSFFT_THROW_IF(
ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(extents[0], extents[1], extents[2],
extents[3], extents[4]),
"reshape_view: insufficient memory");
Expand All @@ -303,33 +303,33 @@ void reshape_view(ViewType& out, const std::array<int, 5>& extents) {

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 6>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 7>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 8>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6], extents[7]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6], extents[7]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6], extents[7]);
}
Expand Down
16 changes: 8 additions & 8 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -45,7 +45,7 @@ auto create_plan(const ExecutionSpace& exec_space,
std::multiplies<>());

cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan1d failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed");

return fft_size;
}
Expand All @@ -69,7 +69,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -83,7 +83,7 @@ auto create_plan(const ExecutionSpace& exec_space,
std::multiplies<>());

cufft_rt = cufftPlan2d(&(*plan), nx, ny, type);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan2d failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed");

return fft_size;
}
Expand All @@ -107,7 +107,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -123,7 +123,7 @@ auto create_plan(const ExecutionSpace& exec_space,
std::multiplies<>());

cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan3d failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed");

return fft_size;
}
Expand Down Expand Up @@ -167,7 +167,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -176,7 +176,7 @@ auto create_plan(const ExecutionSpace& exec_space,
in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);

KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlanMany failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed");

return fft_size;
}
Expand Down
12 changes: 6 additions & 6 deletions fft/src/KokkosFFT_Cuda_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,42 @@ template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecR2C(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecR2C failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecR2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata,
cufftDoubleComplex* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecD2Z failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecD2Z failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecC2R(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2R failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2R failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleReal* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2D failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2D failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata,
cufftComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2C failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2Z failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2Z failed");
}
} // namespace Impl
} // namespace KokkosFFT
Expand Down
Loading

0 comments on commit 89209d8

Please sign in to comment.