Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove non-batched version of get_extents #95

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 2 additions & 95 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,71 +19,8 @@ namespace Impl {
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType, std::size_t DIM = 1>
auto get_extents(InViewType& in, OutViewType& out, axis_type<DIM> _axes) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

// index map after transpose over axis
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, _axes);

constexpr std::size_t rank = InViewType::rank;
[[maybe_unused]] int inner_most_axis =
std::is_same_v<array_layout_type, typename Kokkos::LayoutLeft> ? 0
: rank - 1;

std::vector<int> in_extents, out_extents, fft_extents;
for (std::size_t i = 0; i < rank; i++) {
auto _idx = map.at(i);
in_extents.push_back(in.extent(_idx));
out_extents.push_back(out.extent(_idx));

// The extent for transform is always equal to the extent
// of the extent of real type (R2C or C2R)
// For C2C, the in and out extents are the same.
// In the end, we can just use the largest extent among in and out extents.
auto fft_extent = std::max(in.extent(_idx), out.extent(_idx));
fft_extents.push_back(fft_extent);
}

if (std::is_floating_point<in_value_type>::value) {
// Then R2C
if (is_complex<out_value_type>::value) {
assert(out_extents.at(inner_most_axis) ==
in_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the input type is real, the output type should be complex");
}
}

if (std::is_floating_point<out_value_type>::value) {
// Then C2R
if (is_complex<in_value_type>::value) {
assert(in_extents.at(inner_most_axis) ==
out_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the output type is real, the input type should be complex");
}
}

if (std::is_same<array_layout_type, Kokkos::LayoutLeft>::value) {
std::reverse(in_extents.begin(), in_extents.end());
std::reverse(out_extents.begin(), out_extents.end());
std::reverse(fft_extents.begin(), fft_extents.end());
}

return std::tuple<std::vector<int>, std::vector<int>, std::vector<int> >(
{in_extents, out_extents, fft_extents});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType, std::size_t DIM = 1>
auto get_extents_batched(InViewType& in, OutViewType& out,
axis_type<DIM> _axes) {
auto get_extents(const InViewType& in, const OutViewType& out,
axis_type<DIM> _axes) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;
Expand Down Expand Up @@ -164,36 +101,6 @@ auto get_extents_batched(InViewType& in, OutViewType& out,
return std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, int>(
{in_extents, out_extents, fft_extents, howmany});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType>
auto get_extents(InViewType& in, OutViewType& out) {
using array_layout_type = typename InViewType::array_layout;

constexpr std::size_t rank = InViewType::rank;
int inner_most_axis =
std::is_same_v<array_layout_type, typename Kokkos::LayoutLeft> ? 0
: rank - 1;
return get_extents(in, out, axis_type<1>{inner_most_axis});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType>
auto get_extents(InViewType& in, OutViewType& out, int _axis) {
return get_extents(in, out, axis_type<1>{_axis});
}

/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
template <typename InViewType, typename OutViewType>
auto get_extents_batched(InViewType& in, OutViewType& out, int _axis) {
return get_extents_batched(in, out, axis_type<1>{_axis});
}
} // namespace Impl
}; // namespace KokkosFFT

Expand Down
Loading
Loading