Skip to content

Commit

Permalink
Add missing checks (#196)
Browse files Browse the repository at this point in the history
* Add assertions for completely mismatched extents

* Raise an error if inplace plan is executed on out-of-place views

* remove unnecessary fence

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Nov 27, 2024
1 parent bbf765a commit 18d113f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
16 changes: 16 additions & 0 deletions common/src/KokkosFFT_Extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ auto get_extents(const InViewType& in, const OutViewType& out,
static_assert(!(is_real_v<in_value_type> && is_real_v<out_value_type>),
"get_extents: real to real transform is not supported");

for (std::size_t i = 0; i < rank; i++) {
// The requirement for inner_most_axis is different for transform type
if (static_cast<int>(i) == inner_most_axis) continue;
KOKKOSFFT_THROW_IF(in_extents_full.at(i) != out_extents_full.at(i),
"input and output extents must be the same except for "
"the transform axis");
}

if constexpr (is_complex_v<in_value_type> && is_complex_v<out_value_type>) {
// Then C2C
KOKKOSFFT_THROW_IF(
in_extents_full.at(inner_most_axis) !=
out_extents_full.at(inner_most_axis),
"input and output extents must be the same for C2C transform");
}

if constexpr (is_real_v<in_value_type>) {
// Then R2C
if (is_inplace) {
Expand Down
43 changes: 43 additions & 0 deletions common/unit_test/Test_Extents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ void test_extents_1d_batched_FFT_2d() {
EXPECT_TRUE(fft_extents_c2c_axis1 == ref_fft_extents_r2c_axis1);
EXPECT_TRUE(out_extents_c2c_axis1 == ref_in_extents_r2c_axis1);
EXPECT_EQ(howmany_c2c_axis1, ref_howmany_r2c_axis1);

// Check if errors are correctly raised aginst invalid extents
ComplexView2Dtype xcout2_wrong("xcout2_wrong", n0 + 3, n1);
for (int i = 0; i < 2; i++) {
EXPECT_THROW(
{ KokkosFFT::Impl::get_extents(xcin2, xcout2_wrong, axes_type({i})); },
std::runtime_error);
}
}

template <typename LayoutType>
Expand Down Expand Up @@ -306,6 +314,14 @@ void test_extents_1d_batched_FFT_3d() {
EXPECT_TRUE(fft_extents_c2c_axis2 == ref_fft_extents_r2c_axis2);
EXPECT_TRUE(out_extents_c2c_axis2 == ref_in_extents_r2c_axis2);
EXPECT_EQ(howmany_c2c_axis2, ref_howmany_r2c_axis2);

// Check if errors are correctly raised aginst invalid extents
ComplexView3Dtype xcout3_wrong("xcout3_wrong", n0 + 3, n1, n2);
for (int i = 0; i < 3; i++) {
EXPECT_THROW(
{ KokkosFFT::Impl::get_extents(xcin3, xcout3_wrong, axes_type({i})); },
std::runtime_error);
}
}

TYPED_TEST(Extents1D, 1DFFT_1DView) {
Expand Down Expand Up @@ -429,6 +445,20 @@ void test_extents_2d() {

EXPECT_EQ(howmany_c2c_axis01, 1);
EXPECT_EQ(howmany_c2c_axis10, 1);

// Check if errors are correctly raised aginst invalid extents
ComplexView2Dtype xcout2_wrong("xcout2_wrong", n0 + 3, n1);
for (int axis0 = 0; axis0 < 2; axis0++) {
for (int axis1 = 0; axis1 < 2; axis1++) {
if (axis0 == axis1) continue;
EXPECT_THROW(
{
KokkosFFT::Impl::get_extents(xcin2, xcout2_wrong,
axes_type({axis0, axis1}));
},
std::runtime_error);
}
}
}

template <typename LayoutType>
Expand Down Expand Up @@ -709,6 +739,19 @@ void test_extents_2d_batched_FFT_3d() {
EXPECT_TRUE(fft_extents_c2c_axis_21 == ref_fft_extents_r2c_axis_21);
EXPECT_TRUE(out_extents_c2c_axis_21 == ref_in_extents_r2c_axis_21);
EXPECT_EQ(howmany_c2c_axis_21, ref_howmany_r2c_axis_21);

ComplexView3Dtype xcout3_wrong("xcout3_wrong", n0 + 3, n1, n2 + 2);
for (int axis0 = 0; axis0 < 3; axis0++) {
for (int axis1 = 0; axis1 < 3; axis1++) {
if (axis0 == axis1) continue;
EXPECT_THROW(
{
KokkosFFT::Impl::get_extents(xcin3, xcout3_wrong,
axes_type({axis0, axis1}));
},
std::runtime_error);
}
}
}

TYPED_TEST(Extents2D, 2DFFT_2DView) {
Expand Down
5 changes: 5 additions & 0 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ class Plan {
KOKKOSFFT_THROW_IF(out_extents != m_out_extents,
"extents of output View for plan and "
"execution are not identical.");

bool is_inplace = KokkosFFT::Impl::are_aliasing(in.data(), out.data());
KOKKOSFFT_THROW_IF(is_inplace != m_is_inplace,
"If the plan is in-place, the input and output Views "
"must be identical.");
}
};
} // namespace KokkosFFT
Expand Down
39 changes: 37 additions & 2 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) {
Kokkos::deep_copy(a_ref, a);
Kokkos::deep_copy(ar_ref, ar);

Kokkos::fence();

KokkosFFT::fft(execution_space(), a, a_hat);
KokkosFFT::ifft(execution_space(), a_hat, inv_a_hat);

Expand All @@ -115,6 +113,43 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) {

EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol));

// Create a plan for inplace transform
Kokkos::deep_copy(a_ref, a);
Kokkos::deep_copy(ar_ref, ar);

int axis = -1;
KokkosFFT::Plan fft_plan(execution_space(), a, a_hat,
KokkosFFT::Direction::forward, axis);
fft_plan.execute(a, a_hat);

KokkosFFT::Plan ifft_plan(execution_space(), a_hat, inv_a_hat,
KokkosFFT::Direction::backward, axis);
ifft_plan.execute(a_hat, inv_a_hat);

KokkosFFT::Plan rfft_plan(execution_space(), ar, ar_hat,
KokkosFFT::Direction::forward, axis);
rfft_plan.execute(ar, ar_hat);

KokkosFFT::Plan irfft_plan(execution_space(), ar_hat, inv_ar_hat,
KokkosFFT::Direction::backward, axis);
irfft_plan.execute(ar_hat, inv_ar_hat);

EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol));

// inplace Plan cannot be reused for out-of-place case
ComplexView1DType a_hat_out("a_hat_out", i),
inv_a_hat_out("inv_a_hat_out", i);

RealView1DType inv_ar_hat_out("inv_ar_hat_out", i);
ComplexView1DType ar_hat_out("ar_hat_out", i / 2 + 1);
EXPECT_THROW(fft_plan.execute(a, a_hat_out), std::runtime_error);
EXPECT_THROW(ifft_plan.execute(a_hat_out, inv_a_hat_out),
std::runtime_error);
EXPECT_THROW(rfft_plan.execute(ar, ar_hat_out), std::runtime_error);
EXPECT_THROW(irfft_plan.execute(ar_hat_out, inv_ar_hat_out),
std::runtime_error);
}
}

Expand Down

0 comments on commit 18d113f

Please sign in to comment.