Skip to content

Commit

Permalink
Merge pull request #95 from kokkos/remove-get-extents-non-batched
Browse files Browse the repository at this point in the history
Remove non-batched version of get_extents
  • Loading branch information
yasahi-hpc authored Apr 9, 2024
2 parents 60bfd7d + 779adf7 commit b15bc9f
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 192 deletions.
97 changes: 2 additions & 95 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,71 +19,8 @@ namespace Impl {
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType, std::size_t DIM = 1>
auto get_extents(InViewType& in, OutViewType& out, axis_type<DIM> _axes) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

// index map after transpose over axis
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, _axes);

constexpr std::size_t rank = InViewType::rank;
[[maybe_unused]] int inner_most_axis =
std::is_same_v<array_layout_type, typename Kokkos::LayoutLeft> ? 0
: rank - 1;

std::vector<int> in_extents, out_extents, fft_extents;
for (std::size_t i = 0; i < rank; i++) {
auto _idx = map.at(i);
in_extents.push_back(in.extent(_idx));
out_extents.push_back(out.extent(_idx));

// The extent for transform is always equal to the extent
// of the extent of real type (R2C or C2R)
// For C2C, the in and out extents are the same.
// In the end, we can just use the largest extent among in and out extents.
auto fft_extent = std::max(in.extent(_idx), out.extent(_idx));
fft_extents.push_back(fft_extent);
}

if (std::is_floating_point<in_value_type>::value) {
// Then R2C
if (is_complex<out_value_type>::value) {
assert(out_extents.at(inner_most_axis) ==
in_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the input type is real, the output type should be complex");
}
}

if (std::is_floating_point<out_value_type>::value) {
// Then C2R
if (is_complex<in_value_type>::value) {
assert(in_extents.at(inner_most_axis) ==
out_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the output type is real, the input type should be complex");
}
}

if (std::is_same<array_layout_type, Kokkos::LayoutLeft>::value) {
std::reverse(in_extents.begin(), in_extents.end());
std::reverse(out_extents.begin(), out_extents.end());
std::reverse(fft_extents.begin(), fft_extents.end());
}

return std::tuple<std::vector<int>, std::vector<int>, std::vector<int> >(
{in_extents, out_extents, fft_extents});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType, std::size_t DIM = 1>
auto get_extents_batched(InViewType& in, OutViewType& out,
axis_type<DIM> _axes) {
auto get_extents(const InViewType& in, const OutViewType& out,
axis_type<DIM> _axes) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;
Expand Down Expand Up @@ -164,36 +101,6 @@ auto get_extents_batched(InViewType& in, OutViewType& out,
return std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, int>(
{in_extents, out_extents, fft_extents, howmany});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType>
auto get_extents(InViewType& in, OutViewType& out) {
using array_layout_type = typename InViewType::array_layout;

constexpr std::size_t rank = InViewType::rank;
int inner_most_axis =
std::is_same_v<array_layout_type, typename Kokkos::LayoutLeft> ? 0
: rank - 1;
return get_extents(in, out, axis_type<1>{inner_most_axis});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType>
auto get_extents(InViewType& in, OutViewType& out, int _axis) {
return get_extents(in, out, axis_type<1>{_axis});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType>
auto get_extents_batched(InViewType& in, OutViewType& out, int _axis) {
return get_extents_batched(in, out, axis_type<1>{_axis});
}
} // namespace Impl
}; // namespace KokkosFFT

Expand Down
Loading

0 comments on commit b15bc9f

Please sign in to comment.