diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index 639e66785eb67..1226dcca9c93a 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -30,6 +30,9 @@ template void FieldDataImpl::FillFieldData(const void* source, ssize_t element_count) { + AssertInfo(!nullable_, + "need to fill valid_data, use the 3-argument version instead"); + if (element_count == 0) { return; } @@ -44,6 +47,37 @@ FieldDataImpl::FillFieldData(const void* source, length_ += element_count; } +template +void +FieldDataImpl::FillFieldData( + const void* field_data, const uint8_t* valid_data, ssize_t element_count) { + AssertInfo( + nullable_, + "no need to fill valid_data, use the 2-argument version instead"); + if (element_count == 0) { + return; + } + + std::lock_guard lck(tell_mutex_); + if (length_ + element_count > get_num_rows()) { + resize_field_data(length_ + element_count); + } + std::copy_n(static_cast(field_data), + element_count * dim_, + field_data_.data() + length_ * dim_); + + ssize_t byte_count = (element_count + 7) / 8; + // Note: if 'nullable == true` and valid_data is nullptr + // means null_count == 0, will fill it with 0xFF + if (valid_data == nullptr) { + std::fill_n(valid_data_.get(), byte_count, 0xFF); + } else { + std::copy_n(valid_data, byte_count, valid_data_.get()); + } + + length_ += element_count; +} + template std::pair GetDataInfoFromArray(const std::shared_ptr array) { @@ -66,6 +100,7 @@ FieldDataImpl::FillFieldData( if (element_count == 0) { return; } + null_count = array->null_count(); switch (data_type_) { case DataType::BOOL: { AssertInfo(array->type()->id() == arrow::Type::type::BOOL, @@ -76,42 +111,71 @@ FieldDataImpl::FillFieldData( for (size_t index = 0; index < element_count; ++index) { values[index] = bool_array->Value(index); } + if (nullable_) { + return FillFieldData(values.data(), + bool_array->null_bitmap_data(), + element_count); + } return FillFieldData(values.data(), element_count); } case DataType::INT8: { auto array_info = GetDataInfoFromArray( array); + if (nullable_) { + return FillFieldData( + array_info.first, array->null_bitmap_data(), element_count); + } return FillFieldData(array_info.first, array_info.second); } case DataType::INT16: { auto array_info = GetDataInfoFromArray(array); + if (nullable_) { + return FillFieldData( + array_info.first, array->null_bitmap_data(), element_count); + } return FillFieldData(array_info.first, array_info.second); } case DataType::INT32: { auto array_info = GetDataInfoFromArray(array); + if (nullable_) { + return FillFieldData( + array_info.first, array->null_bitmap_data(), element_count); + } return FillFieldData(array_info.first, array_info.second); } case DataType::INT64: { auto array_info = GetDataInfoFromArray(array); + if (nullable_) { + return FillFieldData( + array_info.first, array->null_bitmap_data(), element_count); + } return FillFieldData(array_info.first, array_info.second); } case DataType::FLOAT: { auto array_info = GetDataInfoFromArray(array); + if (nullable_) { + return FillFieldData( + array_info.first, array->null_bitmap_data(), element_count); + } return FillFieldData(array_info.first, array_info.second); } case DataType::DOUBLE: { auto array_info = GetDataInfoFromArray(array); + if (nullable_) { + return FillFieldData( + array_info.first, array->null_bitmap_data(), element_count); + } return FillFieldData(array_info.first, array_info.second); } case DataType::STRING: @@ -124,6 +188,10 @@ FieldDataImpl::FillFieldData( for (size_t index = 0; index < element_count; ++index) { values[index] = string_array->GetString(index); } + if (nullable_) { + return FillFieldData( + values.data(), array->null_bitmap_data(), element_count); + } return FillFieldData(values.data(), element_count); } case DataType::JSON: { @@ -136,17 +204,33 @@ FieldDataImpl::FillFieldData( values[index] = Json(simdjson::padded_string(json_array->GetString(index))); } + if (nullable_) { + return FillFieldData( + values.data(), array->null_bitmap_data(), element_count); + } return FillFieldData(values.data(), element_count); } case DataType::ARRAY: { auto array_array = std::dynamic_pointer_cast(array); std::vector values(element_count); + int null_number = 0; for (size_t index = 0; index < element_count; ++index) { ScalarArray field_data; - field_data.ParseFromString(array_array->GetString(index)); + if (array_array->GetString(index) == "") { + null_number++; + continue; + } + auto success = + field_data.ParseFromString(array_array->GetString(index)); + AssertInfo(success, "parse from string failed"); values[index] = Array(field_data); } + if (nullable_) { + return FillFieldData( + values.data(), array->null_bitmap_data(), element_count); + } + AssertInfo(null_number == 0, "get empty string when not nullable"); return FillFieldData(values.data(), element_count); } case DataType::VECTOR_FLOAT: @@ -201,27 +285,33 @@ template class FieldDataImpl; template class FieldDataImpl, true>; FieldDataPtr -InitScalarFieldData(const DataType& type, int64_t cap_rows) { +InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows) { switch (type) { case DataType::BOOL: - return std::make_shared>(type, cap_rows); + return std::make_shared>(type, nullable, cap_rows); case DataType::INT8: - return std::make_shared>(type, cap_rows); + return std::make_shared>( + type, nullable, cap_rows); case DataType::INT16: - return std::make_shared>(type, cap_rows); + return std::make_shared>( + type, nullable, cap_rows); case DataType::INT32: - return std::make_shared>(type, cap_rows); + return std::make_shared>( + type, nullable, cap_rows); case DataType::INT64: - return std::make_shared>(type, cap_rows); + return std::make_shared>( + type, nullable, cap_rows); case DataType::FLOAT: - return std::make_shared>(type, cap_rows); + return std::make_shared>(type, nullable, cap_rows); case DataType::DOUBLE: - return std::make_shared>(type, cap_rows); + return std::make_shared>( + type, nullable, cap_rows); case DataType::STRING: case DataType::VARCHAR: - return std::make_shared>(type, cap_rows); + return std::make_shared>( + type, nullable, cap_rows); case DataType::JSON: - return std::make_shared>(type, cap_rows); + return std::make_shared>(type, nullable, cap_rows); default: throw NotSupportedDataTypeException( "InitScalarFieldData not support data type " + diff --git a/internal/core/src/common/FieldData.h b/internal/core/src/common/FieldData.h index 60e0c74b3ad56..12719dac7865d 100644 --- a/internal/core/src/common/FieldData.h +++ b/internal/core/src/common/FieldData.h @@ -30,9 +30,11 @@ template class FieldData : public FieldDataImpl { public: static_assert(IsScalar || std::is_same_v); - explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0) + explicit FieldData(DataType data_type, + bool nullable, + int64_t buffered_num_rows = 0) : FieldDataImpl::FieldDataImpl( - 1, data_type, buffered_num_rows) { + 1, data_type, nullable, buffered_num_rows) { } static_assert(IsScalar || std::is_same_v); explicit FieldData(DataType data_type, FixedVector&& inner_data) @@ -45,8 +47,10 @@ template <> class FieldData : public FieldDataStringImpl { public: static_assert(IsScalar || std::is_same_v); - explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0) - : FieldDataStringImpl(data_type, buffered_num_rows) { + explicit FieldData(DataType data_type, + bool nullable, + int64_t buffered_num_rows = 0) + : FieldDataStringImpl(data_type, nullable, buffered_num_rows) { } }; @@ -54,8 +58,10 @@ template <> class FieldData : public FieldDataJsonImpl { public: static_assert(IsScalar || std::is_same_v); - explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0) - : FieldDataJsonImpl(data_type, buffered_num_rows) { + explicit FieldData(DataType data_type, + bool nullable, + int64_t buffered_num_rows = 0) + : FieldDataJsonImpl(data_type, nullable, buffered_num_rows) { } }; @@ -63,8 +69,10 @@ template <> class FieldData : public FieldDataArrayImpl { public: static_assert(IsScalar || std::is_same_v); - explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0) - : FieldDataArrayImpl(data_type, buffered_num_rows) { + explicit FieldData(DataType data_type, + bool nullable, + int64_t buffered_num_rows = 0) + : FieldDataArrayImpl(data_type, nullable, buffered_num_rows) { } }; @@ -75,7 +83,7 @@ class FieldData : public FieldDataImpl { DataType data_type, int64_t buffered_num_rows = 0) : FieldDataImpl::FieldDataImpl( - dim, data_type, buffered_num_rows) { + dim, data_type, false, buffered_num_rows) { } }; @@ -86,7 +94,7 @@ class FieldData : public FieldDataImpl { DataType data_type, int64_t buffered_num_rows = 0) : binary_dim_(dim), - FieldDataImpl(dim / 8, data_type, buffered_num_rows) { + FieldDataImpl(dim / 8, data_type, false, buffered_num_rows) { Assert(dim % 8 == 0); } @@ -106,7 +114,7 @@ class FieldData : public FieldDataImpl { DataType data_type, int64_t buffered_num_rows = 0) : FieldDataImpl::FieldDataImpl( - dim, data_type, buffered_num_rows) { + dim, data_type, false, buffered_num_rows) { } }; @@ -134,6 +142,6 @@ using FieldDataChannel = Channel; using FieldDataChannelPtr = std::shared_ptr; FieldDataPtr -InitScalarFieldData(const DataType& type, int64_t cap_rows); +InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows); } // namespace milvus \ No newline at end of file diff --git a/internal/core/src/common/FieldDataInterface.h b/internal/core/src/common/FieldDataInterface.h index 17916f08e6259..b44693696519c 100644 --- a/internal/core/src/common/FieldDataInterface.h +++ b/internal/core/src/common/FieldDataInterface.h @@ -40,7 +40,8 @@ using DataType = milvus::DataType; class FieldDataBase { public: - explicit FieldDataBase(DataType data_type) : data_type_(data_type) { + explicit FieldDataBase(DataType data_type, bool nullable) + : data_type_(data_type), nullable_(nullable) { } virtual ~FieldDataBase() = default; @@ -49,6 +50,11 @@ class FieldDataBase { virtual void FillFieldData(const void* source, ssize_t element_count) = 0; + virtual void + FillFieldData(const void* field_data, + const uint8_t* valid_data, + ssize_t element_count) = 0; + virtual void FillFieldData(const std::shared_ptr array) = 0; @@ -57,6 +63,9 @@ class FieldDataBase { virtual void* Data() = 0; + virtual const uint8_t* + ValidData() const = 0; + // For all FieldDataImpl subclasses, this method returns a Type* that points // at the offset-th row of this field data. virtual const void* @@ -66,9 +75,15 @@ class FieldDataBase { virtual int64_t Size() const = 0; + virtual int64_t + DataSize() const = 0; + + virtual int64_t + ValidDataSize() const = 0; + // Returns the serialized bytes size of the index-th row. virtual int64_t - Size(ssize_t index) const = 0; + DataSize(ssize_t index) const = 0; // Number of filled rows virtual size_t @@ -77,6 +92,9 @@ class FieldDataBase { virtual bool IsFull() const = 0; + virtual bool + IsNullable() const = 0; + virtual void Reserve(size_t cap) = 0; @@ -94,8 +112,15 @@ class FieldDataBase { return data_type_; } + virtual int64_t + get_null_count() const = 0; + + virtual bool + is_null(ssize_t offset) const = 0; + protected: const DataType data_type_; + const bool nullable_; }; template @@ -112,17 +137,23 @@ class FieldDataImpl : public FieldDataBase { public: explicit FieldDataImpl(ssize_t dim, DataType data_type, + bool nullable, int64_t buffered_num_rows = 0) - : FieldDataBase(data_type), + : FieldDataBase(data_type, nullable), num_rows_(buffered_num_rows), dim_(is_type_entire_row ? 1 : dim) { field_data_.resize(num_rows_ * dim_); + if (nullable) { + valid_data_ = + std::shared_ptr(new uint8_t[(num_rows_ + 7) / 8]); + } } explicit FieldDataImpl(size_t dim, DataType type, + bool nullable, FixedVector&& field_data) - : FieldDataBase(type), dim_(is_type_entire_row ? 1 : dim) { + : FieldDataBase(type, nullable), dim_(is_type_entire_row ? 1 : dim) { field_data_ = std::move(field_data); Assert(field_data.size() % dim == 0); num_rows_ = field_data.size() / dim; @@ -131,6 +162,11 @@ class FieldDataImpl : public FieldDataBase { void FillFieldData(const void* source, ssize_t element_count) override; + void + FillFieldData(const void* field_data, + const uint8_t* valid_data, + ssize_t element_count) override; + void FillFieldData(const std::shared_ptr array) override; @@ -158,6 +194,11 @@ class FieldDataImpl : public FieldDataBase { return field_data_.data(); } + const uint8_t* + ValidData() const override { + return valid_data_.get(); + } + const void* RawValue(ssize_t offset) const override { AssertInfo(offset < get_num_rows(), @@ -167,13 +208,33 @@ class FieldDataImpl : public FieldDataBase { return &field_data_[offset]; } + std::optional + Value(ssize_t offset) { + if (!is_type_entire_row) { + return RawValue(offset); + } + AssertInfo(offset < get_num_rows(), + "field data subscript out of range"); + AssertInfo(offset < length(), + "subscript position don't has valid value"); + if (nullable_ && !valid_data_[offset]) { + return std::nullopt; + } + return &field_data_[offset]; + } + int64_t Size() const override { + return DataSize() + ValidDataSize(); + } + + int64_t + DataSize() const override { return sizeof(Type) * length() * dim_; } int64_t - Size(ssize_t offset) const override { + DataSize(ssize_t offset) const override { AssertInfo(offset < get_num_rows(), "field data subscript out of range"); AssertInfo(offset < length(), @@ -181,6 +242,15 @@ class FieldDataImpl : public FieldDataBase { return sizeof(Type) * dim_; } + int64_t + ValidDataSize() const override { + int byteSize = (length() + 7) / 8; + if (nullable_) { + return sizeof(uint8_t) * byteSize; + } + return 0; + } + size_t Length() const override { return length_; @@ -193,6 +263,11 @@ class FieldDataImpl : public FieldDataBase { return buffered_num_rows == filled_num_rows; } + bool + IsNullable() const override { + return nullable_; + } + void Reserve(size_t cap) override { std::lock_guard lck(num_rows_mutex_); @@ -200,6 +275,9 @@ class FieldDataImpl : public FieldDataBase { num_rows_ = cap; field_data_.resize(num_rows_ * dim_); } + if (nullable_) { + valid_data_ = std::shared_ptr(new uint8_t[num_rows_]); + } } public: @@ -215,6 +293,11 @@ class FieldDataImpl : public FieldDataBase { if (num_rows > num_rows_) { num_rows_ = num_rows; field_data_.resize(num_rows_ * dim_); + if (nullable_) { + ssize_t byte_count = (num_rows + 7) / 8; + valid_data_ = + std::shared_ptr(new uint8_t[byte_count]); + } } } @@ -229,11 +312,29 @@ class FieldDataImpl : public FieldDataBase { return dim_; } + int64_t + get_null_count() const override { + std::shared_lock lck(tell_mutex_); + return null_count; + } + + virtual bool + is_null(ssize_t offset) const override { + std::shared_lock lck(tell_mutex_); + if (!nullable_) { + return false; + } + auto bit = (valid_data_[offset >> 3] >> ((offset & 0x07))) & 1; + return !bit; + } + protected: FixedVector field_data_; + std::shared_ptr valid_data_; // number of elements field_data_ can hold int64_t num_rows_; mutable std::shared_mutex num_rows_mutex_; + int64_t null_count; // number of actual elements in field_data_ size_t length_{}; mutable std::shared_mutex tell_mutex_; @@ -244,12 +345,15 @@ class FieldDataImpl : public FieldDataBase { class FieldDataStringImpl : public FieldDataImpl { public: - explicit FieldDataStringImpl(DataType data_type, int64_t total_num_rows = 0) - : FieldDataImpl(1, data_type, total_num_rows) { + explicit FieldDataStringImpl(DataType data_type, + bool nullable, + int64_t total_num_rows = 0) + : FieldDataImpl( + 1, data_type, nullable, total_num_rows) { } int64_t - Size() const override { + DataSize() const override { int64_t data_size = 0; for (size_t offset = 0; offset < length(); ++offset) { data_size += field_data_[offset].size(); @@ -259,7 +363,7 @@ class FieldDataStringImpl : public FieldDataImpl { } int64_t - Size(ssize_t offset) const override { + DataSize(ssize_t offset) const override { AssertInfo(offset < get_num_rows(), "field data subscript out of range"); AssertInfo(offset < length(), @@ -290,12 +394,14 @@ class FieldDataStringImpl : public FieldDataImpl { class FieldDataJsonImpl : public FieldDataImpl { public: - explicit FieldDataJsonImpl(DataType data_type, int64_t total_num_rows = 0) - : FieldDataImpl(1, data_type, total_num_rows) { + explicit FieldDataJsonImpl(DataType data_type, + bool nullable, + int64_t total_num_rows = 0) + : FieldDataImpl(1, data_type, nullable, total_num_rows) { } int64_t - Size() const override { + DataSize() const override { int64_t data_size = 0; for (size_t offset = 0; offset < length(); ++offset) { data_size += field_data_[offset].data().size(); @@ -305,7 +411,7 @@ class FieldDataJsonImpl : public FieldDataImpl { } int64_t - Size(ssize_t offset) const override { + DataSize(ssize_t offset) const override { AssertInfo(offset < get_num_rows(), "field data subscript out of range"); AssertInfo(offset < length(), @@ -349,16 +455,17 @@ class FieldDataSparseVectorImpl : public FieldDataImpl, true> { public: explicit FieldDataSparseVectorImpl(DataType data_type, + bool nullable, int64_t total_num_rows = 0) : FieldDataImpl, true>( - /*dim=*/1, data_type, total_num_rows), + /*dim=*/1, data_type, nullable, total_num_rows), vec_dim_(0) { AssertInfo(data_type == DataType::VECTOR_SPARSE_FLOAT, "invalid data type for sparse vector"); } int64_t - Size() const override { + DataSize() const override { int64_t data_size = 0; for (size_t i = 0; i < length(); ++i) { data_size += field_data_[i].data_byte_size(); @@ -367,7 +474,7 @@ class FieldDataSparseVectorImpl } int64_t - Size(ssize_t offset) const override { + DataSize(ssize_t offset) const override { AssertInfo(offset < get_num_rows(), "field data subscript out of range"); AssertInfo(offset < length(), @@ -430,8 +537,10 @@ class FieldDataSparseVectorImpl class FieldDataArrayImpl : public FieldDataImpl { public: - explicit FieldDataArrayImpl(DataType data_type, int64_t total_num_rows = 0) - : FieldDataImpl(1, data_type, total_num_rows) { + explicit FieldDataArrayImpl(DataType data_type, + bool nullable, + int64_t total_num_rows = 0) + : FieldDataImpl(1, data_type, nullable, total_num_rows) { } int64_t diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index 4f751f1d68328..39999af68f65d 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -35,27 +35,34 @@ class FieldMeta { FieldMeta& operator=(FieldMeta&&) = default; - FieldMeta(const FieldName& name, FieldId id, DataType type) - : name_(name), id_(id), type_(type) { + FieldMeta(const FieldName& name, FieldId id, DataType type, bool nullable) + : name_(name), id_(id), type_(type), nullable_(nullable) { Assert(!IsVectorDataType(type_)); } FieldMeta(const FieldName& name, FieldId id, DataType type, - int64_t max_length) + int64_t max_length, + bool nullable) : name_(name), id_(id), type_(type), - string_info_(StringInfo{max_length}) { + string_info_(StringInfo{max_length}), + nullable_(nullable) { Assert(IsStringDataType(type_)); } FieldMeta(const FieldName& name, FieldId id, DataType type, - DataType element_type) - : name_(name), id_(id), type_(type), element_type_(element_type) { + DataType element_type, + bool nullable) + : name_(name), + id_(id), + type_(type), + element_type_(element_type), + nullable_(nullable) { Assert(IsArrayDataType(type_)); } @@ -71,6 +78,7 @@ class FieldMeta { type_(type), vector_info_(VectorInfo{dim, std::move(metric_type)}) { Assert(IsVectorDataType(type_)); + nullable_ = false; } int64_t @@ -126,6 +134,11 @@ class FieldMeta { return IsStringDataType(type_); } + bool + is_nullable() const { + return nullable_; + } + size_t get_sizeof() const { AssertInfo(!IsSparseFloatVectorDataType(type_), @@ -156,6 +169,7 @@ class FieldMeta { FieldId id_; DataType type_ = DataType::NONE; DataType element_type_ = DataType::NONE; + bool nullable_; std::optional vector_info_; std::optional string_info_; }; diff --git a/internal/core/src/common/Schema.cpp b/internal/core/src/common/Schema.cpp index 7aa4fc1630bcb..70843149901a9 100644 --- a/internal/core/src/common/Schema.cpp +++ b/internal/core/src/common/Schema.cpp @@ -38,6 +38,7 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) { schema_proto.fields()) { auto field_id = FieldId(child.fieldid()); auto name = FieldName(child.name()); + auto nullable = child.nullable(); if (field_id.get() < 100) { // system field id @@ -70,12 +71,15 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) { AssertInfo(type_map.count(MAX_LENGTH), "max_length not found"); auto max_len = boost::lexical_cast(type_map.at(MAX_LENGTH)); - schema->AddField(name, field_id, data_type, max_len); + schema->AddField(name, field_id, data_type, max_len, nullable); } else if (IsArrayDataType(data_type)) { - schema->AddField( - name, field_id, data_type, DataType(child.element_type())); + schema->AddField(name, + field_id, + data_type, + DataType(child.element_type()), + nullable); } else { - schema->AddField(name, field_id, data_type); + schema->AddField(name, field_id, data_type, nullable); } if (child.is_primary_key()) { @@ -93,6 +97,7 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) { const FieldMeta FieldMeta::RowIdMeta(FieldName("RowID"), RowFieldID, - DataType::INT64); + DataType::INT64, + false); } // namespace milvus diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index b1068dd650392..54d139f510ea3 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -34,20 +34,24 @@ static int64_t debug_id = START_USER_FIELDID; class Schema { public: FieldId - AddDebugField(const std::string& name, DataType data_type) { + AddDebugField(const std::string& name, + DataType data_type, + bool nullable = false) { auto field_id = FieldId(debug_id); debug_id++; - this->AddField(FieldName(name), field_id, data_type); + this->AddField(FieldName(name), field_id, data_type, nullable); return field_id; } FieldId AddDebugField(const std::string& name, DataType data_type, - DataType element_type) { + DataType element_type, + bool nullable = false) { auto field_id = FieldId(debug_id); debug_id++; - this->AddField(FieldName(name), field_id, data_type, element_type); + this->AddField( + FieldName(name), field_id, data_type, element_type, nullable); return field_id; } @@ -67,8 +71,11 @@ class Schema { // scalar type void - AddField(const FieldName& name, const FieldId id, DataType data_type) { - auto field_meta = FieldMeta(name, id, data_type); + AddField(const FieldName& name, + const FieldId id, + DataType data_type, + bool nullable) { + auto field_meta = FieldMeta(name, id, data_type, nullable); this->AddField(std::move(field_meta)); } @@ -77,8 +84,10 @@ class Schema { AddField(const FieldName& name, const FieldId id, DataType data_type, - DataType element_type) { - auto field_meta = FieldMeta(name, id, data_type, element_type); + DataType element_type, + bool nullable) { + auto field_meta = + FieldMeta(name, id, data_type, element_type, nullable); this->AddField(std::move(field_meta)); } @@ -87,8 +96,9 @@ class Schema { AddField(const FieldName& name, const FieldId id, DataType data_type, - int64_t max_length) { - auto field_meta = FieldMeta(name, id, data_type, max_length); + int64_t max_length, + bool nullable) { + auto field_meta = FieldMeta(name, id, data_type, max_length, nullable); this->AddField(std::move(field_meta)); } diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index dab66ffb18a31..3117d0c34f468 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -65,7 +65,8 @@ class ColumnVector final : public BaseVector { size_t length, std::optional null_count = std::nullopt) : BaseVector(data_type, length, null_count) { - values_ = InitScalarFieldData(data_type, length); + //todo(smellthemoon): use false temporarily + values_ = InitScalarFieldData(data_type, false, length); } // ColumnVector(FixedVector&& data) @@ -78,7 +79,7 @@ class ColumnVector final : public BaseVector { ColumnVector(TargetBitmap&& bitmap) : BaseVector(DataType::INT8, bitmap.size()) { values_ = std::make_shared>( - bitmap.size(), DataType::INT8, std::move(bitmap).into()); + bitmap.size(), DataType::INT8, false, std::move(bitmap).into()); } virtual ~ColumnVector() override { diff --git a/internal/core/src/index/ScalarIndexSort.cpp b/internal/core/src/index/ScalarIndexSort.cpp index bcb401ea5bf09..95365770768c7 100644 --- a/internal/core/src/index/ScalarIndexSort.cpp +++ b/internal/core/src/index/ScalarIndexSort.cpp @@ -72,8 +72,15 @@ ScalarIndexSort::BuildV2(const Config& config) { auto data = rec.ValueUnsafe(); auto total_num_rows = data->num_rows(); auto col_data = data->GetColumnByName(field_name); + auto nullable = + col_data->type()->id() == arrow::Type::NA ? true : false; + // will support build scalar index when nullable in the future just skip it + // now, not support to build index in nullable field_data + // todo(smellthemoon) + AssertInfo(!nullable, + "not support to build index in nullable field_data"); auto field_data = storage::CreateFieldData( - DataType(GetDType()), 0, total_num_rows); + DataType(GetDType()), nullable, 0, total_num_rows); field_data->FillFieldData(col_data); field_datas.push_back(field_data); } diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index aa41438e2bc8d..20f72fd5d1511 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -83,8 +83,15 @@ StringIndexMarisa::BuildV2(const Config& config) { auto data = rec.ValueUnsafe(); auto total_num_rows = data->num_rows(); auto col_data = data->GetColumnByName(field_name); - auto field_data = - storage::CreateFieldData(DataType::STRING, 0, total_num_rows); + auto nullable = + col_data->type()->id() == arrow::Type::NA ? true : false; + // will support build scalar index when nullable in the future just skip it + // now, not support to build index in nullable field_data + // todo(smellthemoon) + AssertInfo(!nullable, + "not support to build index in nullable field_data"); + auto field_data = storage::CreateFieldData( + DataType::STRING, nullable, 0, total_num_rows); field_data->FillFieldData(col_data); field_datas.push_back(field_data); } diff --git a/internal/core/src/index/Utils.cpp b/internal/core/src/index/Utils.cpp index a9ad1cf1a0d91..b29caf254f3f4 100644 --- a/internal/core/src/index/Utils.cpp +++ b/internal/core/src/index/Utils.cpp @@ -240,9 +240,9 @@ AssembleIndexDatas(std::map& index_datas) { std::string prefix = item[NAME]; int slice_num = item[SLICE_NUM]; auto total_len = static_cast(item[TOTAL_LEN]); - + // todo:smellthemoon: use false here temporarily auto new_field_data = - storage::CreateFieldData(DataType::INT8, 1, total_len); + storage::CreateFieldData(DataType::INT8, false, 1, total_len); for (auto i = 0; i < slice_num; ++i) { std::string file_name = GenSlicedFileName(prefix, i); @@ -279,9 +279,9 @@ AssembleIndexDatas(std::map& index_datas, std::string prefix = item[NAME]; int slice_num = item[SLICE_NUM]; auto total_len = static_cast(item[TOTAL_LEN]); - + // todo:smellthemoon: use false here temporarily auto new_field_data = - storage::CreateFieldData(DataType::INT8, 1, total_len); + storage::CreateFieldData(DataType::INT8, false, 1, total_len); for (auto i = 0; i < slice_num; ++i) { std::string file_name = GenSlicedFileName(prefix, i); diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 2c99b09204944..e78dca41ae9a1 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -261,8 +261,8 @@ VectorMemIndex::LoadV2(const Config& config) { int slice_num = item[SLICE_NUM]; auto total_len = static_cast(item[TOTAL_LEN]); - auto new_field_data = - milvus::storage::CreateFieldData(DataType::INT8, 1, total_len); + auto new_field_data = milvus::storage::CreateFieldData( + DataType::INT8, false, 1, total_len); for (auto i = 0; i < slice_num; ++i) { std::string file_name = index_prefix + "/" + GenSlicedFileName(prefix, i); @@ -361,7 +361,7 @@ VectorMemIndex::Load(milvus::tracer::TraceContext ctx, auto total_len = static_cast(item[TOTAL_LEN]); auto new_field_data = milvus::storage::CreateFieldData( - DataType::INT8, 1, total_len); + DataType::INT8, false, 1, total_len); std::vector batch; batch.reserve(slice_num); @@ -464,7 +464,7 @@ VectorMemIndex::BuildV2(const Config& config) { auto total_num_rows = data->num_rows(); auto col_data = data->GetColumnByName(field_name); auto field_data = - storage::CreateFieldData(field_type, dim, total_num_rows); + storage::CreateFieldData(field_type, false, dim, total_num_rows); field_data->FillFieldData(col_data); field_datas.push_back(field_data); } diff --git a/internal/core/src/mmap/Column.h b/internal/core/src/mmap/Column.h index bda4ca16a9edd..90225bc0431f5 100644 --- a/internal/core/src/mmap/Column.h +++ b/internal/core/src/mmap/Column.h @@ -65,10 +65,17 @@ class ColumnBase { return; } - cap_size_ = type_size_ * reserve; + if (!field_meta.is_vector()) { + is_scalar = true; + } else { + AssertInfo(!field_meta.is_nullable(), + "only support null in scalar"); + } + + data_cap_size_ = field_meta.get_sizeof() * reserve; // use anon mapping so we are able to free these memory with munmap only - size_t mapped_size = cap_size_ + padding_; + size_t mapped_size = data_cap_size_ + padding_; data_ = static_cast(mmap(nullptr, mapped_size, PROT_READ | PROT_WRITE, @@ -80,6 +87,20 @@ class ColumnBase { strerror(errno), mapped_size); + if (field_meta.is_nullable()) { + nullable = true; + valid_data_cap_size_ = (reserve + 7) / 8; + mapped_size += valid_data_cap_size_; + valid_data_ = static_cast(mmap(nullptr, + valid_data_cap_size_, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANON, + -1, + 0)); + AssertInfo(valid_data_ != MAP_FAILED, + "failed to create anon map, err: {}", + strerror(errno)); + } UpdateMetricWhenMmap(mapped_size); } @@ -92,9 +113,9 @@ class ColumnBase { num_rows_(size / type_size_) { SetPaddingSize(field_meta.get_data_type()); - size_ = size; - cap_size_ = size; - size_t mapped_size = cap_size_ + padding_; + data_size_ = size; + data_cap_size_ = size; + size_t mapped_size = data_cap_size_ + padding_; data_ = static_cast(mmap( nullptr, mapped_size, PROT_READ, MAP_SHARED, file.Descriptor(), 0)); AssertInfo(data_ != MAP_FAILED, @@ -102,6 +123,26 @@ class ColumnBase { strerror(errno)); madvise(data_, mapped_size, MADV_WILLNEED); + if (!field_meta.is_vector()) { + is_scalar = true; + if (field_meta.is_nullable()) { + nullable = true; + valid_data_cap_size_ = (num_rows_ + 7) / 8; + valid_data_size_ = (num_rows_ + 7) / 8; + mapped_size += valid_data_size_; + valid_data_ = static_cast(mmap(nullptr, + valid_data_cap_size_, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANON, + file.Descriptor(), + 0)); + AssertInfo(valid_data_ != MAP_FAILED, + "failed to create file-backed map, err: {}", + strerror(errno)); + madvise(valid_data_, valid_data_cap_size_, MADV_WILLNEED); + } + } + UpdateMetricWhenMmap(mapped_size); } @@ -109,27 +150,45 @@ class ColumnBase { ColumnBase(const File& file, size_t size, int dim, - const DataType& data_type) - : type_size_(GetDataTypeSize(data_type, dim)), + const DataType& data_type, + bool nullable) + : nullable(nullable), + type_size_(GetDataTypeSize(data_type, dim)), num_rows_(size / GetDataTypeSize(data_type, dim)), - size_(size), - cap_size_(size), + data_size_(size), + data_cap_size_(size), is_map_anonymous_(false) { SetPaddingSize(data_type); - size_t mapped_size = cap_size_ + padding_; + size_t mapped_size = data_cap_size_ + padding_; data_ = static_cast(mmap( nullptr, mapped_size, PROT_READ, MAP_SHARED, file.Descriptor(), 0)); AssertInfo(data_ != MAP_FAILED, "failed to create file-backed map, err: {}", strerror(errno)); - + if (dim == 1) { + is_scalar = true; + if (nullable) { + valid_data_cap_size_ = (num_rows_ + 7) / 8; + valid_data_size_ = (num_rows_ + 7) / 8; + mapped_size += valid_data_size_; + valid_data_ = static_cast(mmap(nullptr, + valid_data_cap_size_, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANON, + file.Descriptor(), + 0)); + AssertInfo(valid_data_ != MAP_FAILED, + "failed to create file-backed map, err: {}", + strerror(errno)); + } + } UpdateMetricWhenMmap(mapped_size); } virtual ~ColumnBase() { if (data_ != nullptr) { - size_t mapped_size = cap_size_ + padding_; + size_t mapped_size = data_cap_size_ + padding_; if (munmap(data_, mapped_size)) { AssertInfo(true, "failed to unmap variable field, err={}", @@ -137,20 +196,36 @@ class ColumnBase { } UpdateMetricWhenMunmap(mapped_size); } + if (valid_data_ != nullptr) { + if (munmap(valid_data_, valid_data_cap_size_)) { + AssertInfo(true, + "failed to unmap variable field, err={}", + strerror(errno)); + } + UpdateMetricWhenMunmap(valid_data_cap_size_); + } } ColumnBase(ColumnBase&& column) noexcept : data_(column.data_), - cap_size_(column.cap_size_), + nullable(column.nullable), + valid_data_(column.valid_data_), + valid_data_cap_size_(column.valid_data_cap_size_), + data_cap_size_(column.data_cap_size_), padding_(column.padding_), type_size_(column.type_size_), num_rows_(column.num_rows_), - size_(column.size_) { + data_size_(column.data_size_), + valid_data_size_(column.valid_data_size_) { column.data_ = nullptr; - column.cap_size_ = 0; + column.data_cap_size_ = 0; column.padding_ = 0; column.num_rows_ = 0; - column.size_ = 0; + column.data_size_ = 0; + column.nullable = false; + column.valid_data_ = nullptr; + column.valid_data_cap_size_ = 0; + column.valid_data_size_ = 0; } virtual const char* @@ -158,6 +233,26 @@ class ColumnBase { return data_; } + const uint8_t* + ValidData() const { + return valid_data_; + } + + bool + IsNullable() const { + return nullable; + } + + size_t + DataSize() const { + return data_size_; + } + + size_t + ValidDataSize() const { + return valid_data_size_; + } + size_t NumRows() const { return num_rows_; @@ -165,14 +260,14 @@ class ColumnBase { virtual size_t ByteSize() const { - return cap_size_ + padding_; + return data_cap_size_ + padding_ + valid_data_cap_size_; } // The capacity of the column, // DO NOT call this for variable length column(including SparseFloatColumn). virtual size_t Capacity() const { - return cap_size_ / type_size_; + return data_cap_size_ / type_size_; } virtual SpanBase @@ -180,31 +275,44 @@ class ColumnBase { virtual void AppendBatch(const FieldDataPtr data) { - size_t required_size = size_ + data->Size(); - if (required_size > cap_size_) { - Expand(required_size * 2 + padding_); + size_t required_size = data_size_ + data->DataSize(); + if (required_size > data_cap_size_) { + ExpandData(required_size * 2 + padding_); } std::copy_n(static_cast(data->Data()), - data->Size(), - data_ + size_); - size_ = required_size; + data->DataSize(), + data_ + data_size_); + data_size_ = required_size; num_rows_ += data->Length(); + AppendValidData(data->ValidData(), data->ValidDataSize()); } // Append one row virtual void Append(const char* data, size_t size) { - size_t required_size = size_ + size; - if (required_size > cap_size_) { - Expand(required_size * 2); + size_t required_size = data_size_ + size; + if (required_size > data_cap_size_) { + ExpandData(required_size * 2); } - std::copy_n(data, size, data_ + size_); - size_ = required_size; + std::copy_n(data, size, data_ + data_size_); + data_size_ = required_size; num_rows_++; } + // append valid_data don't need to change num_rows + void + AppendValidData(const uint8_t* valid_data, size_t size) { + if (nullable == true) { + size_t required_size = valid_data_size_ + size; + if (required_size > valid_data_cap_size_) { + ExpandValidData(required_size * 2); + } + std::copy(valid_data, valid_data + size, valid_data_); + } + } + void SetPaddingSize(const DataType& type) { switch (type) { @@ -228,7 +336,7 @@ class ColumnBase { protected: // only for memory mode, not mmap void - Expand(size_t new_size) { + ExpandData(size_t new_size) { if (new_size == 0) { return; } @@ -248,8 +356,8 @@ class ColumnBase { new_size + padding_); if (data_ != nullptr) { - std::memcpy(data, data_, size_); - if (munmap(data_, cap_size_ + padding_)) { + std::memcpy(data, data_, data_size_); + if (munmap(data_, data_cap_size_ + padding_)) { auto err = errno; size_t mapped_size = new_size + padding_; munmap(data, mapped_size); @@ -259,25 +367,66 @@ class ColumnBase { false, "failed to unmap while expanding: {}, old_map_size={}", strerror(err), - cap_size_ + padding_); + data_cap_size_ + padding_); } - UpdateMetricWhenMunmap(cap_size_ + padding_); + UpdateMetricWhenMunmap(data_cap_size_ + padding_); } data_ = data; - cap_size_ = new_size; + data_cap_size_ = new_size; + is_map_anonymous_ = true; + } + + // only for memory mode, not mmap + void + ExpandValidData(size_t new_size) { + if (new_size == 0) { + return; + } + auto valid_data = static_cast(mmap(nullptr, + new_size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANON, + -1, + 0)); + UpdateMetricWhenMmap(true, new_size); + AssertInfo(valid_data != MAP_FAILED, + "failed to create map: {}", + strerror(errno)); + + if (valid_data_ != nullptr) { + std::memcpy(valid_data, valid_data_, valid_data_size_); + if (munmap(valid_data_, valid_data_cap_size_)) { + auto err = errno; + munmap(valid_data, new_size); + UpdateMetricWhenMunmap(new_size); + AssertInfo(false, + "failed to unmap while expanding, err={}", + strerror(errno)); + } + UpdateMetricWhenMunmap(new_size); + } + + valid_data_ = valid_data; + valid_data_cap_size_ = new_size; is_map_anonymous_ = true; } char* data_{nullptr}; + bool nullable{false}; + uint8_t* valid_data_{nullptr}; + size_t valid_data_cap_size_{0}; + // std::shared_ptr valid_data_{nullptr}; + bool is_scalar{false}; // capacity in bytes - size_t cap_size_{0}; + size_t data_cap_size_{0}; size_t padding_{0}; const size_t type_size_{1}; size_t num_rows_{0}; // length in bytes - size_t size_{0}; + size_t data_size_{0}; + size_t valid_data_size_{0}; private: void @@ -329,8 +478,12 @@ class Column : public ColumnBase { } // mmap mode ctor - Column(const File& file, size_t size, int dim, DataType data_type) - : ColumnBase(file, size, dim, data_type) { + Column(const File& file, + size_t size, + int dim, + DataType data_type, + bool nullable) + : ColumnBase(file, size, dim, data_type, nullable) { } Column(Column&& column) noexcept : ColumnBase(std::move(column)) { @@ -340,7 +493,7 @@ class Column : public ColumnBase { SpanBase Span() const override { - return SpanBase(data_, num_rows_, cap_size_ / num_rows_); + return SpanBase(data_, num_rows_, data_cap_size_ / num_rows_); } }; @@ -459,7 +612,7 @@ class VariableColumn : public ColumnBase { std::string_view RawAt(const int i) const { - size_t len = (i == indices_.size() - 1) ? size_ - indices_.back() + size_t len = (i == indices_.size() - 1) ? data_size_ - indices_.back() : indices_[i + 1] - indices_[i]; return std::string_view(data_ + indices_[i], len); } @@ -469,8 +622,8 @@ class VariableColumn : public ColumnBase { for (auto i = 0; i < chunk->get_num_rows(); i++) { auto data = static_cast(chunk->RawValue(i)); - indices_.emplace_back(size_); - size_ += data->size(); + indices_.emplace_back(data_size_); + data_size_ += data->size(); } load_buf_.emplace(std::move(chunk)); } @@ -485,9 +638,13 @@ class VariableColumn : public ColumnBase { // for variable length column in memory mode only if (data_ == nullptr) { - size_t total_size = size_; - size_ = 0; - Expand(total_size); + size_t total_data_size = data_size_; + data_size_ = 0; + ExpandData(total_data_size); + + size_t total_valid_data_size = valid_data_size_; + valid_data_size_ = 0; + ExpandValidData(total_valid_data_size); while (!load_buf_.empty()) { auto chunk = std::move(load_buf_.front()); @@ -495,9 +652,16 @@ class VariableColumn : public ColumnBase { for (auto i = 0; i < chunk->get_num_rows(); i++) { auto data = static_cast(chunk->RawValue(i)); - std::copy_n(data->c_str(), data->size(), data_ + size_); - size_ += data->size(); + std::copy_n( + data->c_str(), data->size(), data_ + data_size_); + data_size_ += data->size(); + } + if (nullable == true) { + std::copy(chunk->ValidData(), + chunk->ValidDataSize() + chunk->ValidData(), + valid_data_); } + valid_data_size_ += chunk->ValidDataSize(); } } @@ -512,7 +676,8 @@ class VariableColumn : public ColumnBase { views_.emplace_back(data_ + indices_[i], indices_[i + 1] - indices_[i]); } - views_.emplace_back(data_ + indices_.back(), size_ - indices_.back()); + views_.emplace_back(data_ + indices_.back(), + data_size_ - indices_.back()); } private: @@ -570,7 +735,7 @@ class ArrayColumn : public ColumnBase { void Append(const Array& array) { - indices_.emplace_back(size_); + indices_.emplace_back(data_size_); element_indices_.emplace_back(array.get_offsets()); ColumnBase::Append(static_cast(array.data()), array.byte_size()); @@ -597,7 +762,7 @@ class ArrayColumn : public ColumnBase { std::move(element_indices_[i])); } views_.emplace_back(data_ + indices_.back(), - size_ - indices_.back(), + data_size_ - indices_.back(), element_type_, std::move(element_indices_[indices_.size() - 1])); element_indices_.clear(); diff --git a/internal/core/src/mmap/Utils.h b/internal/core/src/mmap/Utils.h index db06c1f0d0cf1..118a243dce62f 100644 --- a/internal/core/src/mmap/Utils.h +++ b/internal/core/src/mmap/Utils.h @@ -39,6 +39,9 @@ WriteFieldData(File& file, const FieldDataPtr& data, std::vector>& element_indices) { size_t total_written{0}; + if (data->IsNullable()) { + total_written += file.Write(data->ValidData(), data->ValidDataSize()); + } if (IsVariableDataType(data_type)) { switch (data_type) { case DataType::VARCHAR: @@ -92,7 +95,7 @@ WriteFieldData(File& file, GetDataTypeName(data_type)); } } else { - total_written += file.Write(data->Data(), data->Size()); + total_written += file.Write(data->Data(), data->DataSize()); } return total_written; diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index aaa900405b807..95b1c8e9c504f 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -136,7 +136,9 @@ class VectorBase { const int64_t size_per_chunk_; }; -template +template class ConcurrentVectorImpl : public VectorBase { public: // constants @@ -169,6 +171,13 @@ class ConcurrentVectorImpl : public VectorBase { elements_per_row_(is_type_entire_row ? 1 : elements_per_row) { } + void + grow_to_at_least(int64_t element_count) { + auto chunk_count = upper_div(element_count, size_per_chunk_); + chunks_.emplace_to_at_least(chunk_count, + elements_per_row_ * size_per_chunk_); + } + Span get_span(int64_t chunk_id) const { auto& chunk = get_chunk(chunk_id); @@ -233,6 +242,42 @@ class ConcurrentVectorImpl : public VectorBase { element_offset, static_cast(source), element_count); } + void + set_data(ssize_t element_offset, + const Type* source, + ssize_t element_count) { + auto chunk_id = element_offset / size_per_chunk_; + auto chunk_offset = element_offset % size_per_chunk_; + ssize_t source_offset = 0; + // first partition: + if (chunk_offset + element_count <= size_per_chunk_) { + // only first + fill_chunk( + chunk_id, chunk_offset, element_count, source, source_offset); + return; + } + + auto first_size = size_per_chunk_ - chunk_offset; + fill_chunk(chunk_id, chunk_offset, first_size, source, source_offset); + + source_offset += size_per_chunk_ - chunk_offset; + element_count -= first_size; + ++chunk_id; + + // the middle + while (element_count >= size_per_chunk_) { + fill_chunk(chunk_id, 0, size_per_chunk_, source, source_offset); + source_offset += size_per_chunk_; + element_count -= size_per_chunk_; + ++chunk_id; + } + + // the final + if (element_count > 0) { + fill_chunk(chunk_id, 0, element_count, source, source_offset); + } + } + const Chunk& get_chunk(ssize_t chunk_index) const { return chunks_[chunk_index]; @@ -290,42 +335,6 @@ class ConcurrentVectorImpl : public VectorBase { } private: - void - set_data(ssize_t element_offset, - const Type* source, - ssize_t element_count) { - auto chunk_id = element_offset / size_per_chunk_; - auto chunk_offset = element_offset % size_per_chunk_; - ssize_t source_offset = 0; - // first partition: - if (chunk_offset + element_count <= size_per_chunk_) { - // only first - fill_chunk( - chunk_id, chunk_offset, element_count, source, source_offset); - return; - } - - auto first_size = size_per_chunk_ - chunk_offset; - fill_chunk(chunk_id, chunk_offset, first_size, source, source_offset); - - source_offset += size_per_chunk_ - chunk_offset; - element_count -= first_size; - ++chunk_id; - - // the middle - while (element_count >= size_per_chunk_) { - fill_chunk(chunk_id, 0, size_per_chunk_, source, source_offset); - source_offset += size_per_chunk_; - element_count -= size_per_chunk_; - ++chunk_id; - } - - // the final - if (element_count > 0) { - fill_chunk(chunk_id, 0, element_count, source, source_offset); - } - } - void fill_chunk(ssize_t chunk_id, ssize_t chunk_offset, @@ -399,6 +408,60 @@ class ConcurrentVector int64_t dim_; }; +class ConcurrentValidDataVector : public ConcurrentVectorImpl { + public: + static_assert(IsScalar); + explicit ConcurrentValidDataVector(int64_t size_per_chunk) + : ConcurrentVectorImpl::ConcurrentVectorImpl( + 1, size_per_chunk) { + } + void + set_data_raw(ssize_t element_offset, + const std::vector& datas) override { + for (auto& field_data : datas) { + auto num_rows = field_data->get_num_rows(); + auto valid_data = std::make_unique(num_rows); + for (size_t i = 0; i < num_rows; ++i) { + auto bit = + (field_data->ValidData()[i >> 3] >> ((i & 0x07))) & 1; + valid_data[i] = bit; + } + set_data_raw(element_offset, valid_data.get(), num_rows); + element_offset += num_rows; + } + } + void + set_data_raw(ssize_t element_offset, + ssize_t element_count, + const DataArray* data, + const FieldMeta& field_meta) override { + if (field_meta.is_nullable()) { + return set_data_raw( + element_offset, data->valid_data().data(), element_count); + } + } + + void + set_data_raw(ssize_t element_offset, + const void* source, + ssize_t element_count) override { + throw SegcoreError( + NotImplemented, + "source type is specified in ConcurrentValidDataVector"); + } + + void + set_data_raw(ssize_t element_offset, + const bool* source, + ssize_t element_count) { + if (element_count == 0) { + return; + } + this->grow_to_at_least(element_offset + element_count); + this->set_data(element_offset, source, element_count); + } +}; + template <> class ConcurrentVector : public ConcurrentVectorImpl { diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index dcf559d2913a4..f7e7b5661e6a4 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -297,6 +297,9 @@ struct InsertRecord { for (auto& field : schema) { auto field_id = field.first; auto& field_meta = field.second; + if (field_meta.is_nullable()) { + this->append_valid_data(field_id, size_per_chunk); + } if (pk2offset_ == nullptr && pk_field_id.has_value() && pk_field_id.value() == field_id) { switch (field_meta.get_data_type()) { @@ -553,6 +556,21 @@ struct InsertRecord { return ptr; } + ConcurrentValidDataVector* + get_valid_data(FieldId field_id) const { + AssertInfo(valid_data_.find(field_id) != valid_data_.end(), + "Cannot find valid_data with field_id: " + + std::to_string(field_id.get())); + auto ptr = valid_data_.at(field_id).get(); + Assert(ptr); + return ptr; + } + + bool + is_valid_data_exist(FieldId field_id) { + return valid_data_.find(field_id) != valid_data_.end(); + } + // append a column of scalar or sparse float vector type template void @@ -562,6 +580,14 @@ struct InsertRecord { field_id, std::make_unique>(size_per_chunk)); } + // append a column of scalar type + void + append_valid_data(FieldId field_id, int64_t size_per_chunk) { + valid_data_.emplace( + field_id, + std::make_unique(size_per_chunk)); + } + // append a column of vector type template void @@ -573,8 +599,9 @@ struct InsertRecord { } void - drop_field_data(FieldId field_id) { + drop_data(FieldId field_id) { fields_data_.erase(field_id); + valid_data_.erase(field_id); } const ConcurrentVector& @@ -614,6 +641,8 @@ struct InsertRecord { private: std::unordered_map> fields_data_{}; + std::unordered_map> + valid_data_{}; mutable std::shared_mutex shared_mutex_{}; }; diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 3d1f277c43d89..f9ca551257758 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -127,6 +127,13 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset, num_rows, &insert_record_proto->fields_data(data_offset), field_meta); + if (insert_record_.is_valid_data_exist(field_id)) { + insert_record_.get_valid_data(field_id)->set_data_raw( + reserved_offset, + num_rows, + &insert_record_proto->fields_data(data_offset), + field_meta); + } } //insert vector data into index if (segcore_config_.get_enable_interim_segment_index()) { @@ -231,6 +238,10 @@ SegmentGrowingImpl::LoadFieldData(const LoadFieldDataInfo& infos) { if (!indexing_record_.SyncDataWithIndex(field_id)) { insert_record_.get_field_data_base(field_id)->set_data_raw( reserved_offset, field_data); + if (insert_record_.is_valid_data_exist(field_id)) { + insert_record_.get_valid_data(field_id)->set_data_raw( + reserved_offset, field_data); + } } if (segcore_config_.get_enable_interim_segment_index()) { auto offset = reserved_offset; @@ -508,6 +519,15 @@ SegmentGrowingImpl::bulk_subscript(FieldId field_id, AssertInfo(!field_meta.is_vector(), "Scalar field meta type is vector type"); auto result = CreateScalarDataArray(count, field_meta); + if (field_meta.is_nullable()) { + auto valid_data_ptr = insert_record_.get_valid_data(field_id); + auto res = result->mutable_valid_data()->mutable_data(); + auto& valid_data = *valid_data_ptr; + for (int64_t i = 0; i < count; ++i) { + auto offset = seg_offsets[i]; + res[i] = valid_data[offset]; + } + } switch (field_meta.get_data_type()) { case DataType::BOOL: { bulk_subscript_impl(vec_ptr, diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 18f4bfcee3610..45a3354da87eb 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -393,6 +393,9 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { FieldDataPtr field_data; while (data.channel->pop(field_data)) { var_column->Append(std::move(field_data)); + var_column->AppendValidData( + field_data->ValidData(), + field_data->ValidDataSize()); } var_column->Seal(); field_data_size = var_column->ByteSize(); @@ -408,6 +411,9 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { FieldDataPtr field_data; while (data.channel->pop(field_data)) { var_column->Append(std::move(field_data)); + var_column->AppendValidData( + field_data->ValidData(), + field_data->ValidDataSize()); } var_column->Seal(); stats_.mem_size += var_column->ByteSize(); @@ -432,6 +438,9 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { stats_.mem_size += array->byte_size() + sizeof(uint64_t); } + var_column->AppendValidData( + field_data->ValidData(), + field_data->ValidDataSize()); } var_column->Seal(); column = std::move(var_column); @@ -544,7 +553,7 @@ SegmentSealedImpl::MapFieldData(const FieldId field_id, FieldDataInfo& data) { strerror(errno))); for (auto i = 0; i < field_data->get_num_rows(); i++) { - auto size = field_data->Size(i); + auto size = field_data->DataSize(i); indices.emplace_back(total_written); total_written += size; } @@ -1168,6 +1177,16 @@ SegmentSealedImpl::get_raw_data(FieldId field_id, // to make sure it won't get released if segment released auto column = fields_.at(field_id); auto ret = fill_with_empty(field_id, count); + if (column->IsNullable()) { + auto dst = ret->mutable_valid_data()->mutable_data(); + // auto valid_data = std::make_unique(count); + for (size_t i = 0; i < count; ++i) { + auto offset = seg_offsets[i]; + auto bit = + (column->ValidData()[offset >> 3] >> ((offset & 0x07))) & 1; + dst[i] = bit; + } + } switch (field_meta.get_data_type()) { case DataType::VARCHAR: case DataType::STRING: { diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index 0adc911c2e921..9cb1dd9d2acba 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -232,6 +232,10 @@ CreateScalarDataArray(int64_t count, const FieldMeta& field_meta) { data_array->set_type(static_cast( field_meta.get_data_type())); + if (field_meta.is_nullable()) { + data_array->mutable_valid_data()->Resize(count, false); + } + auto scalar_array = data_array->mutable_scalars(); switch (data_type) { case DataType::BOOL: { @@ -360,6 +364,7 @@ CreateVectorDataArray(int64_t count, const FieldMeta& field_meta) { std::unique_ptr CreateScalarDataArrayFrom(const void* data_raw, + const void* valid_data, int64_t count, const FieldMeta& field_meta) { auto data_type = field_meta.get_data_type(); @@ -367,6 +372,11 @@ CreateScalarDataArrayFrom(const void* data_raw, data_array->set_field_id(field_meta.get_id().get()); data_array->set_type(static_cast( field_meta.get_data_type())); + if (field_meta.is_nullable()) { + auto valid_data_ = reinterpret_cast(valid_data); + auto obj = data_array->mutable_valid_data(); + obj->Add(valid_data_, valid_data_ + count); + } auto scalar_array = data_array->mutable_scalars(); switch (data_type) { @@ -517,12 +527,14 @@ CreateVectorDataArrayFrom(const void* data_raw, std::unique_ptr CreateDataArrayFrom(const void* data_raw, + const void* valid_data, int64_t count, const FieldMeta& field_meta) { auto data_type = field_meta.get_data_type(); if (!IsVectorDataType(data_type)) { - return CreateScalarDataArrayFrom(data_raw, count, field_meta); + return CreateScalarDataArrayFrom( + data_raw, valid_data, count, field_meta); } return CreateVectorDataArrayFrom(data_raw, count, field_meta); @@ -536,6 +548,7 @@ MergeDataArray( auto data_type = field_meta.get_data_type(); auto data_array = std::make_unique(); data_array->set_field_id(field_meta.get_id().get()); + auto nullable = field_meta.is_nullable(); data_array->set_type(static_cast( field_meta.get_data_type())); @@ -590,6 +603,12 @@ MergeDataArray( continue; } + if (nullable) { + auto data = src_field_data->valid_data().data(); + auto obj = data_array->mutable_valid_data(); + *(obj->Add()) = data[src_offset]; + } + auto scalar_array = data_array->mutable_scalars(); switch (data_type) { case DataType::BOOL: { @@ -784,6 +803,7 @@ LoadFieldDatasFromRemote2(std::shared_ptr space, data->GetColumnByName(field.second.get_name().get()); auto field_data = storage::CreateFieldData( field.second.get_data_type(), + field.second.is_nullable(), field.second.is_vector() ? field.second.get_dim() : 0, total_num_rows); field_data->FillFieldData(col_data); diff --git a/internal/core/src/segcore/Utils.h b/internal/core/src/segcore/Utils.h index b3987597c5828..5ec2c993d7234 100644 --- a/internal/core/src/segcore/Utils.h +++ b/internal/core/src/segcore/Utils.h @@ -63,6 +63,7 @@ CreateVectorDataArray(int64_t count, const FieldMeta& field_meta); std::unique_ptr CreateScalarDataArrayFrom(const void* data_raw, + const void* valid_data, int64_t count, const FieldMeta& field_meta); @@ -73,6 +74,7 @@ CreateVectorDataArrayFrom(const void* data_raw, std::unique_ptr CreateDataArrayFrom(const void* data_raw, + const void* valid_data, int64_t count, const FieldMeta& field_meta); diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index d33ad985e4e9e..5c7a71c6a33b4 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -343,7 +343,8 @@ LoadFieldRawData(CSegmentInterface c_segment, dim = field_meta.get_dim(); } } - auto field_data = milvus::storage::CreateFieldData(data_type, dim); + auto field_data = + milvus::storage::CreateFieldData(data_type, false, dim); field_data->FillFieldData(data, row_count); milvus::FieldDataChannelPtr channel = std::make_shared(); diff --git a/internal/core/src/storage/ChunkCache.cpp b/internal/core/src/storage/ChunkCache.cpp index 4f9c80418f007..bc27001ec2b40 100644 --- a/internal/core/src/storage/ChunkCache.cpp +++ b/internal/core/src/storage/ChunkCache.cpp @@ -104,7 +104,8 @@ ChunkCache::Mmap(const std::filesystem::path& path, AssertInfo( false, "TODO: unimplemented for variable data type: {}", data_type); } else { - column = std::make_shared(file, data_size, dim, data_type); + column = std::make_shared( + file, data_size, dim, data_type, field_data->IsNullable()); } // unlink diff --git a/internal/core/src/storage/DataCodec.cpp b/internal/core/src/storage/DataCodec.cpp index 2e37f7bf732bc..48890529aeffe 100644 --- a/internal/core/src/storage/DataCodec.cpp +++ b/internal/core/src/storage/DataCodec.cpp @@ -31,6 +31,7 @@ DeserializeRemoteFileData(BinlogReaderPtr reader) { DescriptorEvent descriptor_event(reader); DataType data_type = DataType(descriptor_event.event_data.fix_part.data_type); + bool nullable = descriptor_event.event_data.fix_part.nullable; auto descriptor_fix_part = descriptor_event.event_data.fix_part; FieldDataMeta data_meta{descriptor_fix_part.collection_id, descriptor_fix_part.partition_id, @@ -42,7 +43,7 @@ DeserializeRemoteFileData(BinlogReaderPtr reader) { auto event_data_length = header.event_length_ - GetEventHeaderSize(header); auto insert_event_data = - InsertEventData(reader, event_data_length, data_type); + InsertEventData(reader, event_data_length, data_type, nullable); auto insert_data = std::make_unique(insert_event_data.field_data); insert_data->SetFieldDataMeta(data_meta); @@ -54,7 +55,7 @@ DeserializeRemoteFileData(BinlogReaderPtr reader) { auto event_data_length = header.event_length_ - GetEventHeaderSize(header); auto index_event_data = - IndexEventData(reader, event_data_length, data_type); + IndexEventData(reader, event_data_length, data_type, nullable); auto field_data = index_event_data.field_data; // for compatible with golang indexcode.Serialize, which set dataType to String if (data_type == DataType::STRING) { @@ -63,7 +64,7 @@ DeserializeRemoteFileData(BinlogReaderPtr reader) { AssertInfo( field_data->get_num_rows() == 1, "wrong length of string num in old index binlog file"); - auto new_field_data = CreateFieldData(DataType::INT8); + auto new_field_data = CreateFieldData(DataType::INT8, nullable); new_field_data->FillFieldData( (*static_cast(field_data->RawValue(0))) .c_str(), diff --git a/internal/core/src/storage/DiskFileManagerImpl.cpp b/internal/core/src/storage/DiskFileManagerImpl.cpp index 57dbce8728fb0..fc1eca03b7a39 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.cpp +++ b/internal/core/src/storage/DiskFileManagerImpl.cpp @@ -410,7 +410,7 @@ DiskFileManagerImpl::CacheRawDataToDisk( num_rows += total_num_rows; auto col_data = data->GetColumnByName(index_meta_.field_name); auto field_data = storage::CreateFieldData( - index_meta_.field_type, index_meta_.dim, total_num_rows); + index_meta_.field_type, false, index_meta_.dim, total_num_rows); field_data->FillFieldData(col_data); dim = field_data->get_dim(); auto data_size = diff --git a/internal/core/src/storage/Event.cpp b/internal/core/src/storage/Event.cpp index a3d6e5ef6bfb1..9d1ea8354dbfa 100644 --- a/internal/core/src/storage/Event.cpp +++ b/internal/core/src/storage/Event.cpp @@ -34,7 +34,7 @@ GetFixPartSize(DescriptorEventData& data) { sizeof(data.fix_part.segment_id) + sizeof(data.fix_part.field_id) + sizeof(data.fix_part.start_timestamp) + sizeof(data.fix_part.end_timestamp) + - sizeof(data.fix_part.data_type); + sizeof(data.fix_part.data_type) + sizeof(data.fix_part.nullable); } int GetFixPartSize(BaseEventData& data) { @@ -107,6 +107,8 @@ DescriptorEventDataFixPart::DescriptorEventDataFixPart(BinlogReaderPtr reader) { assert(ast.ok()); ast = reader->Read(sizeof(field_id), &field_id); assert(ast.ok()); + ast = reader->Read(sizeof(nullable), &nullable); + assert(ast.ok()); ast = reader->Read(sizeof(start_timestamp), &start_timestamp); assert(ast.ok()); ast = reader->Read(sizeof(end_timestamp), &end_timestamp); @@ -120,7 +122,7 @@ DescriptorEventDataFixPart::Serialize() { auto fix_part_size = sizeof(collection_id) + sizeof(partition_id) + sizeof(segment_id) + sizeof(field_id) + sizeof(start_timestamp) + sizeof(end_timestamp) + - sizeof(data_type); + sizeof(data_type) + sizeof(nullable); std::vector res(fix_part_size); int offset = 0; memcpy(res.data() + offset, &collection_id, sizeof(collection_id)); @@ -131,6 +133,8 @@ DescriptorEventDataFixPart::Serialize() { offset += sizeof(segment_id); memcpy(res.data() + offset, &field_id, sizeof(field_id)); offset += sizeof(field_id); + memcpy(res.data() + offset, &nullable, sizeof(nullable)); + offset += sizeof(nullable); memcpy(res.data() + offset, &start_timestamp, sizeof(start_timestamp)); offset += sizeof(start_timestamp); memcpy(res.data() + offset, &end_timestamp, sizeof(end_timestamp)); @@ -196,7 +200,8 @@ DescriptorEventData::Serialize() { BaseEventData::BaseEventData(BinlogReaderPtr reader, int event_length, - DataType data_type) { + DataType data_type, + bool nullable) { auto ast = reader->Read(sizeof(start_timestamp), &start_timestamp); AssertInfo(ast.ok(), "read start timestamp failed"); ast = reader->Read(sizeof(end_timestamp), &end_timestamp); @@ -207,7 +212,7 @@ BaseEventData::BaseEventData(BinlogReaderPtr reader, auto res = reader->Read(payload_length); AssertInfo(res.first.ok(), "read payload failed"); auto payload_reader = std::make_shared( - res.second.get(), payload_length, data_type); + res.second.get(), payload_length, data_type, nullable); field_data = payload_reader->get_field_data(); } @@ -217,10 +222,11 @@ BaseEventData::Serialize() { std::shared_ptr payload_writer; if (IsVectorDataType(data_type) && !IsSparseFloatVectorDataType(data_type)) { - payload_writer = - std::make_unique(data_type, field_data->get_dim()); + payload_writer = std::make_unique( + data_type, field_data->get_dim(), field_data->IsNullable()); } else { - payload_writer = std::make_unique(data_type); + payload_writer = std::make_unique( + data_type, field_data->IsNullable()); } switch (data_type) { case DataType::VARCHAR: @@ -229,8 +235,8 @@ BaseEventData::Serialize() { ++offset) { auto str = static_cast( field_data->RawValue(offset)); - payload_writer->add_one_string_payload(str->c_str(), - str->size()); + auto size = field_data->is_null(offset) ? -1 : str->size(); + payload_writer->add_one_string_payload(str->c_str(), size); } break; } @@ -240,10 +246,12 @@ BaseEventData::Serialize() { auto array = static_cast(field_data->RawValue(offset)); auto array_string = array->output_data().SerializeAsString(); + auto size = + field_data->is_null(offset) ? -1 : array_string.size(); payload_writer->add_one_binary_payload( reinterpret_cast(array_string.c_str()), - array_string.size()); + size); } break; } @@ -276,8 +284,10 @@ BaseEventData::Serialize() { auto payload = Payload{data_type, static_cast(field_data->Data()), + field_data->ValidData(), field_data->get_num_rows(), - field_data->get_dim()}; + field_data->get_dim(), + field_data->IsNullable()}; payload_writer->add_payload(payload); } } @@ -297,11 +307,13 @@ BaseEventData::Serialize() { return res; } -BaseEvent::BaseEvent(BinlogReaderPtr reader, DataType data_type) { +BaseEvent::BaseEvent(BinlogReaderPtr reader, + DataType data_type, + bool nullable) { event_header = EventHeader(reader); auto event_data_length = event_header.event_length_ - GetEventHeaderSize(event_header); - event_data = BaseEventData(reader, event_data_length, data_type); + event_data = BaseEventData(reader, event_data_length, data_type, nullable); } std::vector @@ -380,7 +392,7 @@ LocalIndexEvent::LocalIndexEvent(BinlogReaderPtr reader) { auto res = reader->Read(index_size); AssertInfo(res.first.ok(), "read payload failed"); auto payload_reader = std::make_shared( - res.second.get(), index_size, DataType::INT8); + res.second.get(), index_size, DataType::INT8, false); field_data = payload_reader->get_field_data(); } diff --git a/internal/core/src/storage/Event.h b/internal/core/src/storage/Event.h index 87a5d0eb4d927..54a9096745fdc 100644 --- a/internal/core/src/storage/Event.h +++ b/internal/core/src/storage/Event.h @@ -46,6 +46,7 @@ struct DescriptorEventDataFixPart { int64_t partition_id; int64_t segment_id; int64_t field_id; + bool nullable; Timestamp start_timestamp; Timestamp end_timestamp; milvus::proto::schema::DataType data_type; @@ -79,7 +80,8 @@ struct BaseEventData { BaseEventData() = default; explicit BaseEventData(BinlogReaderPtr reader, int event_length, - DataType data_type); + DataType data_type, + bool nullable); std::vector Serialize(); @@ -102,7 +104,9 @@ struct BaseEvent { int64_t event_offset; BaseEvent() = default; - explicit BaseEvent(BinlogReaderPtr reader, DataType data_type); + explicit BaseEvent(BinlogReaderPtr reader, + DataType data_type, + bool nullable); std::vector Serialize(); diff --git a/internal/core/src/storage/InsertData.cpp b/internal/core/src/storage/InsertData.cpp index 514d98d56aac6..f112d7e7659ca 100644 --- a/internal/core/src/storage/InsertData.cpp +++ b/internal/core/src/storage/InsertData.cpp @@ -60,6 +60,7 @@ InsertData::serialize_to_remote_file() { des_fix_part.field_id = field_data_meta_->field_id; des_fix_part.start_timestamp = time_range_.first; des_fix_part.end_timestamp = time_range_.second; + des_fix_part.nullable = field_data_->IsNullable(); des_fix_part.data_type = milvus::proto::schema::DataType(data_type); for (auto i = int8_t(EventType::DescriptorEvent); i < int8_t(EventType::EventTypeEnd); diff --git a/internal/core/src/storage/PayloadReader.cpp b/internal/core/src/storage/PayloadReader.cpp index 81b0cae4e0607..f468bd343d810 100644 --- a/internal/core/src/storage/PayloadReader.cpp +++ b/internal/core/src/storage/PayloadReader.cpp @@ -27,8 +27,9 @@ namespace milvus::storage { PayloadReader::PayloadReader(const uint8_t* data, int length, - DataType data_type) - : column_type_(data_type) { + DataType data_type, + bool nullable) + : column_type_(data_type), nullable(nullable) { auto input = std::make_shared(data, length); init(input); } @@ -72,11 +73,12 @@ PayloadReader::init(std::shared_ptr input) { st = arrow_reader->GetRecordBatchReader(&rb_reader); AssertInfo(st.ok(), "get record batch reader"); - field_data_ = CreateFieldData(column_type_, dim_, total_num_rows); + field_data_ = CreateFieldData(column_type_, nullable, dim_, total_num_rows); for (arrow::Result> maybe_batch : *rb_reader) { AssertInfo(maybe_batch.ok(), "get batch record success"); auto array = maybe_batch.ValueOrDie()->column(column_index); + // to read field_data_->FillFieldData(array); } AssertInfo(field_data_->IsFull(), "field data hasn't been filled done"); diff --git a/internal/core/src/storage/PayloadReader.h b/internal/core/src/storage/PayloadReader.h index b5fb22084dab4..39aa6420fd14d 100644 --- a/internal/core/src/storage/PayloadReader.h +++ b/internal/core/src/storage/PayloadReader.h @@ -26,7 +26,10 @@ namespace milvus::storage { class PayloadReader { public: - explicit PayloadReader(const uint8_t* data, int length, DataType data_type); + explicit PayloadReader(const uint8_t* data, + int length, + DataType data_type, + bool nullable); ~PayloadReader() = default; @@ -41,6 +44,7 @@ class PayloadReader { private: DataType column_type_; int dim_; + bool nullable; FieldDataPtr field_data_; }; diff --git a/internal/core/src/storage/PayloadStream.h b/internal/core/src/storage/PayloadStream.h index c23c7816367b5..8639ab9a97dab 100644 --- a/internal/core/src/storage/PayloadStream.h +++ b/internal/core/src/storage/PayloadStream.h @@ -32,8 +32,10 @@ class PayloadInputStream; struct Payload { DataType data_type; const uint8_t* raw_data; - int64_t rows; + const uint8_t* valid_data; + const int64_t rows; std::optional dimension; + bool nullable; }; class PayloadOutputStream : public arrow::io::OutputStream { diff --git a/internal/core/src/storage/PayloadWriter.cpp b/internal/core/src/storage/PayloadWriter.cpp index d9b1db7dc5cba..c7722c11b88c7 100644 --- a/internal/core/src/storage/PayloadWriter.cpp +++ b/internal/core/src/storage/PayloadWriter.cpp @@ -23,18 +23,19 @@ namespace milvus::storage { // create payload writer for numeric data type -PayloadWriter::PayloadWriter(const DataType column_type) - : column_type_(column_type) { +PayloadWriter::PayloadWriter(const DataType column_type, bool nullable) + : column_type_(column_type), nullable_(nullable) { builder_ = CreateArrowBuilder(column_type); - schema_ = CreateArrowSchema(column_type); + schema_ = CreateArrowSchema(column_type, nullable); } // create payload writer for vector data type -PayloadWriter::PayloadWriter(const DataType column_type, int dim) - : column_type_(column_type) { +PayloadWriter::PayloadWriter(const DataType column_type, int dim, bool nullable) + : column_type_(column_type), nullable_(nullable) { AssertInfo(column_type != DataType::VECTOR_SPARSE_FLOAT, "PayloadWriter for Sparse Float Vector should be created " "using the constructor without dimension"); + AssertInfo(nullable == false, "only scalcar type support null now"); init_dimension(dim); } @@ -48,7 +49,7 @@ PayloadWriter::init_dimension(int dim) { dimension_ = dim; builder_ = CreateArrowBuilder(column_type_, dim); - schema_ = CreateArrowSchema(column_type_, dim); + schema_ = CreateArrowSchema(column_type_, dim, nullable_); } void diff --git a/internal/core/src/storage/PayloadWriter.h b/internal/core/src/storage/PayloadWriter.h index 1bd2d652be9a8..86ca281bb6b62 100644 --- a/internal/core/src/storage/PayloadWriter.h +++ b/internal/core/src/storage/PayloadWriter.h @@ -25,8 +25,8 @@ namespace milvus::storage { class PayloadWriter { public: - explicit PayloadWriter(const DataType column_type); - explicit PayloadWriter(const DataType column_type, int dim); + explicit PayloadWriter(const DataType column_type, int dim, bool nullable); + explicit PayloadWriter(const DataType column_type, bool nullable); ~PayloadWriter() = default; void @@ -58,6 +58,7 @@ class PayloadWriter { private: DataType column_type_; + bool nullable_; std::shared_ptr builder_; std::shared_ptr schema_; std::shared_ptr output_; diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index 0e714f0a97362..ce62c50e38474 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -75,6 +75,17 @@ std::map ReadAheadPolicy_Map = { {"willneed", MADV_WILLNEED}, {"dontneed", MADV_DONTNEED}}; +// in arrow, null_bitmap read from the least significant bit +std::vector +genValidIter(const uint8_t* valid_data, int length) { + std::vector valid_data_; + for (size_t i = 0; i < length; ++i) { + auto bit = (valid_data[i >> 3] >> ((i & 0x07))) & 1; + valid_data_.push_back(bit); + } + return valid_data_; +} + StorageType ReadMediumType(BinlogReaderPtr reader) { AssertInfo(reader->Tell() == 0, @@ -106,12 +117,22 @@ template void add_numeric_payload(std::shared_ptr builder, DT* start, + const uint8_t* valid_data, + bool nullable, int length) { AssertInfo(builder != nullptr, "empty arrow builder"); auto numeric_builder = std::dynamic_pointer_cast(builder); - auto ast = numeric_builder->AppendValues(start, start + length); - AssertInfo( - ast.ok(), "append value to arrow builder failed: {}", ast.ToString()); + arrow::Status ast; + if (nullable) { + // need iter to read valid_data when write + auto iter = genValidIter(valid_data, length); + ast = + numeric_builder->AppendValues(start, start + length, iter.begin()); + AssertInfo(ast.ok(), "append value to arrow builder failed"); + } else { + ast = numeric_builder->AppendValues(start, start + length); + AssertInfo(ast.ok(), "append value to arrow builder failed"); + } } void @@ -121,48 +142,49 @@ AddPayloadToArrowBuilder(std::shared_ptr builder, auto raw_data = const_cast(payload.raw_data); auto length = payload.rows; auto data_type = payload.data_type; + auto nullable = payload.nullable; switch (data_type) { case DataType::BOOL: { auto bool_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, bool_data, length); + builder, bool_data, payload.valid_data, nullable, length); break; } case DataType::INT8: { auto int8_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, int8_data, length); + builder, int8_data, payload.valid_data, nullable, length); break; } case DataType::INT16: { auto int16_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, int16_data, length); + builder, int16_data, payload.valid_data, nullable, length); break; } case DataType::INT32: { auto int32_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, int32_data, length); + builder, int32_data, payload.valid_data, nullable, length); break; } case DataType::INT64: { auto int64_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, int64_data, length); + builder, int64_data, payload.valid_data, nullable, length); break; } case DataType::FLOAT: { auto float_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, float_data, length); + builder, float_data, payload.valid_data, nullable, length); break; } case DataType::DOUBLE: { auto double_data = reinterpret_cast(raw_data); add_numeric_payload( - builder, double_data, length); + builder, double_data, payload.valid_data, nullable, length); break; } case DataType::VECTOR_FLOAT16: @@ -292,40 +314,50 @@ CreateArrowBuilder(DataType data_type, int dim) { } std::shared_ptr -CreateArrowSchema(DataType data_type) { +CreateArrowSchema(DataType data_type, bool nullable) { switch (static_cast(data_type)) { case DataType::BOOL: { - return arrow::schema({arrow::field("val", arrow::boolean())}); + return arrow::schema( + {arrow::field("val", arrow::boolean(), nullable)}); } case DataType::INT8: { - return arrow::schema({arrow::field("val", arrow::int8())}); + return arrow::schema( + {arrow::field("val", arrow::int8(), nullable)}); } case DataType::INT16: { - return arrow::schema({arrow::field("val", arrow::int16())}); + return arrow::schema( + {arrow::field("val", arrow::int16(), nullable)}); } case DataType::INT32: { - return arrow::schema({arrow::field("val", arrow::int32())}); + return arrow::schema( + {arrow::field("val", arrow::int32(), nullable)}); } case DataType::INT64: { - return arrow::schema({arrow::field("val", arrow::int64())}); + return arrow::schema( + {arrow::field("val", arrow::int64(), nullable)}); } case DataType::FLOAT: { - return arrow::schema({arrow::field("val", arrow::float32())}); + return arrow::schema( + {arrow::field("val", arrow::float32(), nullable)}); } case DataType::DOUBLE: { - return arrow::schema({arrow::field("val", arrow::float64())}); + return arrow::schema( + {arrow::field("val", arrow::float64(), nullable)}); } case DataType::VARCHAR: case DataType::STRING: { - return arrow::schema({arrow::field("val", arrow::utf8())}); + return arrow::schema( + {arrow::field("val", arrow::utf8(), nullable)}); } case DataType::ARRAY: case DataType::JSON: { - return arrow::schema({arrow::field("val", arrow::binary())}); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable)}); } // sparse float vector doesn't require a dim case DataType::VECTOR_SPARSE_FLOAT: { - return arrow::schema({arrow::field("val", arrow::binary())}); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable)}); } default: { PanicInfo( @@ -335,30 +367,37 @@ CreateArrowSchema(DataType data_type) { } std::shared_ptr -CreateArrowSchema(DataType data_type, int dim) { +CreateArrowSchema(DataType data_type, int dim, bool nullable) { switch (static_cast(data_type)) { case DataType::VECTOR_FLOAT: { AssertInfo(dim > 0, "invalid dim value: {}", dim); - return arrow::schema({arrow::field( - "val", arrow::fixed_size_binary(dim * sizeof(float)))}); + return arrow::schema( + {arrow::field("val", + arrow::fixed_size_binary(dim * sizeof(float)), + nullable)}); } case DataType::VECTOR_BINARY: { AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim); - return arrow::schema( - {arrow::field("val", arrow::fixed_size_binary(dim / 8))}); + return arrow::schema({arrow::field( + "val", arrow::fixed_size_binary(dim / 8), nullable)}); } case DataType::VECTOR_FLOAT16: { AssertInfo(dim > 0, "invalid dim value: {}", dim); - return arrow::schema({arrow::field( - "val", arrow::fixed_size_binary(dim * sizeof(float16)))}); + return arrow::schema( + {arrow::field("val", + arrow::fixed_size_binary(dim * sizeof(float16)), + nullable)}); } case DataType::VECTOR_BFLOAT16: { AssertInfo(dim > 0, "invalid dim value"); - return arrow::schema({arrow::field( - "val", arrow::fixed_size_binary(dim * sizeof(bfloat16)))}); + return arrow::schema( + {arrow::field("val", + arrow::fixed_size_binary(dim * sizeof(bfloat16)), + nullable)}); } case DataType::VECTOR_SPARSE_FLOAT: { - return arrow::schema({arrow::field("val", arrow::binary())}); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable)}); } default: { PanicInfo( @@ -499,7 +538,7 @@ EncodeAndUploadIndexSlice(ChunkManager* chunk_manager, IndexMeta index_meta, FieldDataMeta field_meta, std::string object_key) { - auto field_data = CreateFieldData(DataType::INT8); + auto field_data = CreateFieldData(DataType::INT8, false); field_data->FillFieldData(buf, batch_size); auto indexData = std::make_shared(field_data); indexData->set_index_meta(index_meta); @@ -518,7 +557,7 @@ EncodeAndUploadIndexSlice2(std::shared_ptr space, IndexMeta index_meta, FieldDataMeta field_meta, std::string object_key) { - auto field_data = CreateFieldData(DataType::INT8); + auto field_data = CreateFieldData(DataType::INT8, false); field_data->FillFieldData(buf, batch_size); auto indexData = std::make_shared(field_data); indexData->set_index_meta(index_meta); @@ -538,8 +577,10 @@ EncodeAndUploadFieldSlice(ChunkManager* chunk_manager, FieldDataMeta field_data_meta, const FieldMeta& field_meta, std::string object_key) { - auto field_data = - CreateFieldData(field_meta.get_data_type(), field_meta.get_dim(), 0); + auto field_data = CreateFieldData(field_meta.get_data_type(), + field_meta.is_nullable(), + field_meta.get_dim(), + 0); field_data->FillFieldData(buf, element_count); auto insertData = std::make_shared(field_data); insertData->SetFieldDataMeta(field_data_meta); @@ -745,30 +786,42 @@ CreateChunkManager(const StorageConfig& storage_config) { } FieldDataPtr -CreateFieldData(const DataType& type, int64_t dim, int64_t total_num_rows) { +CreateFieldData(const DataType& type, + bool nullable, + int64_t dim, + int64_t total_num_rows) { switch (type) { case DataType::BOOL: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::INT8: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::INT16: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::INT32: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::INT64: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::FLOAT: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::DOUBLE: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::STRING: case DataType::VARCHAR: - return std::make_shared>(type, - total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::JSON: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::ARRAY: - return std::make_shared>(type, total_num_rows); + return std::make_shared>( + type, nullable, total_num_rows); case DataType::VECTOR_FLOAT: return std::make_shared>( dim, type, total_num_rows); @@ -825,11 +878,16 @@ MergeFieldData(std::vector& data_array) { for (const auto& data : data_array) { total_length += data->Length(); } - - auto merged_data = storage::CreateFieldData(data_array[0]->get_data_type()); + auto merged_data = storage::CreateFieldData(data_array[0]->get_data_type(), + data_array[0]->IsNullable()); merged_data->Reserve(total_length); for (const auto& data : data_array) { - merged_data->FillFieldData(data->Data(), data->Length()); + if (merged_data->IsNullable()) { + merged_data->FillFieldData( + data->Data(), data->ValidData(), data->Length()); + } else { + merged_data->FillFieldData(data->Data(), data->Length()); + } } return merged_data; } diff --git a/internal/core/src/storage/Util.h b/internal/core/src/storage/Util.h index acb6d233c0e00..b10f8c40c28f4 100644 --- a/internal/core/src/storage/Util.h +++ b/internal/core/src/storage/Util.h @@ -58,10 +58,10 @@ std::shared_ptr CreateArrowBuilder(DataType data_type, int dim); std::shared_ptr -CreateArrowSchema(DataType data_type); +CreateArrowSchema(DataType data_type, bool nullable); std::shared_ptr -CreateArrowSchema(DataType data_type, int dim); +CreateArrowSchema(DataType data_type, int dim, bool nullable); int GetDimensionFromFileMetaData(const parquet::ColumnDescriptor* schema, @@ -156,6 +156,7 @@ CreateChunkManager(const StorageConfig& storage_config); FieldDataPtr CreateFieldData(const DataType& type, + bool nullable, int64_t dim = 1, int64_t total_num_rows = 0); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 433b327aed190..55b339c3648b3 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -109,6 +109,38 @@ get_default_schema_config() { return conf.c_str(); } +const char* +get_default_schema_config_nullable() { + static std::string conf = R"(name: "default-collection" + fields: < + fieldID: 100 + name: "fakevec" + data_type: FloatVector + type_params: < + key: "dim" + value: "16" + > + index_params: < + key: "metric_type" + value: "L2" + > + > + fields: < + fieldID: 101 + name: "age" + data_type: Int64 + is_primary_key: true + > + fields: < + fieldID: 102 + name: "nullable" + data_type: Int32 + nullable:true + >)"; + static std::string fake_conf = ""; + return conf.c_str(); +} + const char* get_float16_schema_config() { static std::string conf = R"(name: "float16-collection" @@ -1048,6 +1080,74 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) { DeleteSegment(segment); } +TEST(CApiTest, SearcTestWhenNullable) { + auto c_collection = NewCollection(get_default_schema_config_nullable()); + CSegmentInterface segment; + auto status = NewSegment(c_collection, Growing, -1, &segment); + ASSERT_EQ(status.error_code, Success); + auto col = (milvus::segcore::Collection*)c_collection; + + int N = 10000; + auto dataset = DataGen(col->get_schema(), N); + int64_t ts_offset = 1000; + + int64_t offset; + PreInsert(segment, N, &offset); + + auto insert_data = serialize(dataset.raw_); + auto ins_res = Insert(segment, + offset, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + insert_data.data(), + insert_data.size()); + ASSERT_EQ(ins_res.error_code, Success); + + milvus::proto::plan::PlanNode plan_node; + auto vector_anns = plan_node.mutable_vector_anns(); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); + vector_anns->set_placeholder_tag("$0"); + vector_anns->set_field_id(100); + auto query_info = vector_anns->mutable_query_info(); + query_info->set_topk(10); + query_info->set_round_decimal(3); + query_info->set_metric_type("L2"); + query_info->set_search_params(R"({"nprobe": 10})"); + auto plan_str = plan_node.SerializeAsString(); + + int num_queries = 10; + auto blob = generate_query_data(num_queries); + + void* plan = nullptr; + status = CreateSearchPlanByExpr( + c_collection, plan_str.data(), plan_str.size(), &plan); + ASSERT_EQ(status.error_code, Success); + + void* placeholderGroup = nullptr; + status = ParsePlaceholderGroup( + plan, blob.data(), blob.length(), &placeholderGroup); + ASSERT_EQ(status.error_code, Success); + + std::vector placeholderGroups; + placeholderGroups.push_back(placeholderGroup); + + CSearchResult search_result; + auto res = Search(segment, plan, placeholderGroup, {}, &search_result); + ASSERT_EQ(res.error_code, Success); + + CSearchResult search_result2; + auto res2 = Search(segment, plan, placeholderGroup, {}, &search_result2); + ASSERT_EQ(res2.error_code, Success); + + DeleteSearchPlan(plan); + DeletePlaceholderGroup(placeholderGroup); + DeleteSearchResult(search_result); + DeleteSearchResult(search_result2); + DeleteCollection(c_collection); + DeleteSegment(segment); +} + TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) { auto collection = NewCollection(get_default_schema_config()); CSegmentInterface segment; diff --git a/internal/core/unittest/test_data_codec.cpp b/internal/core/unittest/test_data_codec.cpp index 0a4e7b36ff657..8b5fe06fa8bdd 100644 --- a/internal/core/unittest/test_data_codec.cpp +++ b/internal/core/unittest/test_data_codec.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include +#include #include "storage/DataCodec.h" #include "storage/InsertData.h" @@ -22,14 +23,16 @@ #include "storage/Util.h" #include "common/Consts.h" #include "common/Json.h" -#include "test_utils/Constants.h" -#include "test_utils/DataGen.h" +#include +#include +#include using namespace milvus; TEST(storage, InsertDataBool) { FixedVector data = {true, false, true, false, true}; - auto field_data = milvus::storage::CreateFieldData(storage::DataType::BOOL); + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::BOOL, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -48,14 +51,51 @@ TEST(storage, InsertDataBool) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::BOOL); ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataBoolNullable) { + FixedVector data = {true, false, false, false, true}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::BOOL, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::BOOL); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 2); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + // valid_data is 0001 0011, read from LSB, '1' means the according index is valid + ASSERT_EQ(data[0], new_data[0]); + ASSERT_EQ(data[1], new_data[1]); + ASSERT_EQ(data[4], new_data[4]); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; +} + TEST(storage, InsertDataInt8) { FixedVector data = {1, 2, 3, 4, 5}; - auto field_data = milvus::storage::CreateFieldData(storage::DataType::INT8); + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::INT8, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -63,6 +103,35 @@ TEST(storage, InsertDataInt8) { insert_data.SetFieldDataMeta(field_data_meta); insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT8); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataInt8Nullable) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::INT8, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); std::shared_ptr serialized_data_ptr(serialized_bytes.data(), [&](uint8_t*) {}); @@ -76,13 +145,17 @@ TEST(storage, InsertDataInt8) { ASSERT_EQ(new_payload->get_num_rows(), data.size()); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {1, 2, 0, 0, 5}; ASSERT_EQ(data, new_data); + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; } TEST(storage, InsertDataInt16) { FixedVector data = {1, 2, 3, 4, 5}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::INT16); + milvus::storage::CreateFieldData(storage::DataType::INT16, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -101,15 +174,48 @@ TEST(storage, InsertDataInt16) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT16); ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataInt16Nullable) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::INT16, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT16); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {1, 2, 0, 0, 5}; + ASSERT_EQ(data, new_data); + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; +} + TEST(storage, InsertDataInt32) { FixedVector data = {true, false, true, false, true}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::INT32); + milvus::storage::CreateFieldData(storage::DataType::INT32, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -117,6 +223,35 @@ TEST(storage, InsertDataInt32) { insert_data.SetFieldDataMeta(field_data_meta); insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT32); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataInt32Nullable) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::INT32, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); std::shared_ptr serialized_data_ptr(serialized_bytes.data(), [&](uint8_t*) {}); @@ -130,13 +265,17 @@ TEST(storage, InsertDataInt32) { ASSERT_EQ(new_payload->get_num_rows(), data.size()); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {1, 2, 0, 0, 5}; ASSERT_EQ(data, new_data); + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; } TEST(storage, InsertDataInt64) { FixedVector data = {1, 2, 3, 4, 5}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::INT64); + milvus::storage::CreateFieldData(storage::DataType::INT64, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -155,16 +294,49 @@ TEST(storage, InsertDataInt64) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT64); ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataInt64Nullable) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::INT64, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT64); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {1, 2, 0, 0, 5}; + ASSERT_EQ(data, new_data); + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; +} + TEST(storage, InsertDataString) { FixedVector data = { "test1", "test2", "test3", "test4", "test5"}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::VARCHAR); + milvus::storage::CreateFieldData(storage::DataType::VARCHAR, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -184,18 +356,56 @@ TEST(storage, InsertDataString) { ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VARCHAR); ASSERT_EQ(new_payload->get_num_rows(), data.size()); FixedVector new_data(data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); for (int i = 0; i < data.size(); ++i) { new_data[i] = *static_cast(new_payload->RawValue(i)); - ASSERT_EQ(new_payload->Size(i), data[i].size()); + ASSERT_EQ(new_payload->DataSize(i), data[i].size()); } ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataStringNullable) { + FixedVector data = { + "test1", "test2", "test3", "test4", "test5"}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::STRING, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::STRING); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {"test1", "test2", "", "", "test5"}; + for (int i = 0; i < data.size(); ++i) { + new_data[i] = + *static_cast(new_payload->RawValue(i)); + ASSERT_EQ(new_payload->DataSize(i), data[i].size()); + } + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; +} + TEST(storage, InsertDataFloat) { FixedVector data = {1, 2, 3, 4, 5}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::FLOAT); + milvus::storage::CreateFieldData(storage::DataType::FLOAT, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -203,6 +413,35 @@ TEST(storage, InsertDataFloat) { insert_data.SetFieldDataMeta(field_data_meta); insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::FLOAT); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataFloatNullable) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::FLOAT, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); std::shared_ptr serialized_data_ptr(serialized_bytes.data(), [&](uint8_t*) {}); @@ -216,13 +455,16 @@ TEST(storage, InsertDataFloat) { ASSERT_EQ(new_payload->get_num_rows(), data.size()); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {1, 2, 0, 0, 5}; ASSERT_EQ(data, new_data); + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); } TEST(storage, InsertDataDouble) { FixedVector data = {1.0, 2.0, 3.0, 4.2, 5.3}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::DOUBLE); + milvus::storage::CreateFieldData(storage::DataType::DOUBLE, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -241,16 +483,49 @@ TEST(storage, InsertDataDouble) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::DOUBLE); ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 0); FixedVector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataDoubleNullable) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::DOUBLE, true); + uint8_t* valid_data = new uint8_t[1]{0x13}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::DOUBLE); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + data = {1, 2, 0, 0, 5}; + ASSERT_EQ(data, new_data); + ASSERT_EQ(new_payload->get_null_count(), 2); + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; +} + TEST(storage, InsertDataFloatVector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; int DIM = 2; - auto field_data = - milvus::storage::CreateFieldData(storage::DataType::VECTOR_FLOAT, DIM); + auto field_data = milvus::storage::CreateFieldData( + storage::DataType::VECTOR_FLOAT, false, DIM); field_data->FillFieldData(data.data(), data.size() / DIM); storage::InsertData insert_data(field_data); @@ -269,6 +544,7 @@ TEST(storage, InsertDataFloatVector) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_FLOAT); ASSERT_EQ(new_payload->get_num_rows(), data.size() / DIM); + ASSERT_EQ(new_payload->get_null_count(), 0); std::vector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), @@ -281,7 +557,7 @@ TEST(storage, InsertDataSparseFloat) { auto vecs = milvus::segcore::GenerateRandomSparseFloatVector( n_rows, kTestSparseDim, kTestSparseVectorDensity); auto field_data = milvus::storage::CreateFieldData( - storage::DataType::VECTOR_SPARSE_FLOAT, kTestSparseDim, n_rows); + storage::DataType::VECTOR_SPARSE_FLOAT, false, kTestSparseDim, n_rows); field_data->FillFieldData(vecs.get(), n_rows); storage::InsertData insert_data(field_data); @@ -301,6 +577,7 @@ TEST(storage, InsertDataSparseFloat) { ASSERT_TRUE(new_payload->get_data_type() == storage::DataType::VECTOR_SPARSE_FLOAT); ASSERT_EQ(new_payload->get_num_rows(), n_rows); + ASSERT_EQ(new_payload->get_null_count(), 0); auto new_data = static_cast*>( new_payload->Data()); @@ -318,8 +595,8 @@ TEST(storage, InsertDataSparseFloat) { TEST(storage, InsertDataBinaryVector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; int DIM = 16; - auto field_data = - milvus::storage::CreateFieldData(storage::DataType::VECTOR_BINARY, DIM); + auto field_data = milvus::storage::CreateFieldData( + storage::DataType::VECTOR_BINARY, false, DIM); field_data->FillFieldData(data.data(), data.size() * 8 / DIM); storage::InsertData insert_data(field_data); @@ -338,6 +615,7 @@ TEST(storage, InsertDataBinaryVector) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_BINARY); ASSERT_EQ(new_payload->get_num_rows(), data.size() * 8 / DIM); + ASSERT_EQ(new_payload->get_null_count(), 0); std::vector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); @@ -347,7 +625,7 @@ TEST(storage, InsertDataFloat16Vector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; int DIM = 2; auto field_data = milvus::storage::CreateFieldData( - storage::DataType::VECTOR_FLOAT16, DIM); + storage::DataType::VECTOR_FLOAT16, false, DIM); field_data->FillFieldData(data.data(), data.size() / DIM); storage::InsertData insert_data(field_data); @@ -366,6 +644,7 @@ TEST(storage, InsertDataFloat16Vector) { auto new_payload = new_insert_data->GetFieldData(); ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_FLOAT16); ASSERT_EQ(new_payload->get_num_rows(), data.size() / DIM); + ASSERT_EQ(new_payload->get_null_count(), 0); std::vector new_data(data.size()); memcpy(new_data.data(), new_payload->Data(), @@ -373,39 +652,10 @@ TEST(storage, InsertDataFloat16Vector) { ASSERT_EQ(data, new_data); } -TEST(storage, InsertDataBFloat16Vector) { - std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; - int DIM = 2; - auto field_data = milvus::storage::CreateFieldData( - storage::DataType::VECTOR_BFLOAT16, DIM); - field_data->FillFieldData(data.data(), data.size() / DIM); - - storage::InsertData insert_data(field_data); - storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; - insert_data.SetFieldDataMeta(field_data_meta); - insert_data.SetTimestamps(0, 100); - - auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); - std::shared_ptr serialized_data_ptr(serialized_bytes.data(), - [&](uint8_t*) {}); - auto new_insert_data = storage::DeserializeFileData( - serialized_data_ptr, serialized_bytes.size()); - ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); - ASSERT_EQ(new_insert_data->GetTimeRage(), - std::make_pair(Timestamp(0), Timestamp(100))); - auto new_payload = new_insert_data->GetFieldData(); - ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_BFLOAT16); - ASSERT_EQ(new_payload->get_num_rows(), data.size() / DIM); - std::vector new_data(data.size()); - memcpy(new_data.data(), - new_payload->Data(), - new_payload->get_num_rows() * sizeof(bfloat16) * DIM); - ASSERT_EQ(data, new_data); -} - TEST(storage, IndexData) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; - auto field_data = milvus::storage::CreateFieldData(storage::DataType::INT8); + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::INT8, false); field_data->FillFieldData(data.data(), data.size()); storage::IndexData index_data(field_data); @@ -441,7 +691,7 @@ TEST(storage, InsertDataStringArray) { auto string_array = Array(field_string_data); FixedVector data = {string_array}; auto field_data = - milvus::storage::CreateFieldData(storage::DataType::ARRAY); + milvus::storage::CreateFieldData(storage::DataType::ARRAY, false); field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); @@ -463,7 +713,56 @@ TEST(storage, InsertDataStringArray) { FixedVector new_data(data.size()); for (int i = 0; i < data.size(); ++i) { new_data[i] = *static_cast(new_payload->RawValue(i)); - ASSERT_EQ(new_payload->Size(i), data[i].byte_size()); + ASSERT_EQ(new_payload->DataSize(i), data[i].byte_size()); ASSERT_TRUE(data[i].operator==(new_data[i])); } } + +TEST(storage, InsertDataStringArrayNullable) { + milvus::proto::schema::ScalarField field_string_data; + field_string_data.mutable_string_data()->add_data("test_array1"); + field_string_data.mutable_string_data()->add_data("test_array2"); + field_string_data.mutable_string_data()->add_data("test_array3"); + field_string_data.mutable_string_data()->add_data("test_array4"); + field_string_data.mutable_string_data()->add_data("test_array5"); + auto string_array = Array(field_string_data); + milvus::proto::schema::ScalarField field_int_data; + field_string_data.mutable_int_data()->add_data(1); + field_string_data.mutable_int_data()->add_data(2); + field_string_data.mutable_int_data()->add_data(3); + field_string_data.mutable_int_data()->add_data(4); + field_string_data.mutable_int_data()->add_data(5); + auto int_array = Array(field_int_data); + FixedVector data = {string_array, int_array}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::ARRAY, true); + uint8_t* valid_data = new uint8_t[1]{0x01}; + field_data->FillFieldData(data.data(), valid_data, data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::ARRAY); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + ASSERT_EQ(new_payload->get_null_count(), 1); + FixedVector expected_data = {string_array, Array()}; + FixedVector new_data(data.size()); + for (int i = 0; i < data.size(); ++i) { + new_data[i] = *static_cast(new_payload->RawValue(i)); + ASSERT_EQ(new_payload->DataSize(i), expected_data[i].byte_size()); + ASSERT_TRUE(expected_data[i].operator==(new_data[i])); + } + ASSERT_EQ(*new_payload->ValidData(), *valid_data); + delete[] valid_data; +} diff --git a/internal/core/unittest/test_growing.cpp b/internal/core/unittest/test_growing.cpp index f5421384e02fb..1ec8e53b2df8d 100644 --- a/internal/core/unittest/test_growing.cpp +++ b/internal/core/unittest/test_growing.cpp @@ -263,5 +263,176 @@ TEST_P(GrowingTest, FillData) { num_inserted); EXPECT_EQ(float_array_result->scalars().array_data().data_size(), num_inserted); + + EXPECT_EQ(bool_result->valid_data_size(), 0); + EXPECT_EQ(int8_result->valid_data_size(), 0); + EXPECT_EQ(int16_result->valid_data_size(), 0); + EXPECT_EQ(int32_result->valid_data_size(), 0); + EXPECT_EQ(int64_result->valid_data_size(), 0); + EXPECT_EQ(float_result->valid_data_size(), 0); + EXPECT_EQ(double_result->valid_data_size(), 0); + EXPECT_EQ(varchar_result->valid_data_size(), 0); + EXPECT_EQ(json_result->valid_data_size(), 0); + EXPECT_EQ(int_array_result->valid_data_size(), 0); + EXPECT_EQ(long_array_result->valid_data_size(), 0); + EXPECT_EQ(bool_array_result->valid_data_size(), 0); + EXPECT_EQ(string_array_result->valid_data_size(), 0); + EXPECT_EQ(double_array_result->valid_data_size(), 0); + EXPECT_EQ(float_array_result->valid_data_size(), 0); + } +} + +TEST(Growing, FillNullableData) { + auto schema = std::make_shared(); + auto metric_type = knowhere::metric::L2; + auto bool_field = schema->AddDebugField("bool", DataType::BOOL, true); + auto int8_field = schema->AddDebugField("int8", DataType::INT8, true); + auto int16_field = schema->AddDebugField("int16", DataType::INT16, true); + auto int32_field = schema->AddDebugField("int32", DataType::INT32, true); + auto int64_field = schema->AddDebugField("int64", DataType::INT64); + auto float_field = schema->AddDebugField("float", DataType::FLOAT, true); + auto double_field = schema->AddDebugField("double", DataType::DOUBLE, true); + auto varchar_field = + schema->AddDebugField("varchar", DataType::VARCHAR, true); + auto json_field = schema->AddDebugField("json", DataType::JSON, true); + auto int_array_field = schema->AddDebugField( + "int_array", DataType::ARRAY, DataType::INT8, true); + auto long_array_field = schema->AddDebugField( + "long_array", DataType::ARRAY, DataType::INT64, true); + auto bool_array_field = schema->AddDebugField( + "bool_array", DataType::ARRAY, DataType::BOOL, true); + auto string_array_field = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR, true); + auto double_array_field = schema->AddDebugField( + "double_array", DataType::ARRAY, DataType::DOUBLE, true); + auto float_array_field = schema->AddDebugField( + "float_array", DataType::ARRAY, DataType::FLOAT, true); + auto vec = schema->AddDebugField( + "embeddings", DataType::VECTOR_FLOAT, 128, metric_type); + schema->set_primary_field_id(int64_field); + + std::map index_params = { + {"index_type", "IVF_FLAT"}, + {"metric_type", metric_type}, + {"nlist", "128"}}; + std::map type_params = {{"dim", "128"}}; + FieldIndexMeta fieldIndexMeta( + vec, std::move(index_params), std::move(type_params)); + auto config = SegcoreConfig::default_config(); + config.set_chunk_rows(1024); + config.set_enable_interim_segment_index(true); + std::map filedMap = {{vec, fieldIndexMeta}}; + IndexMetaPtr metaPtr = + std::make_shared(100000, std::move(filedMap)); + auto segment_growing = CreateGrowingSegment(schema, metaPtr, 1, config); + auto segment = dynamic_cast(segment_growing.get()); + + int64_t per_batch = 1000; + int64_t n_batch = 3; + int64_t dim = 128; + for (int64_t i = 0; i < n_batch; i++) { + auto dataset = DataGen(schema, per_batch); + auto bool_values = dataset.get_col(bool_field); + auto int8_values = dataset.get_col(int8_field); + auto int16_values = dataset.get_col(int16_field); + auto int32_values = dataset.get_col(int32_field); + auto int64_values = dataset.get_col(int64_field); + auto float_values = dataset.get_col(float_field); + auto double_values = dataset.get_col(double_field); + auto varchar_values = dataset.get_col(varchar_field); + auto json_values = dataset.get_col(json_field); + auto int_array_values = dataset.get_col(int_array_field); + auto long_array_values = dataset.get_col(long_array_field); + auto bool_array_values = dataset.get_col(bool_array_field); + auto string_array_values = + dataset.get_col(string_array_field); + auto double_array_values = + dataset.get_col(double_array_field); + auto float_array_values = + dataset.get_col(float_array_field); + auto vector_values = dataset.get_col(vec); + + auto offset = segment->PreInsert(per_batch); + segment->Insert(offset, + per_batch, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + auto num_inserted = (i + 1) * per_batch; + auto ids_ds = GenRandomIds(num_inserted); + auto bool_result = + segment->bulk_subscript(bool_field, ids_ds->GetIds(), num_inserted); + auto int8_result = + segment->bulk_subscript(int8_field, ids_ds->GetIds(), num_inserted); + auto int16_result = segment->bulk_subscript( + int16_field, ids_ds->GetIds(), num_inserted); + auto int32_result = segment->bulk_subscript( + int32_field, ids_ds->GetIds(), num_inserted); + auto int64_result = segment->bulk_subscript( + int64_field, ids_ds->GetIds(), num_inserted); + auto float_result = segment->bulk_subscript( + float_field, ids_ds->GetIds(), num_inserted); + auto double_result = segment->bulk_subscript( + double_field, ids_ds->GetIds(), num_inserted); + auto varchar_result = segment->bulk_subscript( + varchar_field, ids_ds->GetIds(), num_inserted); + auto json_result = + segment->bulk_subscript(json_field, ids_ds->GetIds(), num_inserted); + auto int_array_result = segment->bulk_subscript( + int_array_field, ids_ds->GetIds(), num_inserted); + auto long_array_result = segment->bulk_subscript( + long_array_field, ids_ds->GetIds(), num_inserted); + auto bool_array_result = segment->bulk_subscript( + bool_array_field, ids_ds->GetIds(), num_inserted); + auto string_array_result = segment->bulk_subscript( + string_array_field, ids_ds->GetIds(), num_inserted); + auto double_array_result = segment->bulk_subscript( + double_array_field, ids_ds->GetIds(), num_inserted); + auto float_array_result = segment->bulk_subscript( + float_array_field, ids_ds->GetIds(), num_inserted); + auto vec_result = + segment->bulk_subscript(vec, ids_ds->GetIds(), num_inserted); + + EXPECT_EQ(bool_result->scalars().bool_data().data_size(), num_inserted); + EXPECT_EQ(int8_result->scalars().int_data().data_size(), num_inserted); + EXPECT_EQ(int16_result->scalars().int_data().data_size(), num_inserted); + EXPECT_EQ(int32_result->scalars().int_data().data_size(), num_inserted); + EXPECT_EQ(int64_result->scalars().long_data().data_size(), + num_inserted); + EXPECT_EQ(float_result->scalars().float_data().data_size(), + num_inserted); + EXPECT_EQ(double_result->scalars().double_data().data_size(), + num_inserted); + EXPECT_EQ(varchar_result->scalars().string_data().data_size(), + num_inserted); + EXPECT_EQ(json_result->scalars().json_data().data_size(), num_inserted); + EXPECT_EQ(vec_result->vectors().float_vector().data_size(), + num_inserted * dim); + EXPECT_EQ(int_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(long_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(bool_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(string_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(double_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(float_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(bool_result->valid_data_size(), num_inserted); + EXPECT_EQ(int8_result->valid_data_size(), num_inserted); + EXPECT_EQ(int16_result->valid_data_size(), num_inserted); + EXPECT_EQ(int32_result->valid_data_size(), num_inserted); + EXPECT_EQ(float_result->valid_data_size(), num_inserted); + EXPECT_EQ(double_result->valid_data_size(), num_inserted); + EXPECT_EQ(varchar_result->valid_data_size(), num_inserted); + EXPECT_EQ(json_result->valid_data_size(), num_inserted); + EXPECT_EQ(int_array_result->valid_data_size(), 1); + EXPECT_EQ(long_array_result->valid_data_size(), 1); + EXPECT_EQ(bool_array_result->valid_data_size(), 1); + EXPECT_EQ(string_array_result->valid_data_size(), 1); + EXPECT_EQ(double_array_result->valid_data_size(), 1); + EXPECT_EQ(float_array_result->valid_data_size(), 1); } } diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 81abab1586b1e..ef564436be8a5 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -555,6 +555,7 @@ TEST(Query, FillSegment) { { auto field = proto.add_fields(); field->set_name("fakevec"); + field->set_nullable(false); field->set_is_primary_key(false); field->set_description("asdgfsagf"); field->set_fieldid(100); @@ -570,6 +571,7 @@ TEST(Query, FillSegment) { { auto field = proto.add_fields(); field->set_name("the_key"); + field->set_nullable(false); field->set_fieldid(101); field->set_is_primary_key(true); field->set_description("asdgfsagf"); @@ -579,6 +581,7 @@ TEST(Query, FillSegment) { { auto field = proto.add_fields(); field->set_name("the_value"); + field->set_nullable(false); field->set_fieldid(102); field->set_is_primary_key(false); field->set_description("asdgfsagf"); @@ -595,6 +598,7 @@ TEST(Query, FillSegment) { dataset.get_col(FieldId(100)); // vector field const auto std_i32_vec = dataset.get_col(FieldId(102)); // scalar field + const auto i32_vec_valid_data = dataset.get_col_valid(FieldId(102)); std::vector> segments; segments.emplace_back([&] { @@ -659,6 +663,8 @@ TEST(Query, FillSegment) { auto output_i32_field_data = fields_data.at(i32_field_id)->scalars().int_data().data(); ASSERT_EQ(output_i32_field_data.size(), topk * num_queries); + auto output_i32_valid_data = fields_data.at(i32_field_id)->valid_data(); + ASSERT_EQ(output_i32_valid_data.size(), topk * num_queries); for (int i = 0; i < topk * num_queries; i++) { int64_t val = std::get(result->primary_keys_[i]); @@ -666,6 +672,7 @@ TEST(Query, FillSegment) { auto internal_offset = result->seg_offsets_[i]; auto std_val = std_vec[internal_offset]; auto std_i32 = std_i32_vec[internal_offset]; + auto std_i32_valid = i32_vec_valid_data[internal_offset]; std::vector std_vfloat(dim); std::copy_n(std_vfloat_vec.begin() + dim * internal_offset, dim, @@ -684,6 +691,10 @@ TEST(Query, FillSegment) { int i32; memcpy(&i32, &output_i32_field_data[i], sizeof(int32_t)); ASSERT_EQ(i32, std_i32); + // check int32 valid field + bool i32_valid; + memcpy(&i32_valid, &output_i32_valid_data[i], sizeof(bool)); + ASSERT_EQ(i32_valid, std_i32_valid); } } } diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index 1fed5034ce518..43a7929bdf5e0 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -861,7 +861,7 @@ TEST(Sealed, LoadScalarIndex) { FieldMeta row_id_field_meta( FieldName("RowID"), RowFieldID, DataType::INT64); auto field_data = - std::make_shared>(DataType::INT64); + std::make_shared>(DataType::INT64, false); field_data->FillFieldData(dataset.row_ids_.data(), N); auto field_data_info = FieldDataInfo{ RowFieldID.get(), N, std::vector{field_data}}; @@ -870,7 +870,8 @@ TEST(Sealed, LoadScalarIndex) { LoadFieldDataInfo ts_info; FieldMeta ts_field_meta( FieldName("Timestamp"), TimestampFieldID, DataType::INT64); - field_data = std::make_shared>(DataType::INT64); + field_data = + std::make_shared>(DataType::INT64, false); field_data->FillFieldData(dataset.timestamps_.data(), N); field_data_info = FieldDataInfo{ TimestampFieldID.get(), N, std::vector{field_data}}; @@ -1138,7 +1139,8 @@ TEST(Sealed, BF) { SealedLoadFieldData(dataset, *segment, {fake_id.get()}); auto vec_data = GenRandomFloatVecs(N, dim); - auto field_data = storage::CreateFieldData(DataType::VECTOR_FLOAT, dim); + auto field_data = + storage::CreateFieldData(DataType::VECTOR_FLOAT, false, dim); field_data->FillFieldData(vec_data.data(), N); auto field_data_info = FieldDataInfo{fake_id.get(), N, std::vector{field_data}}; @@ -1192,7 +1194,8 @@ TEST(Sealed, BF_Overflow) { SealedLoadFieldData(dataset, *segment, {fake_id.get()}); auto vec_data = GenMaxFloatVecs(N, dim); - auto field_data = storage::CreateFieldData(DataType::VECTOR_FLOAT, dim); + auto field_data = + storage::CreateFieldData(DataType::VECTOR_FLOAT, false, dim); field_data->FillFieldData(vec_data.data(), N); auto field_data_info = FieldDataInfo{fake_id.get(), N, std::vector{field_data}}; @@ -1719,7 +1722,8 @@ TEST(Sealed, SkipIndexSkipUnaryRange) { //test for int64 std::vector pks = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - auto pk_field_data = storage::CreateFieldData(DataType::INT64, 1, 10); + auto pk_field_data = + storage::CreateFieldData(DataType::INT64, false, 1, 10); pk_field_data->FillFieldData(pks.data(), N); segment->LoadPrimitiveSkipIndex( pk_fid, 0, DataType::INT64, pk_field_data->Data(), N); @@ -1760,7 +1764,8 @@ TEST(Sealed, SkipIndexSkipUnaryRange) { //test for int32 std::vector int32s = {2, 2, 3, 4, 5, 6, 7, 8, 9, 12}; - auto int32_field_data = storage::CreateFieldData(DataType::INT32, 1, 10); + auto int32_field_data = + storage::CreateFieldData(DataType::INT32, false, 1, 10); int32_field_data->FillFieldData(int32s.data(), N); segment->LoadPrimitiveSkipIndex( i32_fid, 0, DataType::INT32, int32_field_data->Data(), N); @@ -1770,7 +1775,8 @@ TEST(Sealed, SkipIndexSkipUnaryRange) { //test for int16 std::vector int16s = {2, 2, 3, 4, 5, 6, 7, 8, 9, 12}; - auto int16_field_data = storage::CreateFieldData(DataType::INT16, 1, 10); + auto int16_field_data = + storage::CreateFieldData(DataType::INT16, false, 1, 10); int16_field_data->FillFieldData(int16s.data(), N); segment->LoadPrimitiveSkipIndex( i16_fid, 0, DataType::INT16, int16_field_data->Data(), N); @@ -1780,7 +1786,8 @@ TEST(Sealed, SkipIndexSkipUnaryRange) { //test for int8 std::vector int8s = {2, 2, 3, 4, 5, 6, 7, 8, 9, 12}; - auto int8_field_data = storage::CreateFieldData(DataType::INT8, 1, 10); + auto int8_field_data = + storage::CreateFieldData(DataType::INT8, false, 1, 10); int8_field_data->FillFieldData(int8s.data(), N); segment->LoadPrimitiveSkipIndex( i8_fid, 0, DataType::INT8, int8_field_data->Data(), N); @@ -1791,7 +1798,8 @@ TEST(Sealed, SkipIndexSkipUnaryRange) { // test for float std::vector floats = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; - auto float_field_data = storage::CreateFieldData(DataType::FLOAT, 1, 10); + auto float_field_data = + storage::CreateFieldData(DataType::FLOAT, false, 1, 10); float_field_data->FillFieldData(floats.data(), N); segment->LoadPrimitiveSkipIndex( float_fid, 0, DataType::FLOAT, float_field_data->Data(), N); @@ -1802,7 +1810,8 @@ TEST(Sealed, SkipIndexSkipUnaryRange) { // test for double std::vector doubles = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; - auto double_field_data = storage::CreateFieldData(DataType::DOUBLE, 1, 10); + auto double_field_data = + storage::CreateFieldData(DataType::DOUBLE, false, 1, 10); double_field_data->FillFieldData(doubles.data(), N); segment->LoadPrimitiveSkipIndex( double_fid, 0, DataType::DOUBLE, double_field_data->Data(), N); @@ -1825,7 +1834,8 @@ TEST(Sealed, SkipIndexSkipBinaryRange) { //test for int64 std::vector pks = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - auto pk_field_data = storage::CreateFieldData(DataType::INT64, 1, 10); + auto pk_field_data = + storage::CreateFieldData(DataType::INT64, false, 1, 10); pk_field_data->FillFieldData(pks.data(), N); segment->LoadPrimitiveSkipIndex( pk_fid, 0, DataType::INT64, pk_field_data->Data(), N); @@ -1860,7 +1870,8 @@ TEST(Sealed, SkipIndexSkipStringRange) { //test for string std::vector strings = {"e", "f", "g", "g", "j"}; - auto string_field_data = storage::CreateFieldData(DataType::VARCHAR, 1, N); + auto string_field_data = + storage::CreateFieldData(DataType::VARCHAR, false, 1, N); string_field_data->FillFieldData(strings.data(), N); auto string_field_data_info = FieldDataInfo{ string_fid.get(), N, std::vector{string_field_data}}; @@ -2036,4 +2047,174 @@ TEST(Sealed, QueryAllFields) { dataset_size); EXPECT_EQ(float_array_result->scalars().array_data().data_size(), dataset_size); + + EXPECT_EQ(bool_result->valid_data_size(), 0); + EXPECT_EQ(int8_result->valid_data_size(), 0); + EXPECT_EQ(int16_result->valid_data_size(), 0); + EXPECT_EQ(int32_result->valid_data_size(), 0); + EXPECT_EQ(int64_result->valid_data_size(), 0); + EXPECT_EQ(float_result->valid_data_size(), 0); + EXPECT_EQ(double_result->valid_data_size(), 0); + EXPECT_EQ(varchar_result->valid_data_size(), 0); + EXPECT_EQ(json_result->valid_data_size(), 0); + EXPECT_EQ(int_array_result->valid_data_size(), 0); + EXPECT_EQ(long_array_result->valid_data_size(), 0); + EXPECT_EQ(bool_array_result->valid_data_size(), 0); + EXPECT_EQ(string_array_result->valid_data_size(), 0); + EXPECT_EQ(double_array_result->valid_data_size(), 0); + EXPECT_EQ(float_array_result->valid_data_size(), 0); +} + +TEST(Sealed, QueryAllNullableFields) { + auto schema = std::make_shared(); + auto metric_type = knowhere::metric::L2; + auto bool_field = schema->AddDebugField("bool", DataType::BOOL, true); + auto int8_field = schema->AddDebugField("int8", DataType::INT8, true); + auto int16_field = schema->AddDebugField("int16", DataType::INT16, true); + auto int32_field = schema->AddDebugField("int32", DataType::INT32, true); + auto int64_field = schema->AddDebugField("int64", DataType::INT64, false); + auto float_field = schema->AddDebugField("float", DataType::FLOAT, true); + auto double_field = schema->AddDebugField("double", DataType::DOUBLE, true); + auto varchar_field = + schema->AddDebugField("varchar", DataType::VARCHAR, true); + auto json_field = schema->AddDebugField("json", DataType::JSON, true); + auto int_array_field = schema->AddDebugField( + "int_array", DataType::ARRAY, DataType::INT8, true); + auto long_array_field = schema->AddDebugField( + "long_array", DataType::ARRAY, DataType::INT64, true); + auto bool_array_field = schema->AddDebugField( + "bool_array", DataType::ARRAY, DataType::BOOL, true); + auto string_array_field = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR, true); + auto double_array_field = schema->AddDebugField( + "double_array", DataType::ARRAY, DataType::DOUBLE, true); + auto float_array_field = schema->AddDebugField( + "float_array", DataType::ARRAY, DataType::FLOAT, true); + auto vec = schema->AddDebugField( + "embeddings", DataType::VECTOR_FLOAT, 128, metric_type); + schema->set_primary_field_id(int64_field); + + std::map index_params = { + {"index_type", "IVF_FLAT"}, + {"metric_type", metric_type}, + {"nlist", "128"}}; + std::map type_params = {{"dim", "128"}}; + FieldIndexMeta fieldIndexMeta( + vec, std::move(index_params), std::move(type_params)); + std::map filedMap = {{vec, fieldIndexMeta}}; + IndexMetaPtr metaPtr = + std::make_shared(100000, std::move(filedMap)); + auto segment_sealed = CreateSealedSegment(schema, metaPtr); + auto segment = dynamic_cast(segment_sealed.get()); + + int64_t dataset_size = 1000; + int64_t dim = 128; + auto dataset = DataGen(schema, dataset_size); + SealedLoadFieldData(dataset, *segment); + + auto bool_values = dataset.get_col(bool_field); + auto int8_values = dataset.get_col(int8_field); + auto int16_values = dataset.get_col(int16_field); + auto int32_values = dataset.get_col(int32_field); + auto int64_values = dataset.get_col(int64_field); + auto float_values = dataset.get_col(float_field); + auto double_values = dataset.get_col(double_field); + auto varchar_values = dataset.get_col(varchar_field); + auto json_values = dataset.get_col(json_field); + auto int_array_values = dataset.get_col(int_array_field); + auto long_array_values = dataset.get_col(long_array_field); + auto bool_array_values = dataset.get_col(bool_array_field); + auto string_array_values = dataset.get_col(string_array_field); + auto double_array_values = dataset.get_col(double_array_field); + auto float_array_values = dataset.get_col(float_array_field); + auto vector_values = dataset.get_col(vec); + + auto bool_valid_values = dataset.get_col_valid(bool_field); + auto int8_valid_values = dataset.get_col_valid(int8_field); + auto int16_valid_values = dataset.get_col_valid(int16_field); + auto int32_valid_values = dataset.get_col_valid(int32_field); + auto float_valid_values = dataset.get_col_valid(float_field); + auto double_valid_values = dataset.get_col_valid(double_field); + auto varchar_valid_values = dataset.get_col_valid(varchar_field); + auto json_valid_values = dataset.get_col_valid(json_field); + auto int_array_valid_values = dataset.get_col_valid(int_array_field); + auto long_array_valid_values = dataset.get_col_valid(long_array_field); + auto bool_array_valid_values = dataset.get_col_valid(bool_array_field); + auto string_array_valid_values = dataset.get_col_valid(string_array_field); + auto double_array_valid_values = dataset.get_col_valid(double_array_field); + auto float_array_valid_values = dataset.get_col_valid(float_array_field); + + auto ids_ds = GenRandomIds(dataset_size); + auto bool_result = + segment->bulk_subscript(bool_field, ids_ds->GetIds(), dataset_size); + auto int8_result = + segment->bulk_subscript(int8_field, ids_ds->GetIds(), dataset_size); + auto int16_result = + segment->bulk_subscript(int16_field, ids_ds->GetIds(), dataset_size); + auto int32_result = + segment->bulk_subscript(int32_field, ids_ds->GetIds(), dataset_size); + auto int64_result = + segment->bulk_subscript(int64_field, ids_ds->GetIds(), dataset_size); + auto float_result = + segment->bulk_subscript(float_field, ids_ds->GetIds(), dataset_size); + auto double_result = + segment->bulk_subscript(double_field, ids_ds->GetIds(), dataset_size); + auto varchar_result = + segment->bulk_subscript(varchar_field, ids_ds->GetIds(), dataset_size); + auto json_result = + segment->bulk_subscript(json_field, ids_ds->GetIds(), dataset_size); + auto int_array_result = segment->bulk_subscript( + int_array_field, ids_ds->GetIds(), dataset_size); + auto long_array_result = segment->bulk_subscript( + long_array_field, ids_ds->GetIds(), dataset_size); + auto bool_array_result = segment->bulk_subscript( + bool_array_field, ids_ds->GetIds(), dataset_size); + auto string_array_result = segment->bulk_subscript( + string_array_field, ids_ds->GetIds(), dataset_size); + auto double_array_result = segment->bulk_subscript( + double_array_field, ids_ds->GetIds(), dataset_size); + auto float_array_result = segment->bulk_subscript( + float_array_field, ids_ds->GetIds(), dataset_size); + auto vec_result = + segment->bulk_subscript(vec, ids_ds->GetIds(), dataset_size); + + EXPECT_EQ(bool_result->scalars().bool_data().data_size(), dataset_size); + EXPECT_EQ(int8_result->scalars().int_data().data_size(), dataset_size); + EXPECT_EQ(int16_result->scalars().int_data().data_size(), dataset_size); + EXPECT_EQ(int32_result->scalars().int_data().data_size(), dataset_size); + EXPECT_EQ(int64_result->scalars().long_data().data_size(), dataset_size); + EXPECT_EQ(float_result->scalars().float_data().data_size(), dataset_size); + EXPECT_EQ(double_result->scalars().double_data().data_size(), dataset_size); + EXPECT_EQ(varchar_result->scalars().string_data().data_size(), + dataset_size); + EXPECT_EQ(json_result->scalars().json_data().data_size(), dataset_size); + EXPECT_EQ(vec_result->vectors().float_vector().data_size(), + dataset_size * dim); + EXPECT_EQ(int_array_result->scalars().array_data().data_size(), + dataset_size); + EXPECT_EQ(long_array_result->scalars().array_data().data_size(), + dataset_size); + EXPECT_EQ(bool_array_result->scalars().array_data().data_size(), + dataset_size); + EXPECT_EQ(string_array_result->scalars().array_data().data_size(), + dataset_size); + EXPECT_EQ(double_array_result->scalars().array_data().data_size(), + dataset_size); + EXPECT_EQ(float_array_result->scalars().array_data().data_size(), + dataset_size); + + EXPECT_EQ(bool_result->valid_data_size(), dataset_size); + EXPECT_EQ(int8_result->valid_data_size(), dataset_size); + EXPECT_EQ(int16_result->valid_data_size(), dataset_size); + EXPECT_EQ(int32_result->valid_data_size(), dataset_size); + EXPECT_EQ(float_result->valid_data_size(), dataset_size); + EXPECT_EQ(double_result->valid_data_size(), dataset_size); + EXPECT_EQ(varchar_result->valid_data_size(), dataset_size); + EXPECT_EQ(json_result->valid_data_size(), dataset_size); + EXPECT_EQ(int_array_result->valid_data_size(), 1); + EXPECT_EQ(long_array_result->valid_data_size(), 1); + EXPECT_EQ(bool_array_result->valid_data_size(), 1); + EXPECT_EQ(string_array_result->valid_data_size(), 1); + EXPECT_EQ(double_array_result->valid_data_size(), 1); + EXPECT_EQ(float_array_result->valid_data_size(), 1); } diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 9705b7f0c7ddd..be76dad96be2e 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -215,6 +215,21 @@ struct GeneratedData { return std::move(ret); } + FixedVector + get_col_valid(FieldId field_id) const { + for (const auto& target_field_data : raw_->fields_data()) { + if (field_id.get() == target_field_data.field_id()) { + auto& field_meta = schema_->operator[](field_id); + Assert(field_meta.is_nullable()); + FixedVector ret(raw_->num_rows()); + auto src_data = target_field_data.valid_data().data(); + std::copy_n(src_data, raw_->num_rows(), ret.data()); + return ret; + } + } + PanicInfo(FieldIDInvalid, "field id not find"); + } + std::unique_ptr get_col(FieldId field_id) const { for (const auto& target_field_data : raw_->fields_data()) { @@ -301,8 +316,15 @@ inline GeneratedData DataGen(SchemaPtr schema, auto insert_data = std::make_unique(); auto insert_cols = [&insert_data]( auto& data, int64_t count, auto& field_meta) { + auto nullable = field_meta.is_nullable(); + FixedVector valid_data(count); + if (nullable) { + for (int i = 0; i < count; ++i) { + valid_data[i] = i % 2 == 0 ? true : false; + } + } auto array = milvus::segcore::CreateDataArrayFrom( - data.data(), count, field_meta); + data.data(), valid_data.data(), count, field_meta); insert_data->mutable_fields_data()->AddAllocated(array.release()); }; @@ -587,7 +609,7 @@ DataGenForJsonArray(SchemaPtr schema, auto insert_cols = [&insert_data]( auto& data, int64_t count, auto& field_meta) { auto array = milvus::segcore::CreateDataArrayFrom( - data.data(), count, field_meta); + data.data(), nullptr, count, field_meta); insert_data->mutable_fields_data()->AddAllocated(array.release()); }; for (auto field_id : schema->get_field_ids()) { @@ -893,9 +915,30 @@ CreateFieldDataFromDataArray(ssize_t raw_count, auto createFieldData = [&field_data, &raw_count](const void* raw_data, DataType data_type, int64_t dim) { - field_data = storage::CreateFieldData(data_type, dim); + field_data = storage::CreateFieldData(data_type, false, dim); field_data->FillFieldData(raw_data, raw_count); }; + auto createNullableFieldData = [&field_data, &raw_count]( + const void* raw_data, + const bool* raw_valid_data, + DataType data_type, + int64_t dim) { + field_data = storage::CreateFieldData(data_type, true, dim); + int byteSize = (raw_count + 7) / 8; + auto valid_data = std::make_unique(byteSize); + auto valid_data_ptr = valid_data.get(); + for (int i = 0; i < raw_count; i++) { + bool value = raw_valid_data[i]; + int byteIndex = i / 8; + int bitIndex = i % 8; + if (value) { + valid_data_ptr[byteIndex] |= (1 << bitIndex); + } else { + valid_data_ptr[byteIndex] &= ~(1 << bitIndex); + } + } + field_data->FillFieldData(raw_data, valid_data.get(), raw_count); + }; if (field_meta.is_vector()) { switch (field_meta.get_data_type()) { @@ -938,48 +981,98 @@ CreateFieldDataFromDataArray(ssize_t raw_count, switch (field_meta.get_data_type()) { case DataType::BOOL: { auto raw_data = data->scalars().bool_data().data().data(); - createFieldData(raw_data, DataType::BOOL, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + raw_data, raw_valid_data, DataType::BOOL, dim); + } else { + createFieldData(raw_data, DataType::BOOL, dim); + } break; } case DataType::INT8: { auto src_data = data->scalars().int_data().data(); std::vector data_raw(src_data.size()); std::copy_n(src_data.data(), src_data.size(), data_raw.data()); - createFieldData(data_raw.data(), DataType::INT8, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + data_raw.data(), raw_valid_data, DataType::INT8, dim); + } else { + createFieldData(data_raw.data(), DataType::INT8, dim); + } break; } case DataType::INT16: { auto src_data = data->scalars().int_data().data(); std::vector data_raw(src_data.size()); std::copy_n(src_data.data(), src_data.size(), data_raw.data()); - createFieldData(data_raw.data(), DataType::INT16, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + data_raw.data(), raw_valid_data, DataType::INT16, dim); + } else { + createFieldData(data_raw.data(), DataType::INT16, dim); + } break; } case DataType::INT32: { auto raw_data = data->scalars().int_data().data().data(); - createFieldData(raw_data, DataType::INT32, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + raw_data, raw_valid_data, DataType::INT32, dim); + } else { + createFieldData(raw_data, DataType::INT32, dim); + } break; } case DataType::INT64: { auto raw_data = data->scalars().long_data().data().data(); - createFieldData(raw_data, DataType::INT64, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + raw_data, raw_valid_data, DataType::INT64, dim); + } else { + createFieldData(raw_data, DataType::INT64, dim); + } break; } case DataType::FLOAT: { auto raw_data = data->scalars().float_data().data().data(); - createFieldData(raw_data, DataType::FLOAT, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + raw_data, raw_valid_data, DataType::FLOAT, dim); + } else { + createFieldData(raw_data, DataType::FLOAT, dim); + } break; } case DataType::DOUBLE: { auto raw_data = data->scalars().double_data().data().data(); - createFieldData(raw_data, DataType::DOUBLE, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + raw_data, raw_valid_data, DataType::DOUBLE, dim); + } else { + createFieldData(raw_data, DataType::DOUBLE, dim); + } break; } case DataType::VARCHAR: { auto begin = data->scalars().string_data().data().begin(); auto end = data->scalars().string_data().data().end(); std::vector data_raw(begin, end); - createFieldData(data_raw.data(), DataType::VARCHAR, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData(data_raw.data(), + raw_valid_data, + DataType::VARCHAR, + dim); + } else { + createFieldData(data_raw.data(), DataType::VARCHAR, dim); + } break; } case DataType::JSON: { @@ -989,7 +1082,13 @@ CreateFieldDataFromDataArray(ssize_t raw_count, auto str = src_data.Get(i); data_raw[i] = Json(simdjson::padded_string(str)); } - createFieldData(data_raw.data(), DataType::JSON, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + data_raw.data(), raw_valid_data, DataType::JSON, dim); + } else { + createFieldData(data_raw.data(), DataType::JSON, dim); + } break; } case DataType::ARRAY: { @@ -998,7 +1097,13 @@ CreateFieldDataFromDataArray(ssize_t raw_count, for (int i = 0; i < src_data.size(); i++) { data_raw[i] = Array(src_data.at(i)); } - createFieldData(data_raw.data(), DataType::ARRAY, dim); + if (field_meta.is_nullable()) { + auto raw_valid_data = data->valid_data().data(); + createNullableFieldData( + data_raw.data(), raw_valid_data, DataType::ARRAY, dim); + } else { + createFieldData(data_raw.data(), DataType::ARRAY, dim); + } break; } default: { @@ -1017,8 +1122,8 @@ SealedLoadFieldData(const GeneratedData& dataset, bool with_mmap = false) { auto row_count = dataset.row_ids_.size(); { - auto field_data = - std::make_shared>(DataType::INT64); + auto field_data = std::make_shared>( + DataType::INT64, false); field_data->FillFieldData(dataset.row_ids_.data(), row_count); auto field_data_info = FieldDataInfo(RowFieldID.get(), @@ -1027,8 +1132,8 @@ SealedLoadFieldData(const GeneratedData& dataset, seg.LoadFieldData(RowFieldID, field_data_info); } { - auto field_data = - std::make_shared>(DataType::INT64); + auto field_data = std::make_shared>( + DataType::INT64, false); field_data->FillFieldData(dataset.timestamps_.data(), row_count); auto field_data_info = FieldDataInfo(TimestampFieldID.get(), diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index 7eca359f3043d..866ee8c6eea2a 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -75,15 +75,15 @@ PrepareInsertBinlog(int64_t collection_id, }; { - auto field_data = - std::make_shared>(DataType::INT64); + auto field_data = std::make_shared>( + DataType::INT64, false); field_data->FillFieldData(dataset.row_ids_.data(), row_count); auto path = prefix + "/" + std::to_string(RowFieldID.get()); SaveFieldData(field_data, path, RowFieldID.get()); } { - auto field_data = - std::make_shared>(DataType::INT64); + auto field_data = std::make_shared>( + DataType::INT64, false); field_data->FillFieldData(dataset.timestamps_.data(), row_count); auto path = prefix + "/" + std::to_string(TimestampFieldID.get()); SaveFieldData(field_data, path, TimestampFieldID.get());