Skip to content

Commit

Permalink
add default axes to ND fft functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Sep 24, 2024
1 parent 6b65a27 commit f6001ed
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 36 deletions.
48 changes: 28 additions & 20 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,13 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in,
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void fftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void fftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down Expand Up @@ -507,11 +509,13 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void ifftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down Expand Up @@ -542,11 +546,13 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void rfftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down Expand Up @@ -583,11 +589,13 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void irfftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void irfftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down
32 changes: 16 additions & 16 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2322,8 +2322,8 @@ void test_fftn_2dfft_2dview() {

using axes_type = KokkosFFT::axis_type<2>;
axes_type axes = {-2, -1};
KokkosFFT::fftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x, out_b, axes,
KokkosFFT::Normalization::backward);
KokkosFFT::fftn(execution_space(), x, out_o, axes,
Expand Down Expand Up @@ -2387,8 +2387,8 @@ void test_ifftn_2dfft_2dview() {
KokkosFFT::ifft(execution_space(), out1, out2,
KokkosFFT::Normalization::backward, /*axis=*/0);

KokkosFFT::ifftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x, out_b, axes,
KokkosFFT::Normalization::backward);
KokkosFFT::ifftn(execution_space(), x, out_o, axes,
Expand Down Expand Up @@ -2452,8 +2452,8 @@ void test_rfftn_2dfft_2dview() {
KokkosFFT::Normalization::backward, /*axis=*/0);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::rfftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x, out_b, axes,
Expand Down Expand Up @@ -2531,8 +2531,8 @@ void test_irfftn_2dfft_2dview() {
KokkosFFT::Normalization::backward, /*axis=*/1);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::irfftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(execution_space(), x, out_b, axes,
Expand Down Expand Up @@ -2698,8 +2698,8 @@ void test_fftn_3dfft_3dview(T atol = 1.0e-6) {
KokkosFFT::fft(execution_space(), out2, out3,
KokkosFFT::Normalization::backward, /*axis=*/0);

KokkosFFT::fftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x, out_b, axes,
KokkosFFT::Normalization::backward);
KokkosFFT::fftn(execution_space(), x, out_o, axes,
Expand Down Expand Up @@ -2746,8 +2746,8 @@ void test_ifftn_3dfft_3dview() {
KokkosFFT::ifft(execution_space(), out2, out3,
KokkosFFT::Normalization::backward, /*axis=*/0);

KokkosFFT::ifftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x, out_b, axes,
KokkosFFT::Normalization::backward);
KokkosFFT::ifftn(execution_space(), x, out_o, axes,
Expand Down Expand Up @@ -2796,8 +2796,8 @@ void test_rfftn_3dfft_3dview() {
KokkosFFT::Normalization::backward, /*axis=*/0);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::rfftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x, out_b, axes,
Expand Down Expand Up @@ -2853,8 +2853,8 @@ void test_irfftn_3dfft_3dview() {
KokkosFFT::Normalization::backward, /*axis=*/2);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::irfftn(execution_space(), x,
out); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(execution_space(), x, out_b, axes,
Expand Down

0 comments on commit f6001ed

Please sign in to comment.