diff --git a/include/sparrow/arrow_array_schema_proxy.hpp b/include/sparrow/arrow_array_schema_proxy.hpp index 757b582df..d6d3c7c60 100644 --- a/include/sparrow/arrow_array_schema_proxy.hpp +++ b/include/sparrow/arrow_array_schema_proxy.hpp @@ -16,190 +16,16 @@ #include #include -#include "sparrow/array/data_type.hpp" #include "sparrow/arrow_interface/arrow_array.hpp" +#include "sparrow/arrow_interface/arrow_array_schema_proxy_utils.hpp" +#include "sparrow/arrow_interface/arrow_flag_utils.hpp" #include "sparrow/arrow_interface/arrow_schema.hpp" #include "sparrow/buffer/buffer_view.hpp" -#include "sparrow/c_interface.hpp" #include "sparrow/utils/contracts.hpp" #include "sparrow/utils/mp_utils.hpp" namespace sparrow { - constexpr bool is_valid_ArrowFlag_value(int64_t value) noexcept - { - constexpr std::array valid_values = { - ArrowFlag::DICTIONARY_ORDERED, - ArrowFlag::NULLABLE, - ArrowFlag::MAP_KEYS_SORTED - }; - return std::ranges::any_of( - valid_values, - [value](ArrowFlag flag) - { - return static_cast>(flag) == value; - } - ); - } - - constexpr std::vector to_vector_of_ArrowFlags(int64_t flag_values) - { - constexpr size_t n_bits = sizeof(flag_values) * 8; - std::vector flags; - for (size_t i = 0; i < n_bits; ++i) - { - const int64_t flag_value = static_cast(1) << i; - if ((flag_values & flag_value) != 0) - { - if (!is_valid_ArrowFlag_value(flag_value)) - { - // TODO: Replace with a more specific exception - throw std::runtime_error("Invalid ArrowFlag value"); - } - flags.push_back(static_cast(flag_value)); - } - } - return flags; - } - - constexpr int64_t to_ArrowFlag_value(const std::vector& flags) - { - int64_t flag_values = 0; - for (const ArrowFlag flag : flags) - { - flag_values |= static_cast(flag); - } - return flag_values; - } - - constexpr bool validate_buffers_count(data_type data_type, int64_t n_buffers) - { - const std::size_t expected_buffer_count = get_expected_buffer_count(data_type); - return static_cast(n_buffers) == expected_buffer_count; - } - - constexpr std::size_t get_expected_children_count(data_type data_type) - { - switch (data_type) - { - case data_type::NA: - case data_type::RUN_ENCODED: - case data_type::BOOL: - case data_type::UINT8: - case data_type::INT8: - case data_type::UINT16: - case data_type::INT16: - case data_type::UINT32: - case data_type::INT32: - case data_type::FLOAT: - case data_type::UINT64: - case data_type::INT64: - case data_type::DOUBLE: - case data_type::HALF_FLOAT: - case data_type::TIMESTAMP: - case data_type::FIXED_SIZE_BINARY: - case data_type::DECIMAL: - case data_type::FIXED_WIDTH_BINARY: - case data_type::STRING: - return 0; - case data_type::LIST: - case data_type::LARGE_LIST: - case data_type::LIST_VIEW: - case data_type::LARGE_LIST_VIEW: - case data_type::FIXED_SIZED_LIST: - case data_type::STRUCT: - case data_type::MAP: - case data_type::SPARSE_UNION: - return 1; - case data_type::DENSE_UNION: - return 2; - } - mpl::unreachable(); - } - - bool validate_format_with_arrow_array(data_type data_type, const ArrowArray& array) - { - const bool buffers_count_valid = validate_buffers_count(data_type, array.n_buffers); - const bool children_count_valid = static_cast(array.n_children) - == get_expected_children_count(data_type); - return buffers_count_valid && children_count_valid; - } - - enum class buffer_type - { - VALIDITY, - DATA, - OFFSETS_32BIT, - OFFSETS_64BIT, - VIEWS, - TYPE_IDS, - SIZES_32BIT, - SIZES_64BIT, - }; - - constexpr std::vector get_buffer_types_from_data_type(data_type data_type) - { - switch (data_type) - { - case data_type::BOOL: - case data_type::UINT8: - case data_type::INT8: - case data_type::UINT16: - case data_type::INT16: - case data_type::UINT32: - case data_type::INT32: - case data_type::FLOAT: - case data_type::UINT64: - case data_type::INT64: - case data_type::DOUBLE: - case data_type::HALF_FLOAT: - case data_type::TIMESTAMP: - case data_type::FIXED_SIZE_BINARY: - case data_type::DECIMAL: - case data_type::FIXED_WIDTH_BINARY: - return {buffer_type::VALIDITY, buffer_type::DATA}; - case data_type::STRING: - return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT, buffer_type::DATA}; - case data_type::LIST: - return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT}; - case data_type::LARGE_LIST: - return {buffer_type::VALIDITY, buffer_type::OFFSETS_64BIT}; - case data_type::LIST_VIEW: - return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT, buffer_type::SIZES_32BIT}; - case data_type::LARGE_LIST_VIEW: - return {buffer_type::VALIDITY, buffer_type::OFFSETS_64BIT, buffer_type::SIZES_64BIT}; - case data_type::FIXED_SIZED_LIST: - case data_type::STRUCT: - return {buffer_type::VALIDITY}; - case data_type::SPARSE_UNION: - return {buffer_type::TYPE_IDS}; - case data_type::DENSE_UNION: - return {buffer_type::TYPE_IDS, buffer_type::OFFSETS_32BIT}; - case data_type::NA: - case data_type::MAP: - case data_type::RUN_ENCODED: - return {}; - } - mpl::unreachable(); - } - - constexpr std::size_t get_offset_size(data_type data_type, int64_t length) - { - switch (data_type) - { - case data_type::STRING: - case data_type::LIST: - case data_type::LARGE_LIST: - return static_cast(length) + 1; - case data_type::LIST_VIEW: - case data_type::LARGE_LIST_VIEW: - case data_type::DENSE_UNION: - return static_cast(length); - default: - throw std::runtime_error("Unsupported data type"); - } - } - class arrow_proxy { public: @@ -207,7 +33,6 @@ namespace sparrow explicit arrow_proxy(ArrowArray&& array, ArrowSchema&& schema); explicit arrow_proxy(ArrowArray&& array, ArrowSchema* schema); explicit arrow_proxy(ArrowArray* array, ArrowSchema* schema); - ~arrow_proxy(); @@ -226,9 +51,6 @@ namespace sparrow [[nodiscard]] bool is_created_with_sparrow() const; - void release(); - [[nodiscard]] bool is_released() const; - [[nodiscard]] void* private_data() const; private: @@ -327,24 +149,23 @@ namespace sparrow arrow_proxy::~arrow_proxy() { - if(m_array.index() == 1) + if (m_array.index() == 1) { ArrowArray& array = std::get<1>(m_array); - if(array.release != nullptr) + if (array.release != nullptr) { array.release(&array); } } - if(m_schema.index() == 1) + if (m_schema.index() == 1) { ArrowSchema& schema = std::get<1>(m_schema); - if(schema.release != nullptr) + if (schema.release != nullptr) { schema.release(&schema); } } - } [[nodiscard]] std::string_view arrow_proxy::format() const @@ -437,23 +258,6 @@ namespace sparrow return arrow_proxy{array().dictionary, schema().dictionary}; } - void arrow_proxy::release() - { - if (array().release != nullptr) - { - array().release(&array()); - } - if (schema().release != nullptr) - { - schema().release(&schema()); - } - } - - [[nodiscard]] bool arrow_proxy::is_released() const - { - return array().release == nullptr && schema().release == nullptr; - } - [[nodiscard]] bool arrow_proxy::is_created_with_sparrow() const { return (array().release == &sparrow::release_arrow_array) diff --git a/include/sparrow/arrow_interface/arrow_array_schema_proxy_utils.hpp b/include/sparrow/arrow_interface/arrow_array_schema_proxy_utils.hpp new file mode 100644 index 000000000..bf77312f5 --- /dev/null +++ b/include/sparrow/arrow_interface/arrow_array_schema_proxy_utils.hpp @@ -0,0 +1,149 @@ +// Copyright 2024 Man Group Operations Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "sparrow/array/data_type.hpp" +#include "sparrow/c_interface.hpp" + +namespace sparrow +{ + constexpr bool validate_buffers_count(data_type data_type, int64_t n_buffers) + { + const std::size_t expected_buffer_count = get_expected_buffer_count(data_type); + return static_cast(n_buffers) == expected_buffer_count; + } + + constexpr std::size_t get_expected_children_count(data_type data_type) + { + switch (data_type) + { + case data_type::NA: + case data_type::RUN_ENCODED: + case data_type::BOOL: + case data_type::UINT8: + case data_type::INT8: + case data_type::UINT16: + case data_type::INT16: + case data_type::UINT32: + case data_type::INT32: + case data_type::FLOAT: + case data_type::UINT64: + case data_type::INT64: + case data_type::DOUBLE: + case data_type::HALF_FLOAT: + case data_type::TIMESTAMP: + case data_type::FIXED_SIZE_BINARY: + case data_type::DECIMAL: + case data_type::FIXED_WIDTH_BINARY: + case data_type::STRING: + return 0; + case data_type::LIST: + case data_type::LARGE_LIST: + case data_type::LIST_VIEW: + case data_type::LARGE_LIST_VIEW: + case data_type::FIXED_SIZED_LIST: + case data_type::STRUCT: + case data_type::MAP: + case data_type::SPARSE_UNION: + return 1; + case data_type::DENSE_UNION: + return 2; + } + mpl::unreachable(); + } + + bool validate_format_with_arrow_array(data_type data_type, const ArrowArray& array) + { + const bool buffers_count_valid = validate_buffers_count(data_type, array.n_buffers); + const bool children_count_valid = static_cast(array.n_children) + == get_expected_children_count(data_type); + return buffers_count_valid && children_count_valid; + } + + enum class buffer_type + { + VALIDITY, + DATA, + OFFSETS_32BIT, + OFFSETS_64BIT, + VIEWS, + TYPE_IDS, + SIZES_32BIT, + SIZES_64BIT, + }; + + constexpr std::vector get_buffer_types_from_data_type(data_type data_type) + { + switch (data_type) + { + case data_type::BOOL: + case data_type::UINT8: + case data_type::INT8: + case data_type::UINT16: + case data_type::INT16: + case data_type::UINT32: + case data_type::INT32: + case data_type::FLOAT: + case data_type::UINT64: + case data_type::INT64: + case data_type::DOUBLE: + case data_type::HALF_FLOAT: + case data_type::TIMESTAMP: + case data_type::FIXED_SIZE_BINARY: + case data_type::DECIMAL: + case data_type::FIXED_WIDTH_BINARY: + return {buffer_type::VALIDITY, buffer_type::DATA}; + case data_type::STRING: + return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT, buffer_type::DATA}; + case data_type::LIST: + return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT}; + case data_type::LARGE_LIST: + return {buffer_type::VALIDITY, buffer_type::OFFSETS_64BIT}; + case data_type::LIST_VIEW: + return {buffer_type::VALIDITY, buffer_type::OFFSETS_32BIT, buffer_type::SIZES_32BIT}; + case data_type::LARGE_LIST_VIEW: + return {buffer_type::VALIDITY, buffer_type::OFFSETS_64BIT, buffer_type::SIZES_64BIT}; + case data_type::FIXED_SIZED_LIST: + case data_type::STRUCT: + return {buffer_type::VALIDITY}; + case data_type::SPARSE_UNION: + return {buffer_type::TYPE_IDS}; + case data_type::DENSE_UNION: + return {buffer_type::TYPE_IDS, buffer_type::OFFSETS_32BIT}; + case data_type::NA: + case data_type::MAP: + case data_type::RUN_ENCODED: + return {}; + } + mpl::unreachable(); + } + + constexpr std::size_t get_offset_size(data_type data_type, int64_t length) + { + switch (data_type) + { + case data_type::STRING: + case data_type::LIST: + case data_type::LARGE_LIST: + return static_cast(length) + 1; + case data_type::LIST_VIEW: + case data_type::LARGE_LIST_VIEW: + case data_type::DENSE_UNION: + return static_cast(length); + default: + throw std::runtime_error("Unsupported data type"); + } + } + +} diff --git a/include/sparrow/arrow_interface/arrow_flag_utils.hpp b/include/sparrow/arrow_interface/arrow_flag_utils.hpp new file mode 100644 index 000000000..a5f528af4 --- /dev/null +++ b/include/sparrow/arrow_interface/arrow_flag_utils.hpp @@ -0,0 +1,65 @@ +// Copyright 2024 Man Group Operations Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparrow/c_interface.hpp" + +namespace sparrow +{ + constexpr bool is_valid_ArrowFlag_value(int64_t value) noexcept + { + constexpr std::array valid_values = { + ArrowFlag::DICTIONARY_ORDERED, + ArrowFlag::NULLABLE, + ArrowFlag::MAP_KEYS_SORTED + }; + return std::ranges::any_of( + valid_values, + [value](ArrowFlag flag) + { + return static_cast>(flag) == value; + } + ); + } + + constexpr std::vector to_vector_of_ArrowFlags(int64_t flag_values) + { + constexpr size_t n_bits = sizeof(flag_values) * 8; + std::vector flags; + for (size_t i = 0; i < n_bits; ++i) + { + const int64_t flag_value = static_cast(1) << i; + if ((flag_values & flag_value) != 0) + { + if (!is_valid_ArrowFlag_value(flag_value)) + { + // TODO: Replace with a more specific exception + throw std::runtime_error("Invalid ArrowFlag value"); + } + flags.push_back(static_cast(flag_value)); + } + } + return flags; + } + + constexpr int64_t to_ArrowFlag_value(const std::vector& flags) + { + int64_t flag_values = 0; + for (const ArrowFlag flag : flags) + { + flag_values |= static_cast(flag); + } + return flag_values; + } + +} diff --git a/test/test_arrow_array_schema_proxy.cpp b/test/test_arrow_array_schema_proxy.cpp index e07f20ba8..e65ebfa42 100644 --- a/test/test_arrow_array_schema_proxy.cpp +++ b/test/test_arrow_array_schema_proxy.cpp @@ -195,26 +195,6 @@ TEST_SUITE("ArrowArrowSchemaProxy") CHECK_FALSE(proxy.dictionary().has_value()); } - TEST_CASE("release") - { - auto [schema, array] = make_default_arrow_schema_and_array(); - sparrow::arrow_proxy proxy(&array, &schema); - proxy.release(); - const bool array_release_is_null = array.release == nullptr; - CHECK(array_release_is_null); - const bool schema_release_is_null = schema.release == nullptr; - CHECK(schema_release_is_null); - } - - TEST_CASE("is_released") - { - auto [schema, array] = make_default_arrow_schema_and_array(); - sparrow::arrow_proxy proxy(&array, &schema); - CHECK_FALSE(proxy.is_released()); - proxy.release(); - CHECK(proxy.is_released()); - } - TEST_CASE("is_created_with_sparrow") { auto [schema, array] = make_default_arrow_schema_and_array();