Skip to content

Commit

Permalink
Merge pull request #8 from CExA-project/cleanup
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
yasahi-hpc authored Dec 14, 2023
2 parents 0975bc9 + 45d0294 commit 2a25891
Show file tree
Hide file tree
Showing 24 changed files with 1,492 additions and 1,847 deletions.
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_CUDA_TYPES_HPP__
#define __KOKKOSFFT_CUDA_TYPES_HPP__
#ifndef KOKKOSFFT_CUDA_TYPES_HPP
#define KOKKOSFFT_CUDA_TYPES_HPP

#include <cufft.h>

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_HIP_TYPES_HPP__
#define __KOKKOSFFT_HIP_TYPES_HPP__
#ifndef KOKKOSFFT_HIP_TYPES_HPP
#define KOKKOSFFT_HIP_TYPES_HPP

#include <hipfft/hipfft.h>

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_OpenMP_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_OPENMP_TYPES_HPP__
#define __KOKKOSFFT_OPENMP_TYPES_HPP__
#ifndef KOKKOSFFT_OPENMP_TYPES_HPP
#define KOKKOSFFT_OPENMP_TYPES_HPP

#include <fftw3.h>
#include "KokkosFFT_utils.hpp"
Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_default_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_DEFAULT_TYPES_HPP__
#define __KOKKOSFFT_DEFAULT_TYPES_HPP__
#ifndef KOKKOSFFT_DEFAULT_TYPES_HPP
#define KOKKOSFFT_DEFAULT_TYPES_HPP

#include <Kokkos_Core.hpp>

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_LAYOUTS_HPP__
#define __KOKKOSFFT_LAYOUTS_HPP__
#ifndef KOKKOSFFT_LAYOUTS_HPP
#define KOKKOSFFT_LAYOUTS_HPP


#include <vector>
Expand Down
44 changes: 26 additions & 18 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#ifndef __KOKKOSFFT_NORMALIZATION_HPP__
#define __KOKKOSFFT_NORMALIZATION_HPP__
#ifndef KOKKOSFFT_NORMALIZATION_HPP
#define KOKKOSFFT_NORMALIZATION_HPP

#include <tuple>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
enum class FFT_Normalization {
enum class Normalization {
FORWARD,
BACKWARD,
ORTHO
Expand All @@ -23,32 +24,39 @@ namespace KokkosFFT {
}

template <typename ViewType>
auto _coefficients(const ViewType& inout, FFTDirectionType direction, FFT_Normalization normalization, std::size_t fft_size) {
using value_type = real_type_t<typename ViewType::value_type>;
auto _coefficients(const ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
using value_type = real_type_t<typename ViewType::non_const_value_type>;
value_type coef = 1;
bool to_normalize = false;

switch (normalization) {
case FFT_Normalization::FORWARD:
coef = direction == KOKKOS_FFT_FORWARD
? static_cast<value_type>(1) / static_cast<value_type>(fft_size)
: 1;
case Normalization::FORWARD:
if(direction == KOKKOS_FFT_FORWARD) {
coef = static_cast<value_type>(1) / static_cast<value_type>(fft_size);
to_normalize = true;
}

break;
case FFT_Normalization::BACKWARD:
coef = direction == KOKKOS_FFT_BACKWARD
? static_cast<value_type>(1) / static_cast<value_type>(fft_size)
: 1;
case Normalization::BACKWARD:
if(direction == KOKKOS_FFT_BACKWARD) {
coef = static_cast<value_type>(1) / static_cast<value_type>(fft_size);
to_normalize = true;
}

break;
case FFT_Normalization::ORTHO:
case Normalization::ORTHO:
coef = static_cast<value_type>(1) / Kokkos::sqrt(static_cast<value_type>(fft_size));
to_normalize = true;

break;
};
return coef;
return std::tuple<value_type, bool> ({coef, to_normalize});
}

template <typename ViewType>
void normalize(ViewType& inout, FFTDirectionType direction, FFT_Normalization normalization, std::size_t fft_size) {
auto coef = _coefficients(inout, direction, normalization, fft_size);
_normalize(inout, coef);
void normalize(ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
auto [coef, to_normalize] = _coefficients(inout, direction, normalization, fft_size);
if(to_normalize) _normalize(inout, coef);
}
};

Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_TRANSPOSE_HPP__
#define __KOKKOSFFT_TRANSPOSE_HPP__
#ifndef KOKKOSFFT_TRANSPOSE_HPP
#define KOKKOSFFT_TRANSPOSE_HPP

#include <numeric>
#include <tuple>
Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __KOKKOSFFT_UTILS_HPP__
#define __KOKKOSFFT_UTILS_HPP__
#ifndef KOKKOSFFT_UTILS_HPP
#define KOKKOSFFT_UTILS_HPP

#include <Kokkos_Core.hpp>
#include <vector>
Expand Down
Loading

0 comments on commit 2a25891

Please sign in to comment.