Skip to content

Commit

Permalink
extend check functions to work on std::array or std::vector
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jul 31, 2024
1 parent 4a1fa0d commit 1cbed11
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 34 deletions.
62 changes: 50 additions & 12 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <numeric>
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_common_types.hpp"

#if defined(KOKKOS_ENABLE_CXX17)
#include <cstdlib>
Expand Down Expand Up @@ -85,38 +86,75 @@ auto convert_negative_shift(const ViewType& view, int _shift, int _axis) {
return std::tuple<int, int, int>({shift0, shift1, shift2});
}

template <typename T>
bool is_found(std::vector<T>& values, const T& key) {
template <typename ContainerType, typename ValueType>
bool is_found(ContainerType& values, const ValueType& key) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, ValueType>,
"Container value type must match ValueType");
return std::find(values.begin(), values.end(), key) != values.end();
}

template <typename T>
bool has_duplicate_values(const std::vector<T>& values) {
std::set<T> set_values(values.begin(), values.end());
template <typename ContainerType>
bool has_duplicate_values(const ContainerType& values) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
std::set<value_type> set_values(values.begin(), values.end());
return set_values.size() < values.size();
}

template <typename IntType, std::enable_if_t<std::is_integral_v<IntType>,
std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const std::vector<IntType>& values,
IntType max) {
template <
typename ContainerType, typename IntType,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const ContainerType& values, IntType max) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, IntType>,
"Container value type must match IntType");
bool is_included = false;
for (auto value : values) {
is_included = value >= max;
}
return is_included;
}

template <
typename ViewType, template <typename, std::size_t> class ArrayType,
typename IntType, std::size_t DIM = 1,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool are_valid_axes(const ViewType& view, const ArrayType<IntType, DIM>& axes) {
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"are_valid_axes: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"are_valid_axes: View rank must be larger than or equal to the "
"Rank of FFT axes");
constexpr int rank = ViewType::rank();

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> non_negative_axes;
for (std::size_t i = 0; i < DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(view, axes[i]);
non_negative_axes.push_back(axis);
}

bool is_valid = (!KokkosFFT::Impl::has_duplicate_values(non_negative_axes)) &&
(!KokkosFFT::Impl::is_out_of_range_value_included(
non_negative_axes, rank));

return is_valid;
}

template <std::size_t DIM = 1>
bool is_transpose_needed(std::array<int, DIM> map) {
std::array<int, DIM> contiguous_map;
std::iota(contiguous_map.begin(), contiguous_map.end(), 0);
return map != contiguous_map;
}

template <typename T>
std::size_t get_index(std::vector<T>& values, const T& key) {
auto it = find(values.begin(), values.end(), key);
template <typename ContainerType, typename ValueType>
std::size_t get_index(ContainerType& values, const ValueType& key) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, ValueType>,
"Container value type must match ValueType");
auto it = std::find(values.begin(), values.end(), key);
std::size_t index = 0;
if (it != values.end()) {
index = it - values.begin();
Expand Down
179 changes: 157 additions & 22 deletions common/unit_test/Test_Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

using test_types = ::testing::Types<Kokkos::LayoutLeft, Kokkos::LayoutRight>;

// Int like types
using base_int_types = ::testing::Types<int, std::size_t>;

// Basically the same fixtures, used for labeling tests
template <typename T>
struct ConvertNegativeAxis : public ::testing::Test {
Expand All @@ -19,8 +22,17 @@ struct ConvertNegativeShift : public ::testing::Test {
using layout_type = T;
};

template <typename T>
struct ContainerTypes : public ::testing::Test {
static constexpr std::size_t rank = 5;
using value_type = T;
using vector_type = std::vector<T>;
using array_type = std::array<T, rank>;
};

TYPED_TEST_SUITE(ConvertNegativeAxis, test_types);
TYPED_TEST_SUITE(ConvertNegativeShift, test_types);
TYPED_TEST_SUITE(ContainerTypes, base_int_types);

// Tests for convert_negative_axes over ND views
template <typename LayoutType>
Expand Down Expand Up @@ -267,39 +279,162 @@ TEST(IsTransposeNeeded, 1Dto3D) {
EXPECT_TRUE(KokkosFFT::Impl::is_transpose_needed(map3D_210));
}

TEST(GetIndex, Vectors) {
std::vector<int> v = {0, 1, 4, 2, 3};

EXPECT_EQ(KokkosFFT::Impl::get_index(v, 0), 0);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, 1), 1);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, 2), 3);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, 3), 4);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, 4), 2);

EXPECT_THROW(KokkosFFT::Impl::get_index(v, -1), std::runtime_error);
template <typename ContainerType>
void test_is_found() {
using IntType = KokkosFFT::Impl::base_container_value_type<ContainerType>;
ContainerType v = {0, 1, 4, 2, 3};

EXPECT_TRUE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(0)));
EXPECT_TRUE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(1)));
EXPECT_TRUE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(2)));
EXPECT_TRUE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(3)));
EXPECT_TRUE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(4)));

if constexpr (std::is_signed_v<IntType>) {
EXPECT_FALSE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(-1)));
}
EXPECT_FALSE(KokkosFFT::Impl::is_found(v, static_cast<IntType>(5)));
}

EXPECT_THROW(KokkosFFT::Impl::get_index(v, 5), std::runtime_error);
template <typename ContainerType>
void test_get_index() {
using IntType = KokkosFFT::Impl::base_container_value_type<ContainerType>;
ContainerType v = {0, 1, 4, 2, 3};

EXPECT_EQ(KokkosFFT::Impl::get_index(v, static_cast<IntType>(0)), 0);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, static_cast<IntType>(1)), 1);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, static_cast<IntType>(2)), 3);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, static_cast<IntType>(3)), 4);
EXPECT_EQ(KokkosFFT::Impl::get_index(v, static_cast<IntType>(4)), 2);

if constexpr (std::is_signed_v<IntType>) {
EXPECT_THROW(KokkosFFT::Impl::get_index(v, static_cast<IntType>(-1)),
std::runtime_error);
}
EXPECT_THROW(KokkosFFT::Impl::get_index(v, static_cast<IntType>(5)),
std::runtime_error);
}

TEST(HasDuplicateValues, Array) {
std::vector<int> v0 = {0, 1, 1};
std::vector<int> v1 = {0, 1, 1, 1};
std::vector<int> v2 = {0, 1, 2, 3};
std::vector<int> v3 = {0};
template <typename ContainerType0, typename ContainerType1,
typename ContainerType2>
void test_has_duplicate_values() {
ContainerType0 v0 = {0, 1, 1};
ContainerType1 v1 = {0, 1, 1, 1};
ContainerType1 v2 = {0, 1, 2, 3};
ContainerType2 v3 = {0};

EXPECT_TRUE(KokkosFFT::Impl::has_duplicate_values(v0));
EXPECT_TRUE(KokkosFFT::Impl::has_duplicate_values(v1));
EXPECT_FALSE(KokkosFFT::Impl::has_duplicate_values(v2));
EXPECT_FALSE(KokkosFFT::Impl::has_duplicate_values(v3));
}

TEST(IsOutOfRangeValueIncluded, Array) {
std::vector<int> v = {0, 1, 2, 3};
template <typename ContainerType>
void test_is_out_of_range_value_included() {
using IntType = KokkosFFT::Impl::base_container_value_type<ContainerType>;
ContainerType v = {0, 1, 2, 3, 4};

EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(
v, static_cast<IntType>(2)));
EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(
v, static_cast<IntType>(3)));
EXPECT_FALSE(KokkosFFT::Impl::is_out_of_range_value_included(
v, static_cast<IntType>(5)));
EXPECT_FALSE(KokkosFFT::Impl::is_out_of_range_value_included(
v, static_cast<IntType>(6)));
}

template <typename IntType>
void test_are_valid_axes() {
using real_type = double;
using View1DType = Kokkos::View<real_type*>;
using View2DType = Kokkos::View<real_type**>;
using View3DType = Kokkos::View<real_type***>;
using View4DType = Kokkos::View<real_type****>;

std::array<IntType, 1> axes0 = {0};
std::array<IntType, 2> axes1 = {0, 1};
std::array<IntType, 3> axes2 = {0, 1, 2};
std::array<IntType, 3> axes3 = {0, 1, 1};
std::array<IntType, 3> axes4 = {0, 1, 3};

View1DType view1;
View2DType view2;
View3DType view3;
View4DType view4;

// 1D axes on 1D+ Views
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view1, axes0));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view2, axes0));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view3, axes0));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view4, axes0));

// 2D axes on 2D+ Views
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view2, axes1));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view3, axes1));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view4, axes1));

// 3D axes on 3D+ Views
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view3, axes2));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view4, axes2));
EXPECT_TRUE(KokkosFFT::Impl::are_valid_axes(view4, axes4));

// 3D axes on 3D Views with out of range -> should throw
EXPECT_THROW(KokkosFFT::Impl::are_valid_axes(view3, axes4),
std::runtime_error);

// axes include overlap -> should fail
EXPECT_FALSE(KokkosFFT::Impl::are_valid_axes(view3, axes3));
EXPECT_FALSE(KokkosFFT::Impl::are_valid_axes(view4, axes3));
}

TYPED_TEST(ContainerTypes, is_found_from_vector) {
using container_type = typename TestFixture::vector_type;
test_is_found<container_type>();
}

TYPED_TEST(ContainerTypes, is_found_from_array) {
using container_type = typename TestFixture::array_type;
test_is_found<container_type>();
}

TYPED_TEST(ContainerTypes, get_index_from_vector) {
using container_type = typename TestFixture::vector_type;
test_get_index<container_type>();
}

TYPED_TEST(ContainerTypes, get_index_from_array) {
using container_type = typename TestFixture::array_type;
test_get_index<container_type>();
}

TYPED_TEST(ContainerTypes, has_duplicate_values_in_vector) {
using container_type = typename TestFixture::vector_type;
test_has_duplicate_values<container_type, container_type, container_type>();
}

TYPED_TEST(ContainerTypes, has_duplicate_values_in_array) {
using value_type = typename TestFixture::value_type;
using container_type0 = std::array<value_type, 3>;
using container_type1 = std::array<value_type, 4>;
using container_type2 = std::array<value_type, 1>;
test_has_duplicate_values<container_type0, container_type1,
container_type2>();
}

TYPED_TEST(ContainerTypes, is_out_of_range_value_included_in_vector) {
using container_type = typename TestFixture::vector_type;
test_is_out_of_range_value_included<container_type>();
}

TYPED_TEST(ContainerTypes, is_out_of_range_value_included_in_array) {
using container_type = typename TestFixture::array_type;
test_is_out_of_range_value_included<container_type>();
}

EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 2));
EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 3));
EXPECT_FALSE(KokkosFFT::Impl::is_out_of_range_value_included(v, 4));
EXPECT_FALSE(KokkosFFT::Impl::is_out_of_range_value_included(v, 5));
TYPED_TEST(ContainerTypes, are_valid_axes) {
using value_type = typename TestFixture::value_type;
test_are_valid_axes<value_type>();
}

TEST(ExtractExtents, 1Dto8D) {
Expand Down

0 comments on commit 1cbed11

Please sign in to comment.