Skip to content

Commit

Permalink
Rename KokkosFFT::FFT_Normalization into KokkosFFT::Normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 14, 2023
1 parent 0ffcc27 commit 45d0294
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 170 deletions.
12 changes: 6 additions & 6 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
enum class FFT_Normalization {
enum class Normalization {
FORWARD,
BACKWARD,
ORTHO
Expand All @@ -24,27 +24,27 @@ namespace KokkosFFT {
}

template <typename ViewType>
auto _coefficients(const ViewType& inout, FFTDirectionType direction, FFT_Normalization normalization, std::size_t fft_size) {
auto _coefficients(const ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
using value_type = real_type_t<typename ViewType::non_const_value_type>;
value_type coef = 1;
bool to_normalize = false;

switch (normalization) {
case FFT_Normalization::FORWARD:
case Normalization::FORWARD:
if(direction == KOKKOS_FFT_FORWARD) {
coef = static_cast<value_type>(1) / static_cast<value_type>(fft_size);
to_normalize = true;
}

break;
case FFT_Normalization::BACKWARD:
case Normalization::BACKWARD:
if(direction == KOKKOS_FFT_BACKWARD) {
coef = static_cast<value_type>(1) / static_cast<value_type>(fft_size);
to_normalize = true;
}

break;
case FFT_Normalization::ORTHO:
case Normalization::ORTHO:
coef = static_cast<value_type>(1) / Kokkos::sqrt(static_cast<value_type>(fft_size));
to_normalize = true;

Expand All @@ -54,7 +54,7 @@ namespace KokkosFFT {
}

template <typename ViewType>
void normalize(ViewType& inout, FFTDirectionType direction, FFT_Normalization normalization, std::size_t fft_size) {
void normalize(ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
auto [coef, to_normalize] = _coefficients(inout, direction, normalization, fft_size);
if(to_normalize) _normalize(inout, coef);
}
Expand Down
12 changes: 6 additions & 6 deletions common/unit_test/Test_Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ TEST(Normalization, Forward) {
Kokkos::fence();

// Backward FFT with Forward Normalization -> Do nothing
KokkosFFT::normalize(x, KOKKOS_FFT_BACKWARD, KokkosFFT::FFT_Normalization::FORWARD, len);
KokkosFFT::normalize(x, KOKKOS_FFT_BACKWARD, KokkosFFT::Normalization::FORWARD, len);
EXPECT_TRUE( allclose(x, ref_b, 1.e-5, 1.e-12) );

// Forward FFT with Forward Normalization -> 1/N normalization
KokkosFFT::normalize(x, KOKKOS_FFT_FORWARD, KokkosFFT::FFT_Normalization::FORWARD, len);
KokkosFFT::normalize(x, KOKKOS_FFT_FORWARD, KokkosFFT::Normalization::FORWARD, len);
EXPECT_TRUE( allclose(x, ref_f, 1.e-5, 1.e-12) );
}

Expand All @@ -45,11 +45,11 @@ TEST(Normalization, Backward) {
Kokkos::fence();

// Forward FFT with Backward Normalization -> Do nothing
KokkosFFT::normalize(x, KOKKOS_FFT_FORWARD, KokkosFFT::FFT_Normalization::BACKWARD, len);
KokkosFFT::normalize(x, KOKKOS_FFT_FORWARD, KokkosFFT::Normalization::BACKWARD, len);
EXPECT_TRUE( allclose(x, ref_f, 1.e-5, 1.e-12) );

// Backward FFT with Backward Normalization -> 1/N normalization
KokkosFFT::normalize(x, KOKKOS_FFT_BACKWARD, KokkosFFT::FFT_Normalization::BACKWARD, len);
KokkosFFT::normalize(x, KOKKOS_FFT_BACKWARD, KokkosFFT::Normalization::BACKWARD, len);
EXPECT_TRUE( allclose(x, ref_b, 1.e-5, 1.e-12) );
}

Expand All @@ -72,10 +72,10 @@ TEST(Normalization, Ortho) {
Kokkos::fence();

// Forward FFT with Ortho Normalization -> 1 / sqrt(N) normalization
KokkosFFT::normalize(x_f, KOKKOS_FFT_FORWARD, KokkosFFT::FFT_Normalization::ORTHO, len);
KokkosFFT::normalize(x_f, KOKKOS_FFT_FORWARD, KokkosFFT::Normalization::ORTHO, len);
EXPECT_TRUE( allclose(x_f, ref_f, 1.e-5, 1.e-12) );

// Backward FFT with Ortho Normalization -> 1 / sqrt(N) normalization
KokkosFFT::normalize(x_b, KOKKOS_FFT_BACKWARD, KokkosFFT::FFT_Normalization::ORTHO, len);
KokkosFFT::normalize(x_b, KOKKOS_FFT_BACKWARD, KokkosFFT::Normalization::ORTHO, len);
EXPECT_TRUE( allclose(x_b, ref_b, 1.e-5, 1.e-12) );
}
8 changes: 4 additions & 4 deletions examples/04_batchedFFT/04_batchedFFT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ int main( int argc, char* argv[] ) {
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Kokkos::fill_random(xc2c, random_pool, I);

KokkosFFT::fft(xc2c, xc2c_hat, KokkosFFT::FFT_Normalization::BACKWARD, /*axis=*/-1);
KokkosFFT::ifft(xc2c_hat, xc2c_inv, KokkosFFT::FFT_Normalization::BACKWARD, /*axis=*/-1);
KokkosFFT::fft(xc2c, xc2c_hat, KokkosFFT::Normalization::BACKWARD, /*axis=*/-1);
KokkosFFT::ifft(xc2c_hat, xc2c_inv, KokkosFFT::Normalization::BACKWARD, /*axis=*/-1);

// 1D batched R2C FFT
View3D<double> xr2c("xr2c", n0, n1, n2);
View3D<Kokkos::complex<double> > xr2c_hat("xr2c_hat", n0, n1, n2/2+1);
Kokkos::fill_random(xr2c, random_pool, 1);

KokkosFFT::rfft(xr2c, xr2c_hat, KokkosFFT::FFT_Normalization::BACKWARD, /*axis=*/-1);
KokkosFFT::rfft(xr2c, xr2c_hat, KokkosFFT::Normalization::BACKWARD, /*axis=*/-1);

// 1D batched C2R FFT
View3D<Kokkos::complex<double> > xc2r("xr2c_hat", n0, n1, n2/2+1);
View3D<double> xc2r_hat("xc2r", n0, n1, n2);
Kokkos::fill_random(xc2r, random_pool, I);

KokkosFFT::irfft(xc2r, xc2r_hat, KokkosFFT::FFT_Normalization::BACKWARD, /*axis=*/-1);
KokkosFFT::irfft(xc2r, xc2r_hat, KokkosFFT::Normalization::BACKWARD, /*axis=*/-1);
}
Kokkos::finalize();

Expand Down
36 changes: 18 additions & 18 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
// 1D Transform
namespace KokkosFFT {
template <typename PlanType, typename InViewType, typename OutViewType>
void _fft(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void _fft(PlanType& plan, const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_fft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -45,7 +45,7 @@ namespace KokkosFFT {
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _ifft(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void _ifft(PlanType& plan, const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_ifft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -62,7 +62,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void fft(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, int axis=-1) {
void fft(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, int axis=-1) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -86,7 +86,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void ifft(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, int axis=-1) {
void ifft(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, int axis=-1) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -110,7 +110,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void rfft(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, int axis=-1) {
void rfft(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, int axis=-1) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -128,7 +128,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void irfft(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, int axis=-1) {
void irfft(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, int axis=-1) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -148,7 +148,7 @@ namespace KokkosFFT {

namespace KokkosFFT {
template <typename InViewType, typename OutViewType>
void fft2(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
void fft2(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fft2: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -171,7 +171,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void ifft2(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
void ifft2(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifft2: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -194,7 +194,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void rfft2(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
void rfft2(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfft2: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -212,7 +212,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void irfft2(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
void irfft2(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD, axis_type<2> axes={-2, -1}) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfft2: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -232,7 +232,7 @@ namespace KokkosFFT {

namespace KokkosFFT {
template <typename InViewType, typename OutViewType>
void fftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void fftn(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand Down Expand Up @@ -260,7 +260,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void fftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void fftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -283,7 +283,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void ifftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void ifftn(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand Down Expand Up @@ -311,7 +311,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void ifftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void ifftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -334,7 +334,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void rfftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void rfftn(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -352,7 +352,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void rfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void rfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -370,7 +370,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType>
void irfftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void irfftn(const InViewType& in, OutViewType& out, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand All @@ -388,7 +388,7 @@ namespace KokkosFFT {
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void irfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
void irfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, KokkosFFT::Normalization norm=KokkosFFT::Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
Expand Down
Loading

0 comments on commit 45d0294

Please sign in to comment.