Skip to content

Commit

Permalink
Merge pull request #93 from kokkos/hotfix-shape-arg
Browse files Browse the repository at this point in the history
Hotfix shape arg
  • Loading branch information
yasahi-hpc authored Apr 5, 2024
2 parents 7281024 + e929a4c commit a322cb5
Show file tree
Hide file tree
Showing 5 changed files with 1,770 additions and 1,024 deletions.
27 changes: 20 additions & 7 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@
namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_modified_shape(const ViewType& view, shape_type<DIM> shape) {
auto get_modified_shape(const ViewType& view, shape_type<DIM> shape,
axis_type<DIM> axes, bool is_C2R = false) {
static_assert(ViewType::rank() >= DIM,
"get_modified_shape: Rank of View must be larger "
"than or equal to the Rank of new shape");
static_assert(DIM > 0,
"get_modified_shape: Rank of FFT axes must be "
"larger than or equal to 1");

// [TO DO] Add a is_C2R arg. If is_C2R is true, then shape should be shape/2+1
constexpr int rank = static_cast<int>(ViewType::rank());

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> positive_axes;
for (std::size_t i = 0; i < DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(view, axes.at(i));
positive_axes.push_back(axis);
}

// Assert if the elements are overlapped
assert(!KokkosFFT::Impl::has_duplicate_values(positive_axes));
assert(!KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank));

using full_shape_type = shape_type<rank>;
full_shape_type modified_shape;
for (int i = 0; i < rank; i++) {
Expand All @@ -31,13 +41,16 @@ auto get_modified_shape(const ViewType& view, shape_type<DIM> shape) {

// Update shapes based on newly given shape
for (int i = 0; i < DIM; i++) {
int positive_axis = positive_axes.at(i);
assert(shape.at(i) > 0);
modified_shape.at(i) = shape.at(i);
modified_shape.at(positive_axis) = shape.at(i);
}

if (is_C2R) {
int reshaped_axis = positive_axes.back();
modified_shape.at(reshaped_axis) = modified_shape.at(reshaped_axis) / 2 + 1;
}

// [TO DO] may return, is_modification_needed if modified_shape is not equal
// view.extents() May be implement other function. is_crop_or_pad_needed(view,
// shape);
return modified_shape;
}

Expand Down
Loading

0 comments on commit a322cb5

Please sign in to comment.