Skip to content

Commit

Permalink
enhance: support null in c data_datacodec and load null value
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Apr 26, 2024
1 parent 46d7298 commit 0022575
Show file tree
Hide file tree
Showing 40 changed files with 1,866 additions and 337 deletions.
112 changes: 101 additions & 11 deletions internal/core/src/common/FieldData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ template <typename Type, bool is_type_entire_row>
void
FieldDataImpl<Type, is_type_entire_row>::FillFieldData(const void* source,
ssize_t element_count) {
AssertInfo(!nullable_,
"need to fill valid_data, use the 3-argument version instead");

if (element_count == 0) {
return;
}
Expand All @@ -44,6 +47,37 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(const void* source,
length_ += element_count;
}

template <typename Type, bool is_type_entire_row>
void
FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
const void* field_data, const uint8_t* valid_data, ssize_t element_count) {
AssertInfo(
nullable_,
"no need to fill valid_data, use the 2-argument version instead");
if (element_count == 0) {
return;
}

std::lock_guard lck(tell_mutex_);
if (length_ + element_count > get_num_rows()) {
resize_field_data(length_ + element_count);
}
std::copy_n(static_cast<const Type*>(field_data),
element_count * dim_,
field_data_.data() + length_ * dim_);

ssize_t byte_count = (element_count + 7) / 8;
// Note: if 'nullable == true` and valid_data is nullptr
// means null_count == 0, will fill it with 0xFF
if (valid_data == nullptr) {
std::fill_n(valid_data_.get(), byte_count, 0xFF);
} else {
std::copy_n(valid_data, byte_count, valid_data_.get());
}

length_ += element_count;
}

template <typename ArrayType, arrow::Type::type ArrayDataType>
std::pair<const void*, int64_t>
GetDataInfoFromArray(const std::shared_ptr<arrow::Array> array) {
Expand All @@ -66,6 +100,7 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
if (element_count == 0) {
return;
}
null_count = array->null_count();
switch (data_type_) {
case DataType::BOOL: {
AssertInfo(array->type()->id() == arrow::Type::type::BOOL,
Expand All @@ -76,42 +111,71 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
for (size_t index = 0; index < element_count; ++index) {
values[index] = bool_array->Value(index);
}
if (nullable_) {
return FillFieldData(values.data(),
bool_array->null_bitmap_data(),
element_count);
}
return FillFieldData(values.data(), element_count);
}
case DataType::INT8: {
auto array_info =
GetDataInfoFromArray<arrow::Int8Array, arrow::Type::type::INT8>(
array);
if (nullable_) {
return FillFieldData(
array_info.first, array->null_bitmap_data(), element_count);
}
return FillFieldData(array_info.first, array_info.second);
}
case DataType::INT16: {
auto array_info =
GetDataInfoFromArray<arrow::Int16Array,
arrow::Type::type::INT16>(array);
if (nullable_) {
return FillFieldData(
array_info.first, array->null_bitmap_data(), element_count);
}
return FillFieldData(array_info.first, array_info.second);
}
case DataType::INT32: {
auto array_info =
GetDataInfoFromArray<arrow::Int32Array,
arrow::Type::type::INT32>(array);
if (nullable_) {
return FillFieldData(
array_info.first, array->null_bitmap_data(), element_count);
}
return FillFieldData(array_info.first, array_info.second);
}
case DataType::INT64: {
auto array_info =
GetDataInfoFromArray<arrow::Int64Array,
arrow::Type::type::INT64>(array);
if (nullable_) {
return FillFieldData(
array_info.first, array->null_bitmap_data(), element_count);
}
return FillFieldData(array_info.first, array_info.second);
}
case DataType::FLOAT: {
auto array_info =
GetDataInfoFromArray<arrow::FloatArray,
arrow::Type::type::FLOAT>(array);
if (nullable_) {
return FillFieldData(
array_info.first, array->null_bitmap_data(), element_count);
}
return FillFieldData(array_info.first, array_info.second);
}
case DataType::DOUBLE: {
auto array_info =
GetDataInfoFromArray<arrow::DoubleArray,
arrow::Type::type::DOUBLE>(array);
if (nullable_) {
return FillFieldData(
array_info.first, array->null_bitmap_data(), element_count);
}
return FillFieldData(array_info.first, array_info.second);
}
case DataType::STRING:
Expand All @@ -124,6 +188,10 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
for (size_t index = 0; index < element_count; ++index) {
values[index] = string_array->GetString(index);
}
if (nullable_) {
return FillFieldData(
values.data(), array->null_bitmap_data(), element_count);
}
return FillFieldData(values.data(), element_count);
}
case DataType::JSON: {
Expand All @@ -136,17 +204,33 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
values[index] =
Json(simdjson::padded_string(json_array->GetString(index)));
}
if (nullable_) {
return FillFieldData(
values.data(), array->null_bitmap_data(), element_count);
}
return FillFieldData(values.data(), element_count);
}
case DataType::ARRAY: {
auto array_array =
std::dynamic_pointer_cast<arrow::BinaryArray>(array);
std::vector<Array> values(element_count);
int null_number = 0;
for (size_t index = 0; index < element_count; ++index) {
ScalarArray field_data;
field_data.ParseFromString(array_array->GetString(index));
if (array_array->GetString(index) == "") {
null_number++;
continue;
}
auto success =
field_data.ParseFromString(array_array->GetString(index));
AssertInfo(success, "parse from string failed");
values[index] = Array(field_data);
}
if (nullable_) {
return FillFieldData(
values.data(), array->null_bitmap_data(), element_count);
}
AssertInfo(null_number == 0, "get empty string when not nullable");
return FillFieldData(values.data(), element_count);
}
case DataType::VECTOR_FLOAT:
Expand Down Expand Up @@ -201,27 +285,33 @@ template class FieldDataImpl<bfloat16, false>;
template class FieldDataImpl<knowhere::sparse::SparseRow<float>, true>;

FieldDataPtr
InitScalarFieldData(const DataType& type, int64_t cap_rows) {
InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows) {
switch (type) {
case DataType::BOOL:
return std::make_shared<FieldData<bool>>(type, cap_rows);
return std::make_shared<FieldData<bool>>(type, nullable, cap_rows);
case DataType::INT8:
return std::make_shared<FieldData<int8_t>>(type, cap_rows);
return std::make_shared<FieldData<int8_t>>(
type, nullable, cap_rows);
case DataType::INT16:
return std::make_shared<FieldData<int16_t>>(type, cap_rows);
return std::make_shared<FieldData<int16_t>>(
type, nullable, cap_rows);
case DataType::INT32:
return std::make_shared<FieldData<int32_t>>(type, cap_rows);
return std::make_shared<FieldData<int32_t>>(
type, nullable, cap_rows);
case DataType::INT64:
return std::make_shared<FieldData<int64_t>>(type, cap_rows);
return std::make_shared<FieldData<int64_t>>(
type, nullable, cap_rows);
case DataType::FLOAT:
return std::make_shared<FieldData<float>>(type, cap_rows);
return std::make_shared<FieldData<float>>(type, nullable, cap_rows);
case DataType::DOUBLE:
return std::make_shared<FieldData<double>>(type, cap_rows);
return std::make_shared<FieldData<double>>(
type, nullable, cap_rows);
case DataType::STRING:
case DataType::VARCHAR:
return std::make_shared<FieldData<std::string>>(type, cap_rows);
return std::make_shared<FieldData<std::string>>(
type, nullable, cap_rows);
case DataType::JSON:
return std::make_shared<FieldData<Json>>(type, cap_rows);
return std::make_shared<FieldData<Json>>(type, nullable, cap_rows);
default:
throw NotSupportedDataTypeException(
"InitScalarFieldData not support data type " +
Expand Down
32 changes: 20 additions & 12 deletions internal/core/src/common/FieldData.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ template <typename Type>
class FieldData : public FieldDataImpl<Type, true> {
public:
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0)
explicit FieldData(DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataImpl<Type, true>::FieldDataImpl(
1, data_type, buffered_num_rows) {
1, data_type, nullable, buffered_num_rows) {
}
static_assert(IsScalar<Type> || std::is_same_v<Type, PkType>);
explicit FieldData(DataType data_type, FixedVector<Type>&& inner_data)
Expand All @@ -45,26 +47,32 @@ template <>
class FieldData<std::string> : public FieldDataStringImpl {
public:
static_assert(IsScalar<std::string> || std::is_same_v<std::string, PkType>);
explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0)
: FieldDataStringImpl(data_type, buffered_num_rows) {
explicit FieldData(DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataStringImpl(data_type, nullable, buffered_num_rows) {
}
};

template <>
class FieldData<Json> : public FieldDataJsonImpl {
public:
static_assert(IsScalar<std::string> || std::is_same_v<std::string, PkType>);
explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0)
: FieldDataJsonImpl(data_type, buffered_num_rows) {
explicit FieldData(DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataJsonImpl(data_type, nullable, buffered_num_rows) {
}
};

template <>
class FieldData<Array> : public FieldDataArrayImpl {
public:
static_assert(IsScalar<Array> || std::is_same_v<std::string, PkType>);
explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0)
: FieldDataArrayImpl(data_type, buffered_num_rows) {
explicit FieldData(DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataArrayImpl(data_type, nullable, buffered_num_rows) {
}
};

Expand All @@ -75,7 +83,7 @@ class FieldData<FloatVector> : public FieldDataImpl<float, false> {
DataType data_type,
int64_t buffered_num_rows = 0)
: FieldDataImpl<float, false>::FieldDataImpl(
dim, data_type, buffered_num_rows) {
dim, data_type, false, buffered_num_rows) {
}
};

Expand All @@ -86,7 +94,7 @@ class FieldData<BinaryVector> : public FieldDataImpl<uint8_t, false> {
DataType data_type,
int64_t buffered_num_rows = 0)
: binary_dim_(dim),
FieldDataImpl(dim / 8, data_type, buffered_num_rows) {
FieldDataImpl(dim / 8, data_type, false, buffered_num_rows) {
Assert(dim % 8 == 0);
}

Expand All @@ -106,7 +114,7 @@ class FieldData<Float16Vector> : public FieldDataImpl<float16, false> {
DataType data_type,
int64_t buffered_num_rows = 0)
: FieldDataImpl<float16, false>::FieldDataImpl(
dim, data_type, buffered_num_rows) {
dim, data_type, false, buffered_num_rows) {
}
};

Expand Down Expand Up @@ -134,6 +142,6 @@ using FieldDataChannel = Channel<FieldDataPtr>;
using FieldDataChannelPtr = std::shared_ptr<FieldDataChannel>;

FieldDataPtr
InitScalarFieldData(const DataType& type, int64_t cap_rows);
InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows);

} // namespace milvus
Loading

0 comments on commit 0022575

Please sign in to comment.