diff --git a/common/src/KokkosFFT_transpose.hpp b/common/src/KokkosFFT_transpose.hpp index 737a7ffe..cfbbc521 100644 --- a/common/src/KokkosFFT_transpose.hpp +++ b/common/src/KokkosFFT_transpose.hpp @@ -13,62 +13,60 @@ namespace KokkosFFT { namespace Impl { template -auto get_map_axes(const ViewType& view, axis_type _axes) { - KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(view, _axes), +auto get_map_axes(const ViewType& view, axis_type axes) { + KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(view, axes), "get_map_axes: input axes are not valid for the view"); // Convert the input axes to be in the range of [0, rank-1] - std::vector axes; + axis_type non_negative_axes = {}; for (std::size_t i = 0; i < DIM; i++) { - int axis = KokkosFFT::Impl::convert_negative_axis(view, _axes.at(i)); - axes.push_back(axis); + int axis = KokkosFFT::Impl::convert_negative_axis(view, axes.at(i)); + non_negative_axes[i] = axis; } // how indices are map // For 5D View and axes are (2,3), map would be (0, 1, 4, 2, 3) constexpr int rank = static_cast(ViewType::rank()); - std::vector map, map_inv; + std::vector map; map.reserve(rank); - map_inv.reserve(rank); if (std::is_same_v) { // Stack axes not specified by axes (0, 1, 4) for (int i = 0; i < rank; i++) { - if (!is_found(axes, i)) { + if (!is_found(non_negative_axes, i)) { map.push_back(i); } } // Stack axes on the map (For layout Right) // Then stack (2, 3) to have (0, 1, 4, 2, 3) - for (auto axis : axes) { + for (auto axis : non_negative_axes) { map.push_back(axis); } } else { // For layout Left, stack innermost axes first - std::reverse(axes.begin(), axes.end()); - for (auto axis : axes) { + std::reverse(non_negative_axes.begin(), non_negative_axes.end()); + for (auto axis : non_negative_axes) { map.push_back(axis); } // Then stack remaining axes for (int i = 0; i < rank; i++) { - if (!is_found(axes, i)) { + if (!is_found(non_negative_axes, i)) { map.push_back(i); } } } + using full_axis_type = axis_type; + full_axis_type array_map = {}, array_map_inv = {}; + std::copy_n(map.begin(), rank, array_map.begin()); + // Construct inverse map for (int i = 0; i < rank; i++) { - map_inv.push_back(get_index(map, i)); + array_map_inv[i] = get_index(array_map, i); } - using full_axis_type = axis_type; - full_axis_type array_map = {0}, array_map_inv = {0}; - std::copy(map.begin(), map.end(), array_map.begin()); - std::copy(map_inv.begin(), map_inv.end(), array_map_inv.begin()); - return std::tuple({array_map, array_map_inv}); }