Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup #8

Merged
merged 9 commits into from
Dec 14, 2023
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_CUDA_TYPES_HPP__
#define __KOKKOSFFT_CUDA_TYPES_HPP__
#ifndef KOKKOSFFT_CUDA_TYPES_HPP
#define KOKKOSFFT_CUDA_TYPES_HPP

#include <cufft.h>

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_HIP_TYPES_HPP__
#define __KOKKOSFFT_HIP_TYPES_HPP__
#ifndef KOKKOSFFT_HIP_TYPES_HPP
#define KOKKOSFFT_HIP_TYPES_HPP

#include <hipfft/hipfft.h>

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_OpenMP_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_OPENMP_TYPES_HPP__
#define __KOKKOSFFT_OPENMP_TYPES_HPP__
#ifndef KOKKOSFFT_OPENMP_TYPES_HPP
#define KOKKOSFFT_OPENMP_TYPES_HPP

#include <fftw3.h>
#include "KokkosFFT_utils.hpp"
Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_default_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_DEFAULT_TYPES_HPP__
#define __KOKKOSFFT_DEFAULT_TYPES_HPP__
#ifndef KOKKOSFFT_DEFAULT_TYPES_HPP
#define KOKKOSFFT_DEFAULT_TYPES_HPP

#include <Kokkos_Core.hpp>

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_LAYOUTS_HPP__
#define __KOKKOSFFT_LAYOUTS_HPP__
#ifndef KOKKOSFFT_LAYOUTS_HPP
#define KOKKOSFFT_LAYOUTS_HPP


#include <vector>
Expand Down
44 changes: 26 additions & 18 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#ifndef __KOKKOSFFT_NORMALIZATION_HPP__
#define __KOKKOSFFT_NORMALIZATION_HPP__
#ifndef KOKKOSFFT_NORMALIZATION_HPP
#define KOKKOSFFT_NORMALIZATION_HPP

#include <tuple>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
enum class FFT_Normalization {
enum class Normalization {
FORWARD,
BACKWARD,
ORTHO
Expand All @@ -23,32 +24,39 @@ namespace KokkosFFT {
}

template <typename ViewType>
auto _coefficients(const ViewType& inout, FFTDirectionType direction, FFT_Normalization normalization, std::size_t fft_size) {
using value_type = real_type_t<typename ViewType::value_type>;
auto _coefficients(const ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
using value_type = real_type_t<typename ViewType::non_const_value_type>;
value_type coef = 1;
bool to_normalize = false;

switch (normalization) {
case FFT_Normalization::FORWARD:
coef = direction == KOKKOS_FFT_FORWARD
? static_cast<value_type>(1) / static_cast<value_type>(fft_size)
: 1;
case Normalization::FORWARD:
if(direction == KOKKOS_FFT_FORWARD) {
coef = static_cast<value_type>(1) / static_cast<value_type>(fft_size);
to_normalize = true;
}

break;
case FFT_Normalization::BACKWARD:
coef = direction == KOKKOS_FFT_BACKWARD
? static_cast<value_type>(1) / static_cast<value_type>(fft_size)
: 1;
case Normalization::BACKWARD:
if(direction == KOKKOS_FFT_BACKWARD) {
coef = static_cast<value_type>(1) / static_cast<value_type>(fft_size);
to_normalize = true;
}

break;
case FFT_Normalization::ORTHO:
case Normalization::ORTHO:
coef = static_cast<value_type>(1) / Kokkos::sqrt(static_cast<value_type>(fft_size));
to_normalize = true;

break;
};
return coef;
return std::tuple<value_type, bool> ({coef, to_normalize});
}

template <typename ViewType>
void normalize(ViewType& inout, FFTDirectionType direction, FFT_Normalization normalization, std::size_t fft_size) {
auto coef = _coefficients(inout, direction, normalization, fft_size);
_normalize(inout, coef);
void normalize(ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
auto [coef, to_normalize] = _coefficients(inout, direction, normalization, fft_size);
if(to_normalize) _normalize(inout, coef);
}
};

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_TRANSPOSE_HPP__
#define __KOKKOSFFT_TRANSPOSE_HPP__
#ifndef KOKKOSFFT_TRANSPOSE_HPP
#define KOKKOSFFT_TRANSPOSE_HPP

#include <numeric>
#include <tuple>
Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_UTILS_HPP__
#define __KOKKOSFFT_UTILS_HPP__
#ifndef KOKKOSFFT_UTILS_HPP
#define KOKKOSFFT_UTILS_HPP

#include <Kokkos_Core.hpp>
#include <vector>
Expand Down
Loading