Skip to content

Commit

Permalink
Dedicated files for utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-PLACET committed Sep 4, 2024
1 parent 965cfca commit 9e73425
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 222 deletions.
208 changes: 6 additions & 202 deletions include/sparrow/arrow_array_schema_proxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,198 +16,23 @@
#include <variant>
#include <vector>

#include "sparrow/array/data_type.hpp"
#include "sparrow/arrow_interface/arrow_array.hpp"
#include "sparrow/arrow_interface/arrow_array_schema_proxy_utils.hpp"
#include "sparrow/arrow_interface/arrow_flag_utils.hpp"
#include "sparrow/arrow_interface/arrow_schema.hpp"
#include "sparrow/buffer/buffer_view.hpp"
#include "sparrow/c_interface.hpp"
#include "sparrow/utils/contracts.hpp"
#include "sparrow/utils/mp_utils.hpp"

namespace sparrow
{
constexpr bool is_valid_ArrowFlag_value(int64_t value) noexcept
{
constexpr std::array<ArrowFlag, 3> valid_values = {
ArrowFlag::DICTIONARY_ORDERED,
ArrowFlag::NULLABLE,
ArrowFlag::MAP_KEYS_SORTED
};
return std::ranges::any_of(
valid_values,
[value](ArrowFlag flag)
{
return static_cast<std::underlying_type_t<ArrowFlag>>(flag) == value;
}
);
}

constexpr std::vector<ArrowFlag> to_vector_of_ArrowFlags(int64_t flag_values)
{
constexpr size_t n_bits = sizeof(flag_values) * 8;
std::vector<ArrowFlag> flags;
for (size_t i = 0; i < n_bits; ++i)
{
const int64_t flag_value = static_cast<int64_t>(1) << i;
if ((flag_values & flag_value) != 0)
{
if (!is_valid_ArrowFlag_value(flag_value))
{
// TODO: Replace with a more specific exception
throw std::runtime_error("Invalid ArrowFlag value");
}
flags.push_back(static_cast<ArrowFlag>(flag_value));
}
}
return flags;
}

constexpr int64_t to_ArrowFlag_value(const std::vector<ArrowFlag>& flags)
{
int64_t flag_values = 0;
for (const ArrowFlag flag : flags)
{
flag_values |= static_cast<int64_t>(flag);
}
return flag_values;
}

constexpr bool validate_buffers_count(data_type data_type, int64_t n_buffers)
{
const std::size_t expected_buffer_count = get_expected_buffer_count(data_type);
return static_cast<std::size_t>(n_buffers) == expected_buffer_count;
}

constexpr std::size_t get_expected_children_count(data_type data_type)
{
switch (data_type)
{
case data_type::NA:
case data_type::RUN_ENCODED:
case data_type::BOOL:
case data_type::UINT8:
case data_type::INT8:
case data_type::UINT16:
case data_type::INT16:
case data_type::UINT32:
case data_type::INT32:
case data_type::FLOAT:
case data_type::UINT64:
case data_type::INT64:
case data_type::DOUBLE:
case data_type::HALF_FLOAT:
case data_type::TIMESTAMP:
case data_type::FIXED_SIZE_BINARY:
case data_type::DECIMAL:
case data_type::FIXED_WIDTH_BINARY:
case data_type::STRING:
return 0;
case data_type::LIST:
case data_type::LARGE_LIST:
case data_type::LIST_VIEW:
case data_type::LARGE_LIST_VIEW:
case data_type::FIXED_SIZED_LIST:
case data_type::STRUCT:
case data_type::MAP:
case data_type::SPARSE_UNION:
return 1;
case data_type::DENSE_UNION:
return 2;
}
mpl::unreachable();
}

bool validate_format_with_arrow_array(data_type data_type, const ArrowArray& array)
{
const bool buffers_count_valid = validate_buffers_count(data_type, array.n_buffers);
const bool children_count_valid = static_cast<std::size_t>(array.n_children)
== get_expected_children_count(data_type);
return buffers_count_valid && children_count_valid;
}

enum class buffer_type
{
VALIDITY,
DATA,
OFFSETS_32BIT,
OFFSETS_64BIT,
VIEWS,
TYPE_IDS,
SIZES_32BIT,
SIZES_64BIT,
};

constexpr std::vector<buffer_type> get_buffer_types_from_data_type(data_type data_type)
{
switch (data_type)
{
case data_type::BOOL:
case data_type::UINT8:
case data_type::INT8:
case data_type::UINT16:
case data_type::INT16:
case data_type::UINT32:
case data_type::INT32:
case data_type::FLOAT:
case data_type::UINT64:
case data_type::INT64:
case data_type::DOUBLE:
case data_type::HALF_FLOAT:
case data_type::TIMESTAMP:
case data_type::FIXED_SIZE_BINARY:
case data_type::DECIMAL:
case data_type::FIXED_WIDTH_BINARY:
return {buffer_type::VALIDITY, buffer_type::DATA};
case data_type::STRING:
return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT, buffer_type::DATA};
case data_type::LIST:
return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT};
case data_type::LARGE_LIST:
return {buffer_type::VALIDITY, buffer_type::OFFSETS_64BIT};
case data_type::LIST_VIEW:
return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT, buffer_type::SIZES_32BIT};
case data_type::LARGE_LIST_VIEW:
return {buffer_type::VALIDITY, buffer_type::OFFSETS_64BIT, buffer_type::SIZES_64BIT};
case data_type::FIXED_SIZED_LIST:
case data_type::STRUCT:
return {buffer_type::VALIDITY};
case data_type::SPARSE_UNION:
return {buffer_type::TYPE_IDS};
case data_type::DENSE_UNION:
return {buffer_type::TYPE_IDS, buffer_type::OFFSETS_32BIT};
case data_type::NA:
case data_type::MAP:
case data_type::RUN_ENCODED:
return {};
}
mpl::unreachable();
}

constexpr std::size_t get_offset_size(data_type data_type, int64_t length)
{
switch (data_type)
{
case data_type::STRING:
case data_type::LIST:
case data_type::LARGE_LIST:
return static_cast<std::size_t>(length) + 1;
case data_type::LIST_VIEW:
case data_type::LARGE_LIST_VIEW:
case data_type::DENSE_UNION:
return static_cast<std::size_t>(length);
default:
throw std::runtime_error("Unsupported data type");
}
}

class arrow_proxy
{
public:

explicit arrow_proxy(ArrowArray&& array, ArrowSchema&& schema);
explicit arrow_proxy(ArrowArray&& array, ArrowSchema* schema);
explicit arrow_proxy(ArrowArray* array, ArrowSchema* schema);


~arrow_proxy();

Expand All @@ -226,9 +51,6 @@ namespace sparrow

[[nodiscard]] bool is_created_with_sparrow() const;

void release();
[[nodiscard]] bool is_released() const;

[[nodiscard]] void* private_data() const;

private:
Expand Down Expand Up @@ -327,24 +149,23 @@ namespace sparrow

arrow_proxy::~arrow_proxy()
{
if(m_array.index() == 1)
if (m_array.index() == 1)
{
ArrowArray& array = std::get<1>(m_array);
if(array.release != nullptr)
if (array.release != nullptr)
{
array.release(&array);
}
}

if(m_schema.index() == 1)
if (m_schema.index() == 1)
{
ArrowSchema& schema = std::get<1>(m_schema);
if(schema.release != nullptr)
if (schema.release != nullptr)
{
schema.release(&schema);
}
}

}

[[nodiscard]] std::string_view arrow_proxy::format() const
Expand Down Expand Up @@ -437,23 +258,6 @@ namespace sparrow
return arrow_proxy{array().dictionary, schema().dictionary};
}

void arrow_proxy::release()
{
if (array().release != nullptr)
{
array().release(&array());
}
if (schema().release != nullptr)
{
schema().release(&schema());
}
}

[[nodiscard]] bool arrow_proxy::is_released() const
{
return array().release == nullptr && schema().release == nullptr;
}

[[nodiscard]] bool arrow_proxy::is_created_with_sparrow() const
{
return (array().release == &sparrow::release_arrow_array)
Expand Down
Loading

0 comments on commit 9e73425

Please sign in to comment.