diff --git a/include/sparrow/array/data_traits.hpp b/include/sparrow/array/data_traits.hpp index 6fc887fd..70d831eb 100644 --- a/include/sparrow/array/data_traits.hpp +++ b/include/sparrow/array/data_traits.hpp @@ -14,6 +14,8 @@ #pragma once +#include + #include "sparrow/array/data_type.hpp" #include "sparrow/layout/fixed_size_layout.hpp" #include "sparrow/layout/null_layout.hpp" @@ -39,83 +41,14 @@ namespace sparrow using default_layout = null_layout; }; - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::BOOL; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::UINT8; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::INT8; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::UINT8; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::UINT16; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::INT16; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::UINT32; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::INT32; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::UINT64; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::INT64; - }; - - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::HALF_FLOAT; - }; - - template <> - struct arrow_traits : common_native_types_traits + // Define automatically all standard floating-point and integral types support, including `bool`. + template + requires std::integral or std::floating_point + struct arrow_traits : common_native_types_traits { - static constexpr data_type type_id = data_type::FLOAT; + static constexpr data_type type_id = data_type_from_size(); }; - template <> - struct arrow_traits : common_native_types_traits - { - static constexpr data_type type_id = data_type::DOUBLE; - }; template <> struct arrow_traits diff --git a/include/sparrow/array/data_type.hpp b/include/sparrow/array/data_type.hpp index dbee9d10..09c8f588 100644 --- a/include/sparrow/array/data_type.hpp +++ b/include/sparrow/array/data_type.hpp @@ -26,6 +26,7 @@ namespace date = std::chrono; #include #include #include +#include #include #include @@ -249,6 +250,63 @@ namespace sparrow mpl::unreachable(); } + /// @returns The default floating-point `data_type` that should be associated with the provided type. + /// The deduction will be based on the size of the type. Calling this function with unsupported sizes + /// will not compile. + template + requires (sizeof(T) >= 2 && sizeof(T) <= 8) + constexpr data_type data_type_from_size(T = {}) + { + // TODO: consider rewriting this to benefit from if constexpr? might not be necessary + switch(sizeof(T)) + { + case 2: return data_type::HALF_FLOAT; + case 4: return data_type::FLOAT; + case 8: return data_type::DOUBLE; + } + + mpl::unreachable(); + } + + /// @returns The default integral `data_type` that should be associated with the provided type. + /// The deduction will be based on the size of the type. Calling this function with unsupported + /// sizes will not compile. + template + requires(sizeof(T) >= 1 && sizeof(T) <= 8) + constexpr data_type data_type_from_size(T = {}) + { + if constexpr (std::same_as) + { + return data_type::BOOL; + } + else if constexpr (std::signed_integral) + { + // TODO: consider rewriting this to benefit from if constexpr? might not be necessary + switch (sizeof(T)) + { + case 1: return data_type::INT8; + case 2: return data_type::INT16; + case 4: return data_type::INT32; + case 8: return data_type::INT64; + } + } + else + { + static_assert(std::unsigned_integral); + + // TODO: consider rewriting this to benefit from if constexpr? might not be necessary + switch (sizeof(T)) + { + case 1: return data_type::UINT8; + case 2: return data_type::UINT16; + case 4: return data_type::UINT32; + case 8: return data_type::UINT64; + } + } + + mpl::unreachable(); + } + /// C++ types value representation types matching Arrow types. // NOTE: this needs to be in sync-order with `data_type` diff --git a/test/test_traits.cpp b/test/test_traits.cpp index 9c9291e4..9a51d7ee 100644 --- a/test/test_traits.cpp +++ b/test/test_traits.cpp @@ -14,6 +14,7 @@ #include + ///////////////////////////////////////////////////////////////////////////////////////// // Opt-in support for custom C++ representations of arrow data types. @@ -39,4 +40,121 @@ namespace sparrow { static_assert(mpl::all_of(all_base_types_t{}, predicate::is_arrow_base_type)); static_assert(mpl::all_of(all_base_types_t{}, predicate::has_arrow_traits)); -} + + +// Native basic standard types support + + using basic_native_types = mpl::typelist< + bool, + char, unsigned char, signed char, + short, unsigned short, + int, unsigned int, + long, unsigned long, // `long long` could be bigger than 64bits and is not supported + float, double, // `long double` could be bigger than 64bit and is not supported + std::uint8_t, + std::int8_t, + std::uint16_t, + std::int16_t, + std::uint32_t, + std::int32_t, + std::uint64_t, + std::int64_t, + float16_t, + float32_t, + float64_t + >; + + template + consteval + bool is_possible_arrow_data_type(data_type type_id) + { + // NOTE: + // `char` is not specified by the C and C++ standard to be `signed` or `unsigned`. + // The language also specifies `signed char` and `unsigned char` as distinct types from `char`. + // Therefore, for `char`, sign-ness can vary depending on the C++ target platform and compiler. + // The best we can check is that `sizeof(char)` matches the associated arrow data + // integral type minimum size, whatever the sign-ness, which is why we don't treat + // `char` as a special case below. + + if constexpr (std::same_as) + { + return type_id == data_type::BOOL; + } + // we need T to be able to store the data coming from an arrow data value + // of the associated arrow data type + else if constexpr (std::is_signed_v) + { + switch (type_id) + { + case data_type::INT8: + return sizeof(T) == 1; + case data_type::INT16: + return sizeof(T) <= 2; + case data_type::INT32: + return sizeof(T) <= 4; + case data_type::INT64: + return sizeof(T) <= 8; + default: + return false; + } + } + else + { + static_assert(std::is_unsigned_v); + + switch (type_id) + { + case data_type::UINT8: + return sizeof(T) == 1; + case data_type::UINT16: + return sizeof(T) <= 2; + case data_type::UINT32: + return sizeof(T) <= 4; + case data_type::UINT64: + return sizeof(T) <= 8; + default: + return false; + } + } + + return false; + } + + template + consteval + bool is_possible_arrow_data_type(data_type type_id) + { + switch (type_id) + { + case data_type::HALF_FLOAT: + return sizeof(T) <= 2; + case data_type::FLOAT: + return sizeof(T) <= 4; + case data_type::DOUBLE: + return sizeof(T) <= 8; + default: + return false; + } + } + + // Tests `data_type_from_size` and it's usage in `arrow_traits::type_id` + struct + { + template + requires has_arrow_type_traits + consteval + bool operator()(mpl::typelist) + { + constexpr auto deduced_type_id = data_type_from_size(); + static_assert(deduced_type_id == arrow_traits::type_id); + + return is_possible_arrow_data_type(arrow_traits::type_id); + } + } constexpr has_possible_arrow_data_type; + + // Every basic native types must have an arrow trait, whatever the platform, + // including when fixed-size standard library names are or not alias to basic types. + // Only exceptions: types that could be bigger than 64bit (`long long`, `long double`, etc.) + static_assert(mpl::all_of(basic_native_types{}, has_possible_arrow_data_type)); + +} \ No newline at end of file