Skip to content

Commit

Permalink
Add a trait to extract the base value type from a container
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jul 31, 2024
1 parent 0b1ab3d commit 4a1fa0d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
25 changes: 25 additions & 0 deletions common/src/KokkosFFT_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ inline constexpr bool are_operatable_views_v =

// Other traits

template <typename ContainerType>
struct base_container_value;

template <template <typename, typename...> class ContainerType,
typename ValueType, typename... Args>
struct base_container_value<ContainerType<ValueType, Args...>> {
using value_type = ValueType;
};

// Specialization for std::array
template <typename ValueType, std::size_t N>
struct base_container_value<std::array<ValueType, N>> {
using value_type = ValueType;
};

// Specialization for Kokkos::Array
template <typename ValueType, std::size_t N>
struct base_container_value<Kokkos::Array<ValueType, N>> {
using value_type = ValueType;
};

/// \brief Helper to extract the base value type from a container
template <typename T>
using base_container_value_type = typename base_container_value<T>::value_type;

/// \brief Helper to define a managable View type from the original view type
template <typename T>
struct managable_view_type {
Expand Down
49 changes: 49 additions & 0 deletions common/unit_test/Test_Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// All the tests in this file are compile time tests, so we skip all the tests
// by GTEST_SKIP(). gtest is used for type parameterization.

// Int like types
using base_int_types = ::testing::Types<int, std::size_t>;

// Define the types to combine
using base_real_types = std::tuple<float, double, long double>;

Expand Down Expand Up @@ -39,6 +42,19 @@ using paired_view_types =
tuple_to_types_t<cartesian_product_t<base_real_types, base_layout_types,
base_real_types, base_layout_types>>;

template <typename T>
struct ContainerTypes : public ::testing::Test {
static constexpr std::size_t rank = 3;
using value_type = T;
using vector_type = std::vector<T>;
using std_array_type = std::array<T, rank>;
using Kokkos_array_type = Kokkos::Array<T, rank>;

virtual void SetUp() {
GTEST_SKIP() << "Skipping all tests for this fixture";
}
};

template <typename T>
struct RealAndComplexTypes : public ::testing::Test {
using real_type = T;
Expand Down Expand Up @@ -91,12 +107,45 @@ struct PairedViewTypes : public ::testing::Test {
}
};

TYPED_TEST_SUITE(ContainerTypes, base_int_types);
TYPED_TEST_SUITE(RealAndComplexTypes, real_types);
TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types);
TYPED_TEST_SUITE(PairedValueTypes, paired_value_types);
TYPED_TEST_SUITE(PairedLayoutTypes, paired_layout_types);
TYPED_TEST_SUITE(PairedViewTypes, paired_view_types);

// Tests for base value type deduction
template <typename ValueType, typename ContainerType>
void test_get_container_value_type() {
using value_type_ContainerType =
KokkosFFT::Impl::base_container_value_type<ContainerType>;

// base value type of ContainerType is ValueType
static_assert(std::is_same_v<value_type_ContainerType, ValueType>,
"Value type not deduced correctly from ContainerType");
}

TYPED_TEST(ContainerTypes, get_value_type_from_vector) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::vector_type;

test_get_container_value_type<value_type, container_type>();
}

TYPED_TEST(ContainerTypes, get_value_type_from_std_array) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::std_array_type;

test_get_container_value_type<value_type, container_type>();
}

TYPED_TEST(ContainerTypes, get_value_type_from_kokkos_array) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::Kokkos_array_type;

test_get_container_value_type<value_type, container_type>();
}

// Tests for real type deduction
template <typename RealType, typename ComplexType>
void test_get_real_type() {
Expand Down

0 comments on commit 4a1fa0d

Please sign in to comment.