Skip to content

Commit

Permalink
Generic traits for standard base types (partial fix #162) (#163)
Browse files Browse the repository at this point in the history
* Automatic support for any standard floating-point type.

* Generalized support of standard integral types + fixes for floating-point types.

* Removed unnecessary trait specialization

* merged standard integral and flaoting-point traits types

* Added some tests

* fixup tests
  • Loading branch information
Klaim authored Sep 3, 2024
1 parent 54f2a6e commit b0d6a5e
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 75 deletions.
81 changes: 7 additions & 74 deletions include/sparrow/array/data_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <concepts>

#include "sparrow/array/data_type.hpp"
#include "sparrow/layout/fixed_size_layout.hpp"
#include "sparrow/layout/null_layout.hpp"
Expand All @@ -39,83 +41,14 @@ namespace sparrow
using default_layout = null_layout<DS>;
};

template <>
struct arrow_traits<bool> : common_native_types_traits<bool>
{
static constexpr data_type type_id = data_type::BOOL;
};

template <>
struct arrow_traits<std::uint8_t> : common_native_types_traits<std::uint8_t>
{
static constexpr data_type type_id = data_type::UINT8;
};

template <>
struct arrow_traits<std::int8_t> : common_native_types_traits<std::int8_t>
{
static constexpr data_type type_id = data_type::INT8;
};

template <>
struct arrow_traits<char> : common_native_types_traits<char>
{
static constexpr data_type type_id = data_type::UINT8;
};

template <>
struct arrow_traits<std::uint16_t> : common_native_types_traits<std::uint16_t>
{
static constexpr data_type type_id = data_type::UINT16;
};

template <>
struct arrow_traits<std::int16_t> : common_native_types_traits<std::int16_t>
{
static constexpr data_type type_id = data_type::INT16;
};

template <>
struct arrow_traits<std::uint32_t> : common_native_types_traits<std::uint32_t>
{
static constexpr data_type type_id = data_type::UINT32;
};

template <>
struct arrow_traits<std::int32_t> : common_native_types_traits<std::int32_t>
{
static constexpr data_type type_id = data_type::INT32;
};

template <>
struct arrow_traits<std::uint64_t> : common_native_types_traits<std::uint64_t>
{
static constexpr data_type type_id = data_type::UINT64;
};

template <>
struct arrow_traits<std::int64_t> : common_native_types_traits<std::int64_t>
{
static constexpr data_type type_id = data_type::INT64;
};

template <>
struct arrow_traits<float16_t> : common_native_types_traits<float16_t>
{
static constexpr data_type type_id = data_type::HALF_FLOAT;
};

template <>
struct arrow_traits<float32_t> : common_native_types_traits<float32_t>
// Define automatically all standard floating-point and integral types support, including `bool`.
template <class T>
requires std::integral<T> or std::floating_point<T>
struct arrow_traits<T> : common_native_types_traits<T>
{
static constexpr data_type type_id = data_type::FLOAT;
static constexpr data_type type_id = data_type_from_size<T>();
};

template <>
struct arrow_traits<float64_t> : common_native_types_traits<float64_t>
{
static constexpr data_type type_id = data_type::DOUBLE;
};

template <>
struct arrow_traits<std::string>
Expand Down
58 changes: 58 additions & 0 deletions include/sparrow/array/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace date = std::chrono;
#include <climits>
#include <cstdint>
#include <cstring>
#include <concepts>
#include <string>
#include <vector>

Expand Down Expand Up @@ -249,6 +250,63 @@ namespace sparrow
mpl::unreachable();
}

/// @returns The default floating-point `data_type` that should be associated with the provided type.
/// The deduction will be based on the size of the type. Calling this function with unsupported sizes
/// will not compile.
template<std::floating_point T>
requires (sizeof(T) >= 2 && sizeof(T) <= 8)
constexpr data_type data_type_from_size(T = {})
{
// TODO: consider rewriting this to benefit from if constexpr? might not be necessary
switch(sizeof(T))
{
case 2: return data_type::HALF_FLOAT;
case 4: return data_type::FLOAT;
case 8: return data_type::DOUBLE;
}

mpl::unreachable();
}

/// @returns The default integral `data_type` that should be associated with the provided type.
/// The deduction will be based on the size of the type. Calling this function with unsupported
/// sizes will not compile.
template <std::integral T>
requires(sizeof(T) >= 1 && sizeof(T) <= 8)
constexpr data_type data_type_from_size(T = {})
{
if constexpr (std::same_as<bool, T>)
{
return data_type::BOOL;
}
else if constexpr (std::signed_integral<T>)
{
// TODO: consider rewriting this to benefit from if constexpr? might not be necessary
switch (sizeof(T))
{
case 1: return data_type::INT8;
case 2: return data_type::INT16;
case 4: return data_type::INT32;
case 8: return data_type::INT64;
}
}
else
{
static_assert(std::unsigned_integral<T>);

// TODO: consider rewriting this to benefit from if constexpr? might not be necessary
switch (sizeof(T))
{
case 1: return data_type::UINT8;
case 2: return data_type::UINT16;
case 4: return data_type::UINT32;
case 8: return data_type::UINT64;
}
}

mpl::unreachable();
}


/// C++ types value representation types matching Arrow types.
// NOTE: this needs to be in sync-order with `data_type`
Expand Down
120 changes: 119 additions & 1 deletion test/test_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <sparrow/array/data_traits.hpp>


/////////////////////////////////////////////////////////////////////////////////////////
// Opt-in support for custom C++ representations of arrow data types.

Expand All @@ -39,4 +40,121 @@ namespace sparrow
{
static_assert(mpl::all_of(all_base_types_t{}, predicate::is_arrow_base_type));
static_assert(mpl::all_of(all_base_types_t{}, predicate::has_arrow_traits));
}


// Native basic standard types support

using basic_native_types = mpl::typelist<
bool,
char, unsigned char, signed char,
short, unsigned short,
int, unsigned int,
long, unsigned long, // `long long` could be bigger than 64bits and is not supported
float, double, // `long double` could be bigger than 64bit and is not supported
std::uint8_t,
std::int8_t,
std::uint16_t,
std::int16_t,
std::uint32_t,
std::int32_t,
std::uint64_t,
std::int64_t,
float16_t,
float32_t,
float64_t
>;

template <std::integral T>
consteval
bool is_possible_arrow_data_type(data_type type_id)
{
// NOTE:
// `char` is not specified by the C and C++ standard to be `signed` or `unsigned`.
// The language also specifies `signed char` and `unsigned char` as distinct types from `char`.
// Therefore, for `char`, sign-ness can vary depending on the C++ target platform and compiler.
// The best we can check is that `sizeof(char)` matches the associated arrow data
// integral type minimum size, whatever the sign-ness, which is why we don't treat
// `char` as a special case below.

if constexpr (std::same_as<T, bool>)
{
return type_id == data_type::BOOL;
}
// we need T to be able to store the data coming from an arrow data value
// of the associated arrow data type
else if constexpr (std::is_signed_v<T>)
{
switch (type_id)
{
case data_type::INT8:
return sizeof(T) == 1;
case data_type::INT16:
return sizeof(T) <= 2;
case data_type::INT32:
return sizeof(T) <= 4;
case data_type::INT64:
return sizeof(T) <= 8;
default:
return false;
}
}
else
{
static_assert(std::is_unsigned_v<T>);

switch (type_id)
{
case data_type::UINT8:
return sizeof(T) == 1;
case data_type::UINT16:
return sizeof(T) <= 2;
case data_type::UINT32:
return sizeof(T) <= 4;
case data_type::UINT64:
return sizeof(T) <= 8;
default:
return false;
}
}

return false;
}

template <std::floating_point T>
consteval
bool is_possible_arrow_data_type(data_type type_id)
{
switch (type_id)
{
case data_type::HALF_FLOAT:
return sizeof(T) <= 2;
case data_type::FLOAT:
return sizeof(T) <= 4;
case data_type::DOUBLE:
return sizeof(T) <= 8;
default:
return false;
}
}

// Tests `data_type_from_size` and it's usage in `arrow_traits<T>::type_id`
struct
{
template <class T>
requires has_arrow_type_traits<T>
consteval
bool operator()(mpl::typelist<T>)
{
constexpr auto deduced_type_id = data_type_from_size<T>();
static_assert(deduced_type_id == arrow_traits<T>::type_id);

return is_possible_arrow_data_type<T>(arrow_traits<T>::type_id);
}
} constexpr has_possible_arrow_data_type;

// Every basic native types must have an arrow trait, whatever the platform,
// including when fixed-size standard library names are or not alias to basic types.
// Only exceptions: types that could be bigger than 64bit (`long long`, `long double`, etc.)
static_assert(mpl::all_of(basic_native_types{}, has_possible_arrow_data_type));

}

0 comments on commit b0d6a5e

Please sign in to comment.