From eb3e4583ec6928e8681037883561a17b45256e3d Mon Sep 17 00:00:00 2001 From: smellthemoon <64083300+smellthemoon@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:14:30 +0800 Subject: [PATCH] enhance: all op(Null) is false in expr (#35527) #31728 --------- Signed-off-by: lixinguo Co-authored-by: lixinguo --- internal/core/src/common/FieldData.cpp | 2 +- internal/core/src/common/Vector.h | 23 +- .../src/exec/expression/AlwaysTrueExpr.cpp | 7 +- .../expression/BinaryArithOpEvalRangeExpr.cpp | 129 +- .../expression/BinaryArithOpEvalRangeExpr.h | 72 +- .../src/exec/expression/BinaryRangeExpr.cpp | 71 +- .../src/exec/expression/BinaryRangeExpr.h | 16 +- .../core/src/exec/expression/CompareExpr.cpp | 146 +- .../core/src/exec/expression/CompareExpr.h | 59 +- .../core/src/exec/expression/ExistsExpr.cpp | 14 +- internal/core/src/exec/expression/Expr.h | 196 +- .../src/exec/expression/JsonContainsExpr.cpp | 118 +- .../src/exec/expression/LogicalBinaryExpr.cpp | 4 + .../src/exec/expression/LogicalUnaryExpr.cpp | 3 + .../core/src/exec/expression/TermExpr.cpp | 93 +- .../core/src/exec/expression/UnaryExpr.cpp | 123 +- internal/core/src/exec/expression/UnaryExpr.h | 14 +- .../core/src/exec/operator/FilterBitsNode.cpp | 7 +- internal/core/src/exec/operator/MvccNode.cpp | 12 +- .../operator/groupby/SearchGroupByOperator.h | 4 +- internal/core/src/index/BitmapIndex.cpp | 37 +- internal/core/src/index/BitmapIndex.h | 6 +- internal/core/src/index/HybridScalarIndex.h | 8 +- .../core/src/index/InvertedIndexTantivy.h | 4 +- internal/core/src/index/ScalarIndex.h | 4 +- internal/core/src/index/ScalarIndexSort.cpp | 34 +- internal/core/src/index/ScalarIndexSort.h | 8 +- internal/core/src/index/StringIndexMarisa.cpp | 28 +- internal/core/src/index/StringIndexMarisa.h | 10 +- internal/core/src/query/ScalarIndex.h | 4 +- .../src/segcore/ChunkedSegmentSealedImpl.cpp | 16 +- internal/core/src/segcore/FieldIndexing.cpp | 1 + .../core/src/segcore/SegmentSealedImpl.cpp | 16 +- internal/core/src/segcore/Utils.cpp | 98 +- internal/core/unittest/test_expr.cpp | 14853 ++++++++++++---- internal/core/unittest/test_string_expr.cpp | 632 +- .../core/unittest/test_utils/AssertUtils.h | 16 +- internal/core/unittest/test_utils/DataGen.h | 8 +- tests/python_client/testcases/test_search.py | 1 - 39 files changed, 12688 insertions(+), 4209 deletions(-) diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index f64e677d9a036..af089fa4696e3 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -69,7 +69,7 @@ FieldDataImpl::FillFieldData( ssize_t byte_count = (element_count + 7) / 8; // Note: if 'nullable == true` and valid_data is nullptr // means null_count == 0, will fill it with 0xFF - if (!valid_data) { + if (valid_data == nullptr) { valid_data_.assign(byte_count, 0xFF); } else { std::copy_n(valid_data, byte_count, valid_data_.data()); diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index ac5f0b217b0c5..afc1d4766e079 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -19,6 +19,8 @@ #include #include +#include "EasyAssert.h" +#include "Types.h" #include "common/FieldData.h" namespace milvus { @@ -50,6 +52,7 @@ class BaseVector { protected: DataType type_kind_; size_t length_; + // todo: use null_count to skip some bitset operate std::optional null_count_; }; @@ -65,8 +68,8 @@ class ColumnVector final : public BaseVector { size_t length, std::optional null_count = std::nullopt) : BaseVector(data_type, length, null_count) { - //todo: support null expr values_ = InitScalarFieldData(data_type, false, length); + valid_values_ = InitScalarFieldData(data_type, false, length); } // ColumnVector(FixedVector&& data) @@ -75,15 +78,25 @@ class ColumnVector final : public BaseVector { // std::make_shared>(DataType::BOOL, std::move(data)); // } + // // the size is the number of bits + // ColumnVector(TargetBitmap&& bitmap) + // : BaseVector(DataType::INT8, bitmap.size()) { + // values_ = std::make_shared>( + // bitmap.size(), DataType::INT8, false, std::move(bitmap).into()); + // } + // the size is the number of bits - ColumnVector(TargetBitmap&& bitmap) + ColumnVector(TargetBitmap&& bitmap, TargetBitmap&& valid_bitmap) : BaseVector(DataType::INT8, bitmap.size()) { values_ = std::make_shared>(DataType::INT8, std::move(bitmap)); + valid_values_ = std::make_shared>( + DataType::INT8, std::move(valid_bitmap)); } virtual ~ColumnVector() override { values_.reset(); + valid_values_.reset(); } void* @@ -91,6 +104,11 @@ class ColumnVector final : public BaseVector { return values_->Data(); } + void* + GetValidRawData() { + return valid_values_->Data(); + } + template const As* RawAsValues() const { @@ -99,6 +117,7 @@ class ColumnVector final : public BaseVector { private: FieldDataPtr values_; + FieldDataPtr valid_values_; }; using ColumnVectorPtr = std::shared_ptr; diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp index 24789c429ac8a..920fc86ee6a17 100644 --- a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp @@ -25,16 +25,19 @@ PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) { ? active_count_ - current_pos_ : batch_size_; + // always true no need to skip null if (real_batch_size == 0) { result = nullptr; return; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); res.set(); + valid_res.set(); result = res_vec; current_pos_ += real_batch_size; diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp index 7f64cae5b390e..e5b24ac4121ce 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -113,9 +113,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto op_type = expr_->op_type_; @@ -129,6 +131,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { #define BinaryArithRangeJSONCompare(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ + if (valid_data != nullptr && !valid_data[i]) { \ + res[i] = false; \ + valid_res[i] = false; \ + continue; \ + } \ auto x = data[i].template at(pointer); \ if (x.error()) { \ if constexpr (std::is_same_v) { \ @@ -146,6 +153,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { #define BinaryArithRangeJSONCompareNotEqual(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ + if (valid_data != nullptr && !valid_data[i]) { \ + res[i] = false; \ + valid_res[i] = false; \ + continue; \ + } \ auto x = data[i].template at(pointer); \ if (x.error()) { \ if constexpr (std::is_same_v) { \ @@ -161,8 +173,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } while (false) auto execute_sub_batch = [op_type, arith_type](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val, ValueType right_operand, const std::string& pointer) { @@ -197,6 +211,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = false; + valid_res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -246,6 +265,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = false; + valid_res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -295,6 +319,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = false; + valid_res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -344,6 +373,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = false; + valid_res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -393,6 +427,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = false; + valid_res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -442,6 +481,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = false; + valid_res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -471,6 +515,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { int64_t processed_size = ProcessDataChunks(execute_sub_batch, std::nullptr_t{}, res, + valid_res, value, right_operand, pointer); @@ -492,9 +537,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); int index = -1; if (expr_->column_.nested_path_.size() > 0) { @@ -511,6 +558,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { #define BinaryArithRangeArrayCompare(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ + if (valid_data != nullptr && !valid_data[i]) { \ + res[i] = false; \ + valid_res[i] = false; \ + continue; \ + } \ if (index >= data[i].length()) { \ res[i] = false; \ continue; \ @@ -521,8 +573,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } while (false) auto execute_sub_batch = [op_type, arith_type](const ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val, ValueType right_operand, int index) { @@ -558,6 +612,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() == val; } break; @@ -601,6 +659,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() != val; } break; @@ -644,6 +706,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() > val; } break; @@ -687,6 +753,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() >= val; } break; @@ -730,6 +800,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() < val; } break; @@ -773,6 +847,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() <= val; } break; @@ -794,8 +872,14 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } }; - int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, value, right_operand, index); + int64_t processed_size = + ProcessDataChunks(execute_sub_batch, + std::nullptr_t{}, + res, + valid_res, + value, + right_operand, + index); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -1185,12 +1269,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() { return res; }; auto res = ProcessIndexChunks(execute_sub_batch, value, right_operand); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + // return std::make_shared(std::move(res)); + return res; } template @@ -1209,16 +1294,20 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { auto value = GetValueFromProto(expr_->value_); auto right_operand = GetValueFromProto(expr_->right_operand_); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto op_type = expr_->op_type_; auto arith_type = expr_->arith_op_type_; auto execute_sub_batch = [op_type, arith_type]( const T* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, HighPrecisionType value, HighPrecisionType right_operand) { switch (op_type) { @@ -1534,9 +1623,23 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { "arithmetic eval expr: {}", op_type); } + // there is a batch operation in ArithOpElementFunc, + // so not divide data again for the reason that it may reduce performance if the null distribution is scattered + // but to mask res with valid_data after the batch operation. + if (valid_data != nullptr) { + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + res[i] = valid_res[i] = false; + } + } + } }; - int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, value, right_operand); + int64_t processed_size = ProcessDataChunks(execute_sub_batch, + std::nullptr_t{}, + res, + valid_res, + value, + right_operand); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h index 3c84819dc2b83..5eef111438591 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -239,7 +239,6 @@ struct ArithOpElementFunc { } } */ - if constexpr (!std::is_same_v::op), void>) { constexpr auto cmp_op_cvt = CmpOpHelper::op; @@ -282,22 +281,26 @@ struct ArithOpIndexFunc { HighPrecisonType right_operand) { TargetBitmap res(size); for (size_t i = 0; i < size; ++i) { + auto raw = index->Reverse_Lookup(i); + if (!raw.has_value()) { + res[i] = false; + continue; + } if constexpr (cmp_op == proto::plan::OpType::Equal) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) == val; + res[i] = (raw.value() + right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) == val; + res[i] = (raw.value() - right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) == val; + res[i] = (raw.value() * right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) == val; + res[i] = (raw.value() / right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) == val; + res[i] = (fmod(raw.value(), right_operand)) == val; } else { PanicInfo( OpTypeInvalid, @@ -307,20 +310,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::NotEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) != val; + res[i] = (raw.value() + right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) != val; + res[i] = (raw.value() - right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) != val; + res[i] = (raw.value() * right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) != val; + res[i] = (raw.value() / right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) != val; + res[i] = (fmod(raw.value(), right_operand)) != val; } else { PanicInfo( OpTypeInvalid, @@ -330,20 +332,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::GreaterThan) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) > val; + res[i] = (raw.value() + right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) > val; + res[i] = (raw.value() - right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) > val; + res[i] = (raw.value() * right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) > val; + res[i] = (raw.value() / right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) > val; + res[i] = (fmod(raw.value(), right_operand)) > val; } else { PanicInfo( OpTypeInvalid, @@ -353,20 +354,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::GreaterEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) >= val; + res[i] = (raw.value() + right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) >= val; + res[i] = (raw.value() - right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) >= val; + res[i] = (raw.value() * right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) >= val; + res[i] = (raw.value() / right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) >= val; + res[i] = (fmod(raw.value(), right_operand)) >= val; } else { PanicInfo( OpTypeInvalid, @@ -376,20 +376,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::LessThan) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) < val; + res[i] = (raw.value() + right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) < val; + res[i] = (raw.value() - right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) < val; + res[i] = (raw.value() * right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) < val; + res[i] = (raw.value() / right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) < val; + res[i] = (fmod(raw.value(), right_operand)) < val; } else { PanicInfo( OpTypeInvalid, @@ -399,20 +398,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::LessEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) <= val; + res[i] = (raw.value() + right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) <= val; + res[i] = (raw.value() - right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) <= val; + res[i] = (raw.value() * right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) <= val; + res[i] = (raw.value() / right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) <= val; + res[i] = (fmod(raw.value(), right_operand)) <= val; } else { PanicInfo( OpTypeInvalid, diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index be6aa576aaaee..26467cd4646a3 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "BinaryRangeExpr.h" +#include #include "query/Utils.h" @@ -150,8 +151,12 @@ PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1, cached_overflow_res_->size() == batch_size) { return cached_overflow_res_; } - auto res = std::make_shared(TargetBitmap(batch_size)); - return res; + auto valid_res = ProcessChunksForValid(is_index_mode_); + auto res_vec = std::make_shared(TargetBitmap(batch_size), + std::move(valid_res)); + cached_overflow_res_ = res_vec; + + return res_vec; }; if constexpr (std::is_integral_v && !std::is_same_v) { @@ -207,12 +212,12 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { func(index_ptr, val1, val2, lower_inclusive, upper_inclusive)); }; auto res = ProcessIndexChunks(execute_sub_batch, val1, val2); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } template @@ -240,14 +245,18 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { PreCheckOverflow(val1, val2, lower_inclusive, upper_inclusive)) { return res; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const T* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, HighPrecisionType val1, HighPrecisionType val2) { if (lower_inclusive && upper_inclusive) { @@ -263,6 +272,16 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { BinaryRangeElementFunc func; func(val1, val2, data, size, res); } + // there is a batch operation in BinaryRangeElementFunc, + // so not divide data again for the reason that it may reduce performance if the null distribution is scattered + // but to mask res with valid_data after the batch operation. + if (valid_data != nullptr) { + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + res[i] = valid_res[i] = false; + } + } + } }; auto skip_index_func = [val1, val2, lower_inclusive, upper_inclusive]( @@ -282,7 +301,7 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, skip_index_func, res, val1, val2); + execute_sub_batch, skip_index_func, res, valid_res, val1, val2); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -301,9 +320,11 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); bool lower_inclusive = expr_->lower_inclusive_; bool upper_inclusive = expr_->upper_inclusive_; @@ -313,26 +334,28 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { auto execute_sub_batch = [lower_inclusive, upper_inclusive, pointer]( const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val1, ValueType val2) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } else { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val1, val2); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val1, val2); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -351,9 +374,11 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); bool lower_inclusive = expr_->lower_inclusive_; bool upper_inclusive = expr_->upper_inclusive_; @@ -366,27 +391,29 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val1, ValueType val2, int index) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } else { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val1, val2, index); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val1, val2, index); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index 6484a40e5ef1e..145a8955ffe88 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -54,6 +54,10 @@ struct BinaryRangeElementFunc { #define BinaryRangeJSONCompare(cmp) \ do { \ + if (valid_data != nullptr && !valid_data[i]) { \ + res[i] = valid_res[i] = false; \ + break; \ + } \ auto x = src[i].template at(pointer); \ if (x.error()) { \ if constexpr (std::is_same_v) { \ @@ -81,8 +85,10 @@ struct BinaryRangeElementFuncForJson { ValueType val2, const std::string& pointer, const milvus::Json* src, + const bool* valid_data, size_t n, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { for (size_t i = 0; i < n; ++i) { if constexpr (lower_inclusive && upper_inclusive) { BinaryRangeJSONCompare(val1 <= value && value <= val2); @@ -107,9 +113,15 @@ struct BinaryRangeElementFuncForArray { ValueType val2, int index, const milvus::ArrayView* src, + const bool* valid_data, size_t n, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { for (size_t i = 0; i < n; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (lower_inclusive && upper_inclusive) { if (index >= src[i].length()) { res[i] = false; diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp index 467df6654a929..5bc2e8dab15e1 100644 --- a/internal/core/src/exec/expression/CompareExpr.cpp +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -16,6 +16,7 @@ #include "CompareExpr.h" #include "common/type_c.h" +#include #include "query/Relational.h" namespace milvus { @@ -58,12 +59,19 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_->chunk_scalar_index(field_id, current_chunk_id)); } - return indexing.Reverse_Lookup(current_chunk_pos++); + auto raw = indexing.Reverse_Lookup(current_chunk_pos); + current_chunk_pos++; + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); }; } } auto chunk_data = segment_->chunk_data(field_id, current_chunk_id).data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id).valid_data(); auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); return [=, ¤t_chunk_id, ¤t_chunk_pos]() mutable -> const number { @@ -72,10 +80,16 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, current_chunk_pos = 0; chunk_data = segment_->chunk_data(field_id, current_chunk_id).data(); + chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); } - + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } return chunk_data[current_chunk_pos++]; }; } @@ -103,7 +117,12 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_->chunk_scalar_index( field_id, current_chunk_id)); } - return indexing.Reverse_Lookup(current_chunk_pos++); + auto raw = indexing.Reverse_Lookup(current_chunk_pos); + current_chunk_pos++; + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); }; } } @@ -114,6 +133,9 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto chunk_data = segment_->chunk_data(field_id, current_chunk_id) .data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); return [=, @@ -126,16 +148,26 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_ ->chunk_data(field_id, current_chunk_id) .data(); + chunk_valid_data = + segment_ + ->chunk_data(field_id, current_chunk_id) + .valid_data(); current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); } - + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } return chunk_data[current_chunk_pos++]; }; } else { auto chunk_data = segment_->chunk_view(field_id, current_chunk_id) .first.data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); return [=, @@ -148,9 +180,17 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, ->chunk_view( field_id, current_chunk_id) .first.data(); + chunk_valid_data = segment_ + ->chunk_data( + field_id, current_chunk_id) + .valid_data(); current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); } + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } return std::string(chunk_data[current_chunk_pos++]); }; @@ -203,9 +243,11 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto left = GetChunkData(expr_->left_data_type_, expr_->left_field_id_, @@ -218,8 +260,15 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { right_current_chunk_id_, right_current_chunk_pos_); for (int i = 0; i < real_batch_size; ++i) { - res[i] = boost::apply_visitor( - milvus::query::Relational{}, left(), right()); + if (!left().has_value() || !right().has_value()) { + res[i] = false; + valid_res[i] = false; + continue; + } + res[i] = + boost::apply_visitor(milvus::query::Relational{}, + left().value(), + right().value()); } return res_vec; } else { @@ -228,9 +277,11 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); @@ -255,10 +306,16 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; i < chunk_size; ++i) { - res[processed_rows++] = boost::apply_visitor( - milvus::query::Relational{}, - left(i), - right(i)); + if (!left(i).has_value() || !right(i).has_value()) { + res[processed_rows] = false; + valid_res[processed_rows] = false; + } else { + res[processed_rows] = boost::apply_visitor( + milvus::query::Relational{}, + left(i).value(), + right(i).value()); + } + processed_rows++; if (processed_rows >= batch_size_) { current_chunk_id_ = chunk_id; @@ -280,12 +337,23 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); + auto raw = indexing.Reverse_Lookup(i); + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); }; } } auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); - return [chunk_data](int i) -> const number { return chunk_data[i]; }; + auto chunk_valid_data = + segment_->chunk_data(field_id, chunk_id).valid_data(); + return [chunk_data, chunk_valid_data](int i) -> const number { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } + return chunk_data[i]; + }; } template <> @@ -297,8 +365,12 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { - return [&indexing](int i) -> const std::string { - return indexing.Reverse_Lookup(i); + return [&indexing](int i) -> const number { + auto raw = indexing.Reverse_Lookup(i); + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); }; } } @@ -308,12 +380,23 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, .growing_enable_mmap) { auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); - return [chunk_data](int i) -> const number { return chunk_data[i]; }; + auto chunk_valid_data = + segment_->chunk_data(field_id, chunk_id).valid_data(); + return [chunk_data, chunk_valid_data](int i) -> const number { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } + return chunk_data[i]; + }; } else { - auto chunk_data = - segment_->chunk_view(field_id, chunk_id) - .first.data(); - return [chunk_data](int i) -> const number { + auto chunk_info = + segment_->chunk_view(field_id, chunk_id); + auto chunk_data = chunk_info.first.data(); + auto chunk_valid_data = chunk_info.second.data(); + return [chunk_data, chunk_valid_data](int i) -> const number { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } return std::string(chunk_data[i]); }; } @@ -450,9 +533,11 @@ PhyCompareFilterExpr::ExecCompareRightType() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* left, @@ -491,15 +576,14 @@ PhyCompareFilterExpr::ExecCompareRightType() { break; } default: - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported operator type for compare column expr: {}", - expr_type)); + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operator type for " + "compare column expr: {}", + expr_type)); } }; int64_t processed_size = - ProcessBothDataChunks(execute_sub_batch, res); + ProcessBothDataChunks(execute_sub_batch, res, valid_res); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index fd9ef751387cb..8f4aaaed53709 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -18,6 +18,7 @@ #include #include +#include #include "common/EasyAssert.h" #include "common/Types.h" @@ -29,14 +30,17 @@ namespace milvus { namespace exec { -using number = boost::variant; +using number_type = boost::variant; + +using number = std::optional; + using ChunkDataAccessor = std::function; using MultipleChunkDataAccessor = std::function; @@ -264,16 +268,19 @@ class PhyCompareFilterExpr : public Expr { template int64_t - ProcessBothDataChunks(FUNC func, TargetBitmapView res, ValTypes... values) { + ProcessBothDataChunks(FUNC func, + TargetBitmapView res, + TargetBitmapView valid_res, + ValTypes... values) { if (segment_->is_chunked()) { return ProcessBothDataChunksForMultipleChunk( - func, res, values...); + func, res, valid_res, values...); } else { return ProcessBothDataChunksForSingleChunk( - func, res, values...); + func, res, valid_res, values...); } } @@ -281,6 +288,7 @@ class PhyCompareFilterExpr : public Expr { int64_t ProcessBothDataChunksForSingleChunk(FUNC func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -304,6 +312,20 @@ class PhyCompareFilterExpr : public Expr { const T* left_data = left_chunk.data() + data_pos; const U* right_data = right_chunk.data() + data_pos; func(left_data, right_data, size, res + processed_size, values...); + const bool* left_valid_data = left_chunk.valid_data(); + const bool* right_valid_data = right_chunk.valid_data(); + // mask with valid_data + for (int i = 0; i < size; ++i) { + if (left_valid_data && !left_valid_data[i + data_pos]) { + res[processed_size + i] = false; + valid_res[processed_size + i] = false; + continue; + } + if (right_valid_data && !right_valid_data[i + data_pos]) { + res[processed_size + i] = false; + valid_res[processed_size + i] = false; + } + } processed_size += size; if (processed_size >= batch_size_) { @@ -320,6 +342,7 @@ class PhyCompareFilterExpr : public Expr { int64_t ProcessBothDataChunksForMultipleChunk(FUNC func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -347,6 +370,20 @@ class PhyCompareFilterExpr : public Expr { const T* left_data = left_chunk.data() + data_pos; const U* right_data = right_chunk.data() + data_pos; func(left_data, right_data, size, res + processed_size, values...); + const bool* left_valid_data = left_chunk.valid_data(); + const bool* right_valid_data = right_chunk.valid_data(); + // mask with valid_data + for (int i = 0; i < size; ++i) { + if (left_valid_data && !left_valid_data[i + data_pos]) { + res[processed_size + i] = false; + valid_res[processed_size + i] = false; + continue; + } + if (right_valid_data && !right_valid_data[i + data_pos]) { + res[processed_size + i] = false; + valid_res[processed_size + i] = false; + } + } processed_size += size; if (processed_size >= batch_size_) { diff --git a/internal/core/src/exec/expression/ExistsExpr.cpp b/internal/core/src/exec/expression/ExistsExpr.cpp index 6798eeedb4210..c73b4e007dc38 100644 --- a/internal/core/src/exec/expression/ExistsExpr.cpp +++ b/internal/core/src/exec/expression/ExistsExpr.cpp @@ -44,22 +44,30 @@ PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer) { for (int i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].exist(pointer); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 25f90db4a249f..307792a539ac2 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -248,6 +249,7 @@ class SegmentExpr : public Expr { FUNC func, std::function skip_func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { // For sealed segment, only single chunk Assert(num_data_chunk_ == 1); @@ -256,13 +258,16 @@ class SegmentExpr : public Expr { auto& skip_index = segment_->GetSkipIndex(); if (!skip_func || !skip_func(skip_index, field_id_, 0)) { - auto data_vec = - segment_ - ->get_batch_views( - field_id_, 0, current_data_chunk_pos_, need_size) - .first; - - func(data_vec.data(), need_size, res, values...); + auto views_info = segment_->get_batch_views( + field_id_, 0, current_data_chunk_pos_, need_size); + // first is the raw data, second is valid_data + // use valid_data to see if raw data is null + func(views_info.first.data(), + views_info.second.data(), + need_size, + res, + valid_res, + values...); } current_data_chunk_pos_ += need_size; return need_size; @@ -274,6 +279,7 @@ class SegmentExpr : public Expr { FUNC func, std::function skip_func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -281,7 +287,7 @@ class SegmentExpr : public Expr { std::is_same_v) { if (segment_->type() == SegmentType::Sealed) { return ProcessChunkForSealedSeg( - func, skip_func, res, values...); + func, skip_func, res, valid_res, values...); } } @@ -303,7 +309,16 @@ class SegmentExpr : public Expr { if (!skip_func || !skip_func(skip_index, field_id_, i)) { auto chunk = segment_->chunk_data(field_id_, i); const T* data = chunk.data() + data_pos; - func(data, size, res + processed_size, values...); + const bool* valid_data = chunk.valid_data(); + if (valid_data != nullptr) { + valid_data += data_pos; + } + func(data, + valid_data, + size, + res + processed_size, + valid_res + processed_size, + values...); } processed_size += size; @@ -322,6 +337,7 @@ class SegmentExpr : public Expr { FUNC func, std::function skip_func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -356,13 +372,21 @@ class SegmentExpr : public Expr { if constexpr (std::is_same_v || std::is_same_v) { if (segment_->type() == SegmentType::Sealed) { + // first is the raw data, second is valid_data + // use valid_data to see if raw data is null auto data_vec = segment_ ->get_batch_views( field_id_, i, data_pos, size) .first; + auto valid_data = segment_ + ->get_batch_views( + field_id_, i, data_pos, size) + .second; func(data_vec.data(), + valid_data.data(), size, res + processed_size, + valid_res + processed_size, values...); is_seal = true; } @@ -370,7 +394,16 @@ class SegmentExpr : public Expr { if (!is_seal) { auto chunk = segment_->chunk_data(field_id_, i); const T* data = chunk.data() + data_pos; - func(data, size, res + processed_size, values...); + const bool* valid_data = chunk.valid_data(); + if (valid_data != nullptr) { + valid_data += data_pos; + } + func(data, + valid_data, + size, + res + processed_size, + valid_res + processed_size, + values...); } } @@ -403,8 +436,10 @@ class SegmentExpr : public Expr { int ProcessIndexOneChunk(TargetBitmap& result, + TargetBitmap& valid_result, size_t chunk_id, const TargetBitmap& chunk_res, + const TargetBitmap& chunk_valid_res, int processed_rows) { auto data_pos = chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0; @@ -416,33 +451,41 @@ class SegmentExpr : public Expr { // chunk_res.begin() + data_pos, // chunk_res.begin() + data_pos + size); result.append(chunk_res, data_pos, size); + valid_result.append(chunk_valid_res, data_pos, size); return size; } template - TargetBitmap + VectorPtr ProcessIndexChunks(FUNC func, ValTypes... values) { typedef std:: conditional_t, std::string, T> IndexInnerType; using Index = index::ScalarIndex; TargetBitmap result; + TargetBitmap valid_result; int processed_rows = 0; for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { // This cache result help getting result for every batch loop. - // It avoids indexing execute for evevy batch because indexing + // It avoids indexing execute for every batch because indexing // executing costs quite much time. if (cached_index_chunk_id_ != i) { const Index& index = segment_->chunk_scalar_index(field_id_, i); auto* index_ptr = const_cast(&index); cached_index_chunk_res_ = std::move(func(index_ptr, values...)); + auto valid_result = index_ptr->IsNotNull(); + cached_index_chunk_valid_res_ = std::move(valid_result); cached_index_chunk_id_ = i; } - auto size = ProcessIndexOneChunk( - result, i, cached_index_chunk_res_, processed_rows); + auto size = ProcessIndexOneChunk(result, + valid_result, + i, + cached_index_chunk_res_, + cached_index_chunk_valid_res_, + processed_rows); if (processed_rows + size >= batch_size_) { current_index_chunk_ = i; @@ -454,23 +497,136 @@ class SegmentExpr : public Expr { processed_rows += size; } - return result; + return std::make_shared(std::move(result), + std::move(valid_result)); } - template + template + TargetBitmap + ProcessChunksForValid(bool use_index) { + if (use_index) { + return ProcessIndexChunksForValid(); + } else { + return ProcessDataChunksForValid(); + } + } + + template TargetBitmap + ProcessDataChunksForValid() { + TargetBitmap valid_result(batch_size_); + valid_result.set(); + int64_t processed_size = 0; + for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) { + auto data_pos = + (i == current_data_chunk_) ? current_data_chunk_pos_ : 0; + auto size = + (i == (num_data_chunk_ - 1)) + ? (segment_->type() == SegmentType::Growing + ? (active_count_ % size_per_chunk_ == 0 + ? size_per_chunk_ - data_pos + : active_count_ % size_per_chunk_ - data_pos) + : active_count_ - data_pos) + : size_per_chunk_ - data_pos; + + size = std::min(size, batch_size_ - processed_size); + + auto chunk = segment_->chunk_data(field_id_, i); + const bool* valid_data = chunk.valid_data(); + if (valid_data == nullptr) { + return valid_result; + } + valid_data += data_pos; + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + valid_result[i + data_pos] = false; + } + } + processed_size += size; + if (processed_size >= batch_size_) { + current_data_chunk_ = i; + current_data_chunk_pos_ = data_pos + size; + break; + } + } + return valid_result; + } + + int + ProcessIndexOneChunkForValid(TargetBitmap& valid_result, + size_t chunk_id, + const TargetBitmap& chunk_valid_res, + int processed_rows) { + auto data_pos = + chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0; + auto size = std::min( + std::min(size_per_chunk_ - data_pos, batch_size_ - processed_rows), + int64_t(chunk_valid_res.size())); + + valid_result.append(chunk_valid_res, data_pos, size); + return size; + } + + template + TargetBitmap + ProcessIndexChunksForValid() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + int processed_rows = 0; + TargetBitmap valid_result; + valid_result.set(); + + for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { + // This cache result help getting result for every batch loop. + // It avoids indexing execute for every batch because indexing + // executing costs quite much time. + if (cached_index_chunk_id_ != i) { + const Index& index = + segment_->chunk_scalar_index(field_id_, i); + auto* index_ptr = const_cast(&index); + auto execute_sub_batch = [](Index* index_ptr) { + TargetBitmap res = index_ptr->IsNotNull(); + return res; + }; + cached_index_chunk_valid_res_ = execute_sub_batch(index_ptr); + cached_index_chunk_id_ = i; + } + + auto size = ProcessIndexOneChunkForValid( + valid_result, i, cached_index_chunk_valid_res_, processed_rows); + + if (processed_rows + size >= batch_size_) { + current_index_chunk_ = i; + current_index_chunk_pos_ = i == current_index_chunk_ + ? current_index_chunk_pos_ + size + : size; + break; + } + processed_rows += size; + } + return valid_result; + } + + template + VectorPtr ProcessTextMatchIndex(FUNC func, ValTypes... values) { TargetBitmap result; + TargetBitmap valid_result; if (cached_match_res_ == nullptr) { auto index = segment_->GetTextIndex(field_id_); auto res = std::move(func(index, values...)); + auto valid_res = index->IsNotNull(); cached_match_res_ = std::make_shared(std::move(res)); + cached_index_chunk_valid_res_ = std::move(valid_res); if (cached_match_res_->size() < active_count_) { // some entities are not visible in inverted index. // only happend on growing segment. TargetBitmap tail(active_count_ - cached_match_res_->size()); cached_match_res_->append(tail); + cached_index_chunk_valid_res_.append(tail); } } @@ -481,9 +637,13 @@ class SegmentExpr : public Expr { : batch_size_; result.append( *cached_match_res_, current_data_chunk_pos_, real_batch_size); + valid_result.append(cached_index_chunk_valid_res_, + current_data_chunk_pos_, + real_batch_size); current_data_chunk_pos_ += real_batch_size; - return result; + return std::make_shared(std::move(result), + std::move(valid_result)); } template @@ -581,6 +741,8 @@ class SegmentExpr : public Expr { // Cache for index scan to avoid search index every batch int64_t cached_index_chunk_id_{-1}; TargetBitmap cached_index_chunk_res_{}; + // Cache for chunk valid res. + TargetBitmap cached_index_chunk_valid_res_{}; // Cache for text match. std::shared_ptr cached_match_res_{nullptr}; diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index da9f3d6aaa895..b21714b4c8b6b 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "JsonContainsExpr.h" +#include #include "common/Types.h" namespace milvus { @@ -173,17 +174,21 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { AssertInfo(expr_->column_.nested_path_.size() == 0, "[ExecArrayContains]nested path must be null"); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::unordered_set elements; for (auto const& element : expr_->vals_) { elements.insert(GetValueFromProto(element)); } auto execute_sub_batch = [](const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::unordered_set& elements) { auto executor = [&](size_t i) { const auto& array = data[i]; @@ -195,12 +200,16 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { return false; }; for (int i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -221,9 +230,11 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::unordered_set elements; auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -231,8 +242,10 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { elements.insert(GetValueFromProto(element)); } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::unordered_set& elements) { auto executor = [&](size_t i) { @@ -253,12 +266,16 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -274,9 +291,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::vector elements; @@ -285,8 +304,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements) { auto executor = [&](size_t i) -> bool { @@ -316,12 +337,16 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -344,9 +369,11 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::unordered_set elements; for (auto const& element : expr_->vals_) { @@ -354,8 +381,10 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { } auto execute_sub_batch = [](const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::unordered_set& elements) { auto executor = [&](size_t i) { std::unordered_set tmp_elements(elements); @@ -369,12 +398,16 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { return tmp_elements.size() == 0; }; for (int i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -395,9 +428,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::unordered_set elements; @@ -406,8 +441,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::unordered_set& elements) { auto executor = [&](const size_t i) -> bool { @@ -431,12 +468,16 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { return tmp_elements.size() == 0; }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -451,9 +492,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -467,8 +510,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements, const std::unordered_set elements_index) { @@ -553,6 +598,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { return tmp_elements_index.size() == 0; }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -560,6 +609,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { int64_t processed_size = ProcessDataChunks(execute_sub_batch, std::nullptr_t{}, res, + valid_res, pointer, elements, elements_index); @@ -578,9 +628,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -590,8 +642,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements) { auto executor = [&](const size_t i) { @@ -625,12 +679,16 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { return exist_elements_index.size() == elements.size(); }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -646,9 +704,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -662,8 +722,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements) { auto executor = [&](const size_t i) { @@ -739,12 +801,16 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -832,12 +898,12 @@ PhyJsonContainsFilterExpr::ExecArrayContainsForIndexSegmentImpl() { } }; auto res = ProcessIndexChunks(execute_sub_batch, elems); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } } //namespace exec diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp index d388ab2454cc3..4267f770389ec 100644 --- a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp @@ -45,6 +45,10 @@ PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) { "unsupported logical operator: {}", expr_->GetOpTypeString()); } + TargetBitmapView lvalid_view(lflat->GetValidRawData(), size); + TargetBitmapView rvalid_view(rflat->GetValidRawData(), size); + LogicalElementFunc func; + func(lvalid_view, rvalid_view, size); result = std::move(left); } diff --git a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp index 4d4bb550691c2..14bde4decdab0 100644 --- a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp +++ b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp @@ -30,6 +30,9 @@ PhyLogicalUnaryExpr::Eval(EvalCtx& context, VectorPtr& result) { auto flat_vec = GetColumnVector(result); TargetBitmapView data(flat_vec->GetRawData(), flat_vec->size()); data.flip(); + TargetBitmapView valid_data(flat_vec->GetValidRawData(), + flat_vec->size()); + data &= valid_data; } } diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp index 0aaf7a4e69f74..fcb27a1c747a2 100644 --- a/internal/core/src/exec/expression/TermExpr.cpp +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -15,6 +15,8 @@ // limitations under the License. #include "TermExpr.h" +#include +#include #include "query/Utils.h" namespace milvus { namespace exec { @@ -199,9 +201,12 @@ PhyTermFilterExpr::ExecPkTermImpl() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + // pk valid_bitmap is always all true + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); for (size_t i = 0; i < real_batch_size; ++i) { res[i] = cached_bits_[current_data_chunk_pos_++]; @@ -241,17 +246,21 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); AssertInfo(expr_->vals_.size() == 1, "element length in json array must be one"); ValueType target_val = GetValueFromProto(expr_->vals_[0]); auto execute_sub_batch = [](const ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const ValueType& target_val) { auto executor = [&](size_t i) { for (int i = 0; i < data[i].length(); i++) { @@ -263,12 +272,16 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { return false; }; for (int i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, target_val); + execute_sub_batch, std::nullptr_t{}, res, valid_res, target_val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -289,9 +302,11 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); int index = -1; if (expr_->column_.nested_path_.size() > 0) { @@ -309,12 +324,18 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { } auto execute_sub_batch = [](const ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, int index, const std::unordered_set& term_set) { for (int i = 0; i < size; ++i) { - if (index >= data[i].length()) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } + if (term_set.empty() || index >= data[i].length()) { res[i] = false; continue; } @@ -324,7 +345,7 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, index, term_set); + execute_sub_batch, std::nullptr_t{}, res, valid_res, index, term_set); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -344,9 +365,11 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); AssertInfo(expr_->vals_.size() == 1, "element length in json array must be one"); @@ -354,8 +377,10 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto execute_sub_batch = [](const Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string pointer, const ValueType& target_val) { auto executor = [&](size_t i) { @@ -375,11 +400,15 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, val); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -399,9 +428,11 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::unordered_set term_set; @@ -416,8 +447,10 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { } auto execute_sub_batch = [](const Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string pointer, const std::unordered_set& terms) { auto executor = [&](size_t i) { @@ -439,11 +472,19 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { return terms.find(ValueType(x.value())) != terms.end(); }; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } + if (terms.empty()) { + res[i] = false; + continue; + } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, term_set); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, term_set); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -489,12 +530,12 @@ PhyTermFilterExpr::ExecVisitorImplForIndex() { return func(index_ptr, vals.size(), vals.data()); }; auto res = ProcessIndexChunks(execute_sub_batch, vals); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } template <> @@ -516,7 +557,7 @@ PhyTermFilterExpr::ExecVisitorImplForIndex() { return std::move(func(index_ptr, vals.size(), (bool*)vals.data())); }; auto res = ProcessIndexChunks(execute_sub_batch, vals); - return std::make_shared(std::move(res)); + return res; } template @@ -527,9 +568,11 @@ PhyTermFilterExpr::ExecVisitorImplForData() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::vector vals; for (auto& val : expr_->vals_) { @@ -542,16 +585,22 @@ PhyTermFilterExpr::ExecVisitorImplForData() { } std::unordered_set vals_set(vals.begin(), vals.end()); auto execute_sub_batch = [](const T* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::unordered_set& vals) { TermElementFuncSet func; for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = func(vals, data[i]); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, vals_set); + execute_sub_batch, std::nullptr_t{}, res, valid_res, vals_set); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index 3b7c2116244fb..ad3cd8cb294d1 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "UnaryExpr.h" +#include #include "common/Json.h" namespace milvus { @@ -260,9 +261,11 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); ValueType val = GetValueFromProto(expr_->val_); auto op_type = expr_->op_type_; @@ -271,48 +274,50 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { index = std::stoi(expr_->column_.nested_path_[0]); } auto execute_sub_batch = [op_type](const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val, int index) { switch (op_type) { case proto::plan::GreaterThan: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::GreaterEqual: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::LessThan: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::LessEqual: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::Equal: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::NotEqual: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::PrefixMatch: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } default: @@ -323,7 +328,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val, index); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val, index); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -432,14 +437,14 @@ PhyUnaryRangeFilterExpr::ExecArrayEqualForIndex(bool reverse) { } return res; }); - AssertInfo(batch_res.size() == real_batch_size, + AssertInfo(batch_res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - batch_res.size(), + batch_res->size(), real_batch_size); // return the result. - return std::make_shared(std::move(batch_res)); + return batch_res; } template @@ -455,9 +460,11 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } ExprValueType val = GetValueFromProto(expr_->val_); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto op_type = expr_->op_type_; auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -492,12 +499,18 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } while (false) auto execute_sub_batch = [op_type, pointer](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ExprValueType val) { switch (op_type) { case proto::plan::GreaterThan: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -508,6 +521,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::GreaterEqual: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -518,6 +535,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::LessThan: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -528,6 +549,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::LessEqual: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -538,6 +563,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::Equal: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -554,6 +583,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::NotEqual: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -570,6 +603,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::PrefixMatch: { for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -584,6 +621,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { auto regex_pattern = translator(val); RegexMatcher matcher(regex_pattern); for (size_t i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -601,7 +642,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -693,12 +734,12 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { }; auto val = GetValueFromProto(expr_->val_); auto res = ProcessIndexChunks(execute_sub_batch, val); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } template @@ -720,10 +761,11 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { switch (expr_->op_type_) { case proto::plan::GreaterThan: case proto::plan::GreaterEqual: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; if (milvus::query::lt_lb(val)) { res.set(); @@ -733,10 +775,11 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { } case proto::plan::LessThan: case proto::plan::LessEqual: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; if (milvus::query::gt_ub(val)) { res.set(); @@ -745,19 +788,21 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { return res_vec; } case proto::plan::Equal: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; res.reset(); return res_vec; } case proto::plan::NotEqual: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; res.set(); return res_vec; @@ -788,13 +833,17 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { return nullptr; } IndexInnerType val = GetValueFromProto(expr_->val_); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* data, + const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, IndexInnerType val) { switch (expr_type) { case proto::plan::GreaterThan: { @@ -843,6 +892,16 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { fmt::format("unsupported operator type for unary expr: {}", expr_type)); } + // there is a batch operation in BinaryRangeElementFunc, + // so not divide data again for the reason that it may reduce performance if the null distribution is scattered + // but to mask res with valid_data after the batch operation. + if (valid_data != nullptr) { + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + res[i] = valid_res[i] = false; + } + } + } }; auto skip_index_func = [expr_type, val](const SkipIndex& skip_index, FieldId field_id, @@ -850,8 +909,8 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { return skip_index.CanSkipUnaryRange( field_id, chunk_id, expr_type, val); }; - int64_t processed_size = - ProcessDataChunks(execute_sub_batch, skip_index_func, res, val); + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, skip_index_func, res, valid_res, val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}, related params[active_count:{}, " @@ -881,7 +940,7 @@ PhyUnaryRangeFilterExpr::ExecTextMatch() { return index->MatchQuery(query); }; auto res = ProcessTextMatchIndex(func, query); - return std::make_shared(std::move(res)); + return res; }; } // namespace exec diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index 83711f6d70dab..71a8869ecd291 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -148,11 +148,17 @@ struct UnaryElementFuncForArray { ValueType>; void operator()(const ArrayView* src, + const bool* valid_data, size_t size, ValueType val, int index, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { for (int i = 0; i < size; ++i) { + if (valid_data != nullptr && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } if constexpr (op == proto::plan::OpType::Equal) { if constexpr (std::is_same_v) { res[i] = src[i].is_same_array(val); @@ -224,7 +230,11 @@ struct UnaryIndexFuncForMatch { RegexMatcher matcher(regex_pattern); for (int64_t i = 0; i < cnt; i++) { auto raw = index->Reverse_Lookup(i); - res[i] = matcher(raw); + if (!raw.has_value()) { + res[i] = false; + continue; + } + res[i] = matcher(raw.value()); } return res; } diff --git a/internal/core/src/exec/operator/FilterBitsNode.cpp b/internal/core/src/exec/operator/FilterBitsNode.cpp index 7ad302cbec371..f7716a3fa19b0 100644 --- a/internal/core/src/exec/operator/FilterBitsNode.cpp +++ b/internal/core/src/exec/operator/FilterBitsNode.cpp @@ -68,6 +68,7 @@ PhyFilterBitsNode::GetOutput() { operator_context_->get_exec_context(), exprs_.get(), input_.get()); TargetBitmap bitset; + TargetBitmap valid_bitset; while (num_processed_rows_ < need_process_rows_) { exprs_->Eval(0, 1, true, eval_ctx, results_); @@ -79,13 +80,17 @@ PhyFilterBitsNode::GetOutput() { auto col_vec_size = col_vec->size(); TargetBitmapView view(col_vec->GetRawData(), col_vec_size); bitset.append(view); + TargetBitmapView valid_view(col_vec->GetValidRawData(), col_vec_size); + valid_bitset.append(valid_view); num_processed_rows_ += col_vec_size; } bitset.flip(); Assert(bitset.size() == need_process_rows_); + Assert(valid_bitset.size() == need_process_rows_); // num_processed_rows_ = need_process_rows_; std::vector col_res; - col_res.push_back(std::make_shared(std::move(bitset))); + col_res.push_back(std::make_shared(std::move(bitset), + std::move(valid_bitset))); std::chrono::high_resolution_clock::time_point scalar_end = std::chrono::high_resolution_clock::now(); double scalar_cost = diff --git a/internal/core/src/exec/operator/MvccNode.cpp b/internal/core/src/exec/operator/MvccNode.cpp index eeae9ebf3748d..98d7b4862abff 100644 --- a/internal/core/src/exec/operator/MvccNode.cpp +++ b/internal/core/src/exec/operator/MvccNode.cpp @@ -51,13 +51,15 @@ PhyMvccNode::GetOutput() { is_finished_ = true; return nullptr; } - - auto col_input = - is_source_node_ - ? std::make_shared(TargetBitmap(active_count_)) - : GetColumnVector(input_); + // the first vector is filtering result and second bitset is a valid bitset + // if valid_bitset[i]==false, means result[i] is null + auto col_input = is_source_node_ ? std::make_shared( + TargetBitmap(active_count_), + TargetBitmap(active_count_)) + : GetColumnVector(input_); TargetBitmapView data(col_input->GetRawData(), col_input->size()); + // need to expose null? segment_->mask_with_timestamps(data, query_timestamp_); segment_->mask_with_delete(data, active_count_, query_timestamp_); is_finished_ = true; diff --git a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h index 78833a8d34cd5..e6a95c6603809 100644 --- a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h +++ b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h @@ -100,7 +100,9 @@ class SealedDataGetter : public DataGetter { } return field_data_->operator[](idx); } else { - return (*field_index_).Reverse_Lookup(idx); + auto raw = (*field_index_).Reverse_Lookup(idx); + AssertInfo(raw.has_value(), "field data not found"); + return raw.value(); } } }; diff --git a/internal/core/src/index/BitmapIndex.cpp b/internal/core/src/index/BitmapIndex.cpp index cc4de8e3bf358..b5bca930d57b5 100644 --- a/internal/core/src/index/BitmapIndex.cpp +++ b/internal/core/src/index/BitmapIndex.cpp @@ -80,7 +80,7 @@ BitmapIndex::Build(const Config& config) { template void -BitmapIndex::Build(size_t n, const T* data) { +BitmapIndex::Build(size_t n, const T* data, const bool* valid_data) { if (is_built_) { return; } @@ -89,12 +89,14 @@ BitmapIndex::Build(size_t n, const T* data) { } total_num_rows_ = n; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); T* p = const_cast(data); for (int i = 0; i < n; ++i, ++p) { - data_[*p].add(i); - valid_bitset.set(i); + if (valid_data == nullptr || valid_data[i]) { + data_[*p].add(i); + valid_bitset_.set(i); + } } if (data_.size() < DEFAULT_BITMAP_INDEX_BUILD_MODE_BOUND) { @@ -120,7 +122,7 @@ BitmapIndex::BuildPrimitiveField( if (data->is_valid(i)) { auto val = reinterpret_cast(data->RawValue(i)); data_[*val].add(offset); - valid_bitset.set(offset); + valid_bitset_.set(offset); } offset++; } @@ -139,7 +141,7 @@ BitmapIndex::BuildWithFieldData( PanicInfo(DataIsEmpty, "scalar bitmap index can not build null values"); } total_num_rows_ = total_num_rows; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); switch (schema_.data_type()) { case proto::schema::DataType::Bool: @@ -184,7 +186,7 @@ BitmapIndex::BuildArrayField(const std::vector& field_datas) { auto val = array->template get_data(j); data_[val].add(offset); } - valid_bitset.set(offset); + valid_bitset_.set(offset); } offset++; } @@ -359,7 +361,7 @@ BitmapIndex::DeserializeIndexData(const uint8_t* data_ptr, data_[key] = value; } for (const auto& v : value) { - valid_bitset.set(v); + valid_bitset_.set(v); } } } @@ -422,7 +424,7 @@ BitmapIndex::DeserializeIndexData(const uint8_t* data_ptr, data_[key] = value; } for (const auto& v : value) { - valid_bitset.set(v); + valid_bitset_.set(v); } } } @@ -516,7 +518,7 @@ BitmapIndex::LoadWithoutAssemble(const BinarySet& binary_set, index_meta_buffer->size); auto index_length = index_meta.first; total_num_rows_ = index_meta.second; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); auto index_data_buffer = binary_set.GetByName(BITMAP_INDEX_DATA); @@ -645,7 +647,7 @@ BitmapIndex::NotIn(const size_t n, const T* values) { } } // NotIn(null) and In(null) is both false, need to mask with IsNotNull operate - res &= valid_bitset; + res &= valid_bitset_; return res; } else { TargetBitmap res(total_num_rows_, false); @@ -657,7 +659,7 @@ BitmapIndex::NotIn(const size_t n, const T* values) { } res.flip(); // NotIn(null) and In(null) is both false, need to mask with IsNotNull operate - res &= valid_bitset; + res &= valid_bitset_; return res; } } @@ -667,7 +669,7 @@ const TargetBitmap BitmapIndex::IsNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap res(total_num_rows_, true); - res &= valid_bitset; + res &= valid_bitset_; res.flip(); return res; } @@ -677,7 +679,7 @@ const TargetBitmap BitmapIndex::IsNotNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap res(total_num_rows_, true); - res &= valid_bitset; + res &= valid_bitset_; return res; } @@ -1086,11 +1088,15 @@ BitmapIndex::Reverse_Lookup_InCache(size_t idx) const { } template -T +std::optional BitmapIndex::Reverse_Lookup(size_t idx) const { AssertInfo(is_built_, "index has not been built"); AssertInfo(idx < total_num_rows_, "out of range of total coun"); + if (!valid_bitset_[idx]) { + return std::nullopt; + } + if (use_offset_cache_) { return Reverse_Lookup_InCache(idx); } @@ -1125,6 +1131,7 @@ BitmapIndex::Reverse_Lookup(size_t idx) const { fmt::format( "scalar bitmap index can not lookup target value of index {}", idx)); + return std::nullopt; } template diff --git a/internal/core/src/index/BitmapIndex.h b/internal/core/src/index/BitmapIndex.h index eb11e75441348..fb677e6f3194f 100644 --- a/internal/core/src/index/BitmapIndex.h +++ b/internal/core/src/index/BitmapIndex.h @@ -77,7 +77,7 @@ class BitmapIndex : public ScalarIndex { } void - Build(size_t n, const T* values) override; + Build(size_t n, const T* values, const bool* valid_data = nullptr) override; void Build(const Config& config = {}) override; @@ -106,7 +106,7 @@ class BitmapIndex : public ScalarIndex { T upper_bound_value, bool ub_inclusive) override; - T + std::optional Reverse_Lookup(size_t offset) const override; int64_t @@ -267,7 +267,7 @@ class BitmapIndex : public ScalarIndex { std::shared_ptr file_manager_; // generate valid_bitset to speed up NotIn and IsNull and IsNotNull operate - TargetBitmap valid_bitset; + TargetBitmap valid_bitset_; }; } // namespace index diff --git a/internal/core/src/index/HybridScalarIndex.h b/internal/core/src/index/HybridScalarIndex.h index 0829afc963fbc..8b6e484c2a71f 100644 --- a/internal/core/src/index/HybridScalarIndex.h +++ b/internal/core/src/index/HybridScalarIndex.h @@ -67,10 +67,12 @@ class HybridScalarIndex : public ScalarIndex { } void - Build(size_t n, const T* values) override { + Build(size_t n, + const T* values, + const bool* valid_data = nullptr) override { SelectIndexBuildType(n, values); auto index = GetInternalIndex(); - index->Build(n, values); + index->Build(n, values, valid_data); is_built_ = true; } @@ -133,7 +135,7 @@ class HybridScalarIndex : public ScalarIndex { lower_bound_value, lb_inclusive, upper_bound_value, ub_inclusive); } - T + std::optional Reverse_Lookup(size_t offset) const override { return internal_index_->Reverse_Lookup(offset); } diff --git a/internal/core/src/index/InvertedIndexTantivy.h b/internal/core/src/index/InvertedIndexTantivy.h index 9d7febfd90942..62ecff9f29470 100644 --- a/internal/core/src/index/InvertedIndexTantivy.h +++ b/internal/core/src/index/InvertedIndexTantivy.h @@ -94,7 +94,7 @@ class InvertedIndexTantivy : public ScalarIndex { * deprecated, only used in small chunk index. */ void - Build(size_t n, const T* values) override { + Build(size_t n, const T* values, const bool* valid_data) override { PanicInfo(ErrorCode::NotImplemented, "Build should not be called"); } @@ -136,7 +136,7 @@ class InvertedIndexTantivy : public ScalarIndex { return false; } - T + std::optional Reverse_Lookup(size_t offset) const override { PanicInfo(ErrorCode::NotImplemented, "Reverse_Lookup should not be handled by inverted index"); diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index 6105ce4afb980..6f411179cc3ea 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -80,7 +80,7 @@ class ScalarIndex : public IndexBase { GetIndexType() const = 0; virtual void - Build(size_t n, const T* values) = 0; + Build(size_t n, const T* values, const bool* valid_data = nullptr) = 0; virtual const TargetBitmap In(size_t n, const T* values) = 0; @@ -117,7 +117,7 @@ class ScalarIndex : public IndexBase { T upper_bound_value, bool ub_inclusive) = 0; - virtual T + virtual std::optional Reverse_Lookup(size_t offset) const = 0; virtual const TargetBitmap diff --git a/internal/core/src/index/ScalarIndexSort.cpp b/internal/core/src/index/ScalarIndexSort.cpp index 56396d7d192f3..8d55832b1d4c7 100644 --- a/internal/core/src/index/ScalarIndexSort.cpp +++ b/internal/core/src/index/ScalarIndexSort.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -61,7 +62,7 @@ ScalarIndexSort::Build(const Config& config) { template void -ScalarIndexSort::Build(size_t n, const T* values) { +ScalarIndexSort::Build(size_t n, const T* values, const bool* valid_data) { if (is_built_) return; if (n == 0) { @@ -69,13 +70,17 @@ ScalarIndexSort::Build(size_t n, const T* values) { } data_.reserve(n); total_num_rows_ = n; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); idx_to_offsets_.resize(n); + T* p = const_cast(values); - for (size_t i = 0; i < n; ++i) { - data_.emplace_back(IndexStructure(*p++, i)); - valid_bitset.set(i); + for (size_t i = 0; i < n; ++i, ++p) { + if (!valid_data || valid_data[i]) { + data_.emplace_back(IndexStructure(*p, i)); + valid_bitset_.set(i); + } } + std::sort(data_.begin(), data_.end()); for (size_t i = 0; i < data_.size(); ++i) { idx_to_offsets_[data_[i].idx_] = i; @@ -97,7 +102,7 @@ ScalarIndexSort::BuildWithFieldData( } data_.reserve(length); - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); int64_t offset = 0; for (const auto& data : field_datas) { auto slice_num = data->get_num_rows(); @@ -105,7 +110,7 @@ ScalarIndexSort::BuildWithFieldData( if (data->is_valid(i)) { auto value = reinterpret_cast(data->RawValue(i)); data_.emplace_back(IndexStructure(*value, offset)); - valid_bitset.set(offset); + valid_bitset_.set(offset); } offset++; } @@ -175,11 +180,11 @@ ScalarIndexSort::LoadWithoutAssemble(const BinarySet& index_binary, index_num_rows->data.get(), (size_t)index_num_rows->size); idx_to_offsets_.resize(total_num_rows_); - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size); for (size_t i = 0; i < data_.size(); ++i) { idx_to_offsets_[data_[i].idx_] = i; - valid_bitset.set(data_[i].idx_); + valid_bitset_.set(data_[i].idx_); } is_built_ = true; @@ -256,7 +261,7 @@ ScalarIndexSort::NotIn(const size_t n, const T* values) { } } // NotIn(null) and In(null) is both false, need to mask with IsNotNull operate - bitset &= valid_bitset; + bitset &= valid_bitset_; return bitset; } @@ -265,7 +270,7 @@ const TargetBitmap ScalarIndexSort::IsNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap bitset(total_num_rows_, true); - bitset &= valid_bitset; + bitset &= valid_bitset_; bitset.flip(); return bitset; } @@ -275,7 +280,7 @@ const TargetBitmap ScalarIndexSort::IsNotNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap bitset(total_num_rows_, true); - bitset &= valid_bitset; + bitset &= valid_bitset_; return bitset; } @@ -355,11 +360,14 @@ ScalarIndexSort::Range(T lower_bound_value, } template -T +std::optional ScalarIndexSort::Reverse_Lookup(size_t idx) const { AssertInfo(idx < idx_to_offsets_.size(), "out of range of total count"); AssertInfo(is_built_, "index has not been built"); + if (!valid_bitset_[idx]) { + return std::nullopt; + } auto offset = idx_to_offsets_[idx]; return data_[offset].a_; } diff --git a/internal/core/src/index/ScalarIndexSort.h b/internal/core/src/index/ScalarIndexSort.h index fb33f030c2a03..1370b9dff89d6 100644 --- a/internal/core/src/index/ScalarIndexSort.h +++ b/internal/core/src/index/ScalarIndexSort.h @@ -56,7 +56,7 @@ class ScalarIndexSort : public ScalarIndex { } void - Build(size_t n, const T* values) override; + Build(size_t n, const T* values, const bool* valid_data = nullptr) override; void Build(const Config& config = {}) override; @@ -82,7 +82,7 @@ class ScalarIndexSort : public ScalarIndex { T upper_bound_value, bool ub_inclusive) override; - T + std::optional Reverse_Lookup(size_t offset) const override; int64_t @@ -127,8 +127,8 @@ class ScalarIndexSort : public ScalarIndex { std::vector> data_; std::shared_ptr file_manager_; size_t total_num_rows_{0}; - // generate valid_bitset to speed up NotIn and IsNull and IsNotNull operate - TargetBitmap valid_bitset; + // generate valid_bitset_ to speed up NotIn and IsNull and IsNotNull operate + TargetBitmap valid_bitset_; }; template diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index e3c853193571a..289ba2409da86 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -118,7 +119,9 @@ StringIndexMarisa::BuildWithFieldData( } void -StringIndexMarisa::Build(size_t n, const std::string* values) { +StringIndexMarisa::Build(size_t n, + const std::string* values, + const bool* valid_data) { if (built_) { PanicInfo(IndexAlreadyBuild, "index has been built"); } @@ -127,12 +130,14 @@ StringIndexMarisa::Build(size_t n, const std::string* values) { { // fill key set. for (size_t i = 0; i < n; i++) { - keyset.push_back(values[i].c_str()); + if (valid_data == nullptr || valid_data[i]) { + keyset.push_back(values[i].c_str()); + } } } trie_.build(keyset, MARISA_LABEL_ORDER); - fill_str_ids(n, values); + fill_str_ids(n, values, valid_data); fill_offsets(); built_ = true; @@ -213,7 +218,7 @@ StringIndexMarisa::LoadWithoutAssemble(const BinarySet& set, auto str_ids = set.GetByName(MARISA_STR_IDS); auto str_ids_len = str_ids->size; - str_ids_.resize(str_ids_len / sizeof(size_t)); + str_ids_.resize(str_ids_len / sizeof(size_t), MARISA_NULL_KEY_ID); memcpy(str_ids_.data(), str_ids->data.get(), str_ids_len); fill_offsets(); @@ -491,9 +496,14 @@ StringIndexMarisa::PrefixMatch(std::string_view prefix) { } void -StringIndexMarisa::fill_str_ids(size_t n, const std::string* values) { - str_ids_.resize(n); +StringIndexMarisa::fill_str_ids(size_t n, + const std::string* values, + const bool* valid_data) { + str_ids_.resize(n, MARISA_NULL_KEY_ID); for (size_t i = 0; i < n; i++) { + if (valid_data != nullptr && !valid_data[i]) { + continue; + } auto str = values[i]; auto str_id = lookup(str); AssertInfo(valid_str_id(str_id), "invalid marisa key"); @@ -534,11 +544,13 @@ StringIndexMarisa::prefix_match(const std::string_view prefix) { } return ret; } - -std::string +std::optional StringIndexMarisa::Reverse_Lookup(size_t offset) const { AssertInfo(offset < str_ids_.size(), "out of range of total count"); marisa::Agent agent; + if (str_ids_[offset] < 0) { + return std::nullopt; + } agent.set_query(str_ids_[offset]); trie_.reverse_lookup(agent); return std::string(agent.key().ptr(), agent.key().length()); diff --git a/internal/core/src/index/StringIndexMarisa.h b/internal/core/src/index/StringIndexMarisa.h index 72913d6675987..f3dff120897f0 100644 --- a/internal/core/src/index/StringIndexMarisa.h +++ b/internal/core/src/index/StringIndexMarisa.h @@ -55,7 +55,9 @@ class StringIndexMarisa : public StringIndex { } void - Build(size_t n, const std::string* values) override; + Build(size_t n, + const std::string* values, + const bool* valid_data = nullptr) override; void Build(const Config& config = {}) override; @@ -87,7 +89,7 @@ class StringIndexMarisa : public StringIndex { const TargetBitmap PrefixMatch(const std::string_view prefix) override; - std::string + std::optional Reverse_Lookup(size_t offset) const override; BinarySet @@ -100,7 +102,7 @@ class StringIndexMarisa : public StringIndex { private: void - fill_str_ids(size_t n, const std::string* values); + fill_str_ids(size_t n, const std::string* values, const bool* valid_data); void fill_offsets(); @@ -122,7 +124,7 @@ class StringIndexMarisa : public StringIndex { private: Config config_; marisa::Trie trie_; - std::vector str_ids_; // used to retrieve. + std::vector str_ids_; // used to retrieve. std::map> str_ids_to_offsets_; bool built_ = false; std::shared_ptr file_manager_; diff --git a/internal/core/src/query/ScalarIndex.h b/internal/core/src/query/ScalarIndex.h index eb9d0f3a18687..b72a68c6d5bc8 100644 --- a/internal/core/src/query/ScalarIndex.h +++ b/internal/core/src/query/ScalarIndex.h @@ -26,7 +26,7 @@ template inline index::ScalarIndexPtr generate_scalar_index(Span data) { auto indexing = std::make_unique>(); - indexing->Build(data.row_count(), data.data()); + indexing->Build(data.row_count(), data.data(), data.valid_data()); return indexing; } @@ -34,7 +34,7 @@ template <> inline index::ScalarIndexPtr generate_scalar_index(Span data) { auto indexing = index::CreateStringIndexSort(); - indexing->Build(data.row_count(), data.data()); + indexing->Build(data.row_count(), data.data(), data.valid_data()); return indexing; } diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index a95ae1ecd1665..9987347f3fe57 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -196,8 +196,9 @@ ChunkedSegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && int64_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk(int64_index->Reverse_Lookup(i), - i); + auto raw = int64_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "pk not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -210,8 +211,9 @@ ChunkedSegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && string_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk( - string_index->Reverse_Lookup(i), i); + auto raw = string_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "pk not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -1630,7 +1632,11 @@ ChunkedSegmentSealedImpl::CreateTextIndex(FieldId field_id) { "converted to string index"); auto n = impl->Size(); for (size_t i = 0; i < n; i++) { - index->AddText(impl->Reverse_Lookup(i), i); + auto raw = impl->Reverse_Lookup(i); + if (!raw.has_value()) { + continue; + } + index->AddText(raw.value(), i); } } } diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index eb81947fcdba4..8c924e24ba01e 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -299,6 +299,7 @@ ScalarFieldIndexing::BuildIndexRange(int64_t ack_beg, for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) { auto chunk_data = source->get_chunk_data(chunk_id); // build index for chunk + // seem no lint, not pass valid_data here // TODO if constexpr (std::is_same_v) { auto indexing = index::CreateStringIndexSort(); diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 03a59dbf1cb0a..dea574191845d 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -198,8 +198,9 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && int64_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk(int64_index->Reverse_Lookup(i), - i); + auto raw = int64_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "Primary key not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -212,8 +213,9 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && string_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk( - string_index->Reverse_Lookup(i), i); + auto raw = string_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "Primary key not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -2108,7 +2110,11 @@ SegmentSealedImpl::CreateTextIndex(FieldId field_id) { "converted to string index"); auto n = impl->Size(); for (size_t i = 0; i < n; i++) { - index->AddText(impl->Reverse_Lookup(i), i); + auto raw = impl->Reverse_Lookup(i); + if (!raw.has_value()) { + continue; + } + index->AddText(raw.value(), i); } } } diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index e0bd00007b461..30b01caa86a4d 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -683,6 +683,11 @@ ReverseDataFromIndex(const index::IndexBase* index, data_array->set_field_id(field_meta.get_id().get()); data_array->set_type(static_cast( field_meta.get_data_type())); + auto nullable = field_meta.is_nullable(); + std::vector valid_data; + if (nullable) { + valid_data.resize(count); + } auto scalar_array = data_array->mutable_scalars(); switch (data_type) { @@ -691,7 +696,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_bool_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -702,7 +716,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -713,7 +736,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -724,7 +756,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -735,7 +776,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_long_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -746,7 +796,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_float_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -757,7 +816,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_double_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -768,7 +836,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here + if (!raw.has_value()) { + valid_data[i] = false; + continue; + } + if (nullable) { + valid_data[i] = true; + } + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_string_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -780,6 +857,11 @@ ReverseDataFromIndex(const index::IndexBase* index, } } + if (nullable) { + *(data_array->mutable_valid_data()) = {valid_data.begin(), + valid_data.end()}; + } + return data_array; } diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 2bfc4646d10af..51079941a55ad 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -379,6 +379,255 @@ TEST_P(ExprTest, TestRange) { } } +TEST_P(ExprTest, TestRangeNullable) { + std::vector>> + testcases = { + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: false, + upper_inclusive: false, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 < v && v < 3000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: true, + upper_inclusive: false, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 <= v && v < 3000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: false, + upper_inclusive: true, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 < v && v <= 3000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: true, + upper_inclusive: true, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 <= v && v <= 3000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: GreaterEqual, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v >= 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: GreaterThan, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v > 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: LessEqual, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v <= 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: LessThan, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v < 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: Equal, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: NotEqual, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v != 2000; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age", DataType::INT64); + schema->set_primary_field_id(i64_fid); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector data_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_data_col = raw_data.get_col(i64_fid); + valid_data_col = raw_data.get_col_valid(nullable_fid); + data_col.insert( + data_col.end(), new_data_col.begin(), new_data_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = data_col[i]; + auto valid_data = valid_data_col[i]; + auto ref = ref_func(val, valid_data); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } + } +} + TEST_P(ExprTest, TestBinaryRangeJSON) { struct Testcase { bool lower_inclusive; @@ -475,25 +724,34 @@ TEST_P(ExprTest, TestBinaryRangeJSON) { } } -TEST_P(ExprTest, TestExistsJson) { +TEST_P(ExprTest, TestBinaryRangeJSONNullable) { struct Testcase { + bool lower_inclusive; + bool upper_inclusive; + int64_t lower; + int64_t upper; std::vector nested_path; }; std::vector testcases{ - {{"A"}}, - {{"int"}}, - {{"double"}}, - {{"B"}}, + {true, false, 10, 20, {"int"}}, + {true, true, 20, 30, {"int"}}, + {false, true, 30, 40, {"int"}}, + {false, false, 40, 50, {"int"}}, + {true, false, 10, 20, {"double"}}, + {true, true, 20, 30, {"double"}}, + {false, true, 30, 40, {"double"}}, + {false, false, 40, 50, {"double"}}, }; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); @@ -501,6 +759,7 @@ TEST_P(ExprTest, TestExistsJson) { json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); + valid_data_col = raw_data.get_col_valid(json_fid); seg->PreInsert(N); seg->Insert(iter * N, N, @@ -511,11 +770,32 @@ TEST_P(ExprTest, TestExistsJson) { auto seg_promote = dynamic_cast(seg.get()); for (auto testcase : testcases) { - auto check = [&](bool value) { return value; }; + auto check = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + int64_t lower = testcase.lower, upper = testcase.upper; + if (!testcase.lower_inclusive) { + lower++; + } + if (!testcase.upper_inclusive) { + upper--; + } + return lower <= value && value <= upper; + }; auto pointer = milvus::Json::pointer(testcase.nested_path); - auto expr = - std::make_shared(milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path)); + RetrievePlanNode plan; + milvus::proto::plan::GenericValue lower_val; + lower_val.set_int64_val(testcase.lower); + milvus::proto::plan::GenericValue upper_val; + upper_val.set_int64_val(testcase.upper); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + lower_val, + upper_val, + testcase.lower_inclusive, + testcase.upper_inclusive); BitsetType final; auto plannode = std::make_shared(DEFAULT_PLANNODE_ID, expr); @@ -525,33 +805,172 @@ TEST_P(ExprTest, TestExistsJson) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist(pointer); - auto ref = check(val); - ASSERT_EQ(ans, ref); + + if (testcase.nested_path[0] == "int") { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref) + << val << testcase.lower_inclusive << testcase.lower + << testcase.upper_inclusive << testcase.upper; + } else { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref) + << val << testcase.lower_inclusive << testcase.lower + << testcase.upper_inclusive << testcase.upper; + } } } } -template -T -GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) { - if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kBoolVal); - return static_cast(value_proto.bool_val()); - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kInt64Val); - return static_cast(value_proto.int64_val()); - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kFloatVal); - return static_cast(value_proto.float_val()); - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kStringVal); - return static_cast(value_proto.string_val()); +TEST_P(ExprTest, TestExistsJson) { + struct Testcase { + std::vector nested_path; + }; + std::vector testcases{ + {{"A"}}, + {{"int"}}, + {{"double"}}, + {{"B"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](bool value) { return value; }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = + std::make_shared(milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path)); + auto plannode = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final = ExecuteQueryExpr( + plannode, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist(pointer); + auto ref = check(val); + ASSERT_EQ(ans, ref); + } + } +} + +TEST_P(ExprTest, TestExistsJsonNullable) { + struct Testcase { + std::vector nested_path; + }; + std::vector testcases{ + {{"A"}}, + {{"int"}}, + {{"double"}}, + {{"B"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + valid_data_col = raw_data.get_col_valid(json_fid); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](bool value, bool valid) { + if (!valid) { + return false; + } + return value; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = + std::make_shared(milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path)); + auto plannode = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final = ExecuteQueryExpr( + plannode, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist(pointer); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } + } +} + +template +T +GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) { + if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kBoolVal); + return static_cast(value_proto.bool_val()); + } else if constexpr (std::is_integral_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kInt64Val); + return static_cast(value_proto.int64_val()); + } else if constexpr (std::is_floating_point_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kFloatVal); + return static_cast(value_proto.float_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kStringVal); + return static_cast(value_proto.string_val()); } else if constexpr (std::is_same_v) { Assert(value_proto.val_case() == milvus::proto::plan::GenericValue::kArrayVal); @@ -660,7 +1079,6 @@ TEST_P(ExprTest, TestUnaryRangeJson) { auto final = ExecuteQueryExpr( plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); - EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; @@ -734,30 +1152,36 @@ TEST_P(ExprTest, TestUnaryRangeJson) { } } -TEST_P(ExprTest, TestTermJson) { +TEST_P(ExprTest, TestUnaryRangeJsonNullable) { struct Testcase { - std::vector term; + int64_t val; std::vector nested_path; }; std::vector testcases{ - {{1, 2, 3, 4}, {"int"}}, - {{10, 100, 1000, 10000}, {"int"}}, - {{100, 10000, 9999, 444}, {"int"}}, - {{23, 42, 66, 17, 25}, {"int"}}, + {10, {"int"}}, + {20, {"int"}}, + {30, {"int"}}, + {40, {"int"}}, + {10, {"double"}}, + {20, {"double"}}, + {30, {"double"}}, + {40, {"double"}}, }; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + FixedVector valid_data_col; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data_col = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -770,52 +1194,333 @@ TEST_P(ExprTest, TestTermJson) { } auto seg_promote = dynamic_cast(seg.get()); - for (auto testcase : testcases) { - auto check = [&](int64_t value) { - std::unordered_set term_set(testcase.term.begin(), - testcase.term.end()); - return term_set.find(value) != term_set.end(); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + std::vector ops{ + OpType::Equal, + OpType::NotEqual, + OpType::GreaterThan, + OpType::GreaterEqual, + OpType::LessThan, + OpType::LessEqual, + }; + for (const auto& testcase : testcases) { + auto check = [&](int64_t value, bool valid) { + return value == testcase.val; }; - auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (const auto& val : testcase.term) { + std::function f = check; + for (auto& op : ops) { + switch (op) { + case OpType::Equal: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value == testcase.val; + }; + break; + } + case OpType::NotEqual: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value != testcase.val; + }; + break; + } + case OpType::GreaterEqual: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value >= testcase.val; + }; + break; + } + case OpType::GreaterThan: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value > testcase.val; + }; + break; + } + case OpType::LessEqual: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value <= testcase.val; + }; + break; + } + case OpType::LessThan: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value < testcase.val; + }; + break; + } + default: { + PanicInfo(Unsupported, "unsupported range node"); + } + } + + auto pointer = milvus::Json::pointer(testcase.nested_path); proto::plan::GenericValue value; - value.set_int64_val(val); - values.push_back(value); - } - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - values); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N * num_iters); + value.set_int64_val(testcase.val); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + op, + value); + BitsetType final; + auto plan = std::make_shared( + DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr( + plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); - for (int i = 0; i < N * num_iters; ++i) { - auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(pointer) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (testcase.nested_path[0] == "int") { + auto val = + milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = f(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } else { + auto val = + milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = f(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } + } } } -} -TEST_P(ExprTest, TestTerm) { - auto vec_2k_3k = [] { - std::string buf; - for (int i = 2000; i < 3000; ++i) { - buf += "values: < int64_val: " + std::to_string(i) + " >\n"; - } - return buf; - }(); + struct TestArrayCase { + proto::plan::GenericValue val; + std::vector nested_path; + }; - std::vector>> testcases = { - {R"(values: < + proto::plan::GenericValue value; + auto* arr = value.mutable_array_val(); + arr->set_same_type(true); + proto::plan::GenericValue int_val1; + int_val1.set_int64_val(int64_t(1)); + arr->add_array()->CopyFrom(int_val1); + + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(2)); + arr->add_array()->CopyFrom(int_val2); + + proto::plan::GenericValue int_val3; + int_val3.set_int64_val(int64_t(3)); + arr->add_array()->CopyFrom(int_val3); + + std::vector array_cases = {{value, {"array"}}}; + for (const auto& testcase : array_cases) { + auto check = [&](OpType op, bool valid) { + if (!valid) { + return false; + } + if (testcase.nested_path[0] == "array" && op == OpType::Equal) { + return true; + } + return false; + }; + for (auto& op : ops) { + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + op, + testcase.val); + BitsetType final; + auto plan = std::make_shared( + DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr( + plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = check(op, valid_data_col[i]); + ASSERT_EQ(ans, ref) << "@" << i << "op" << op; + } + } + } +} + +TEST_P(ExprTest, TestTermJson) { + struct Testcase { + std::vector term; + std::vector nested_path; + }; + std::vector testcases{ + {{1, 2, 3, 4}, {"int"}}, + {{10, 100, 1000, 10000}, {"int"}}, + {{100, 10000, 9999, 444}, {"int"}}, + {{23, 42, 66, 17, 25}, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](int64_t value) { + std::unordered_set term_set(testcase.term.begin(), + testcase.term.end()); + return term_set.find(value) != term_set.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue value; + value.set_int64_val(val); + values.push_back(value); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref); + } + } +} + +TEST_P(ExprTest, TestTermJsonNullable) { + struct Testcase { + std::vector term; + std::vector nested_path; + }; + std::vector testcases{ + {{1, 2, 3, 4}, {"int"}}, + {{10, 100, 1000, 10000}, {"int"}}, + {{100, 10000, 9999, 444}, {"int"}}, + {{23, 42, 66, 17, 25}, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data_col; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + auto new_valid_data_col = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + valid_data_col.insert(valid_data_col.end(), + new_valid_data_col.begin(), + new_valid_data_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set(testcase.term.begin(), + testcase.term.end()); + return term_set.find(value) != term_set.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue value; + value.set_int64_val(val); + values.push_back(value); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } + } +} + +TEST_P(ExprTest, TestTerm) { + auto vec_2k_3k = [] { + std::string buf; + for (int i = 2000; i < 3000; ++i) { + buf += "values: < int64_val: " + std::to_string(i) + " >\n"; + } + return buf; + }(); + + std::vector>> testcases = { + {R"(values: < int64_val: 2000 > values: < @@ -901,30 +1606,73 @@ TEST_P(ExprTest, TestTerm) { } } -TEST_P(ExprTest, TestCompare) { - std::vector>> +TEST_P(ExprTest, TestTermNullable) { + auto vec_2k_3k = [] { + std::string buf; + for (int i = 2000; i < 3000; ++i) { + buf += "values: < int64_val: " + std::to_string(i) + " >\n"; + } + return buf; + }(); + + std::vector>> testcases = { - {R"(LessThan)", [](int a, int64_t b) { return a < b; }}, - {R"(LessEqual)", [](int a, int64_t b) { return a <= b; }}, - {R"(GreaterThan)", [](int a, int64_t b) { return a > b; }}, - {R"(GreaterEqual)", [](int a, int64_t b) { return a >= b; }}, - {R"(Equal)", [](int a, int64_t b) { return a == b; }}, - {R"(NotEqual)", [](int a, int64_t b) { return a != b; }}, + {R"(values: < + int64_val: 2000 + > + values: < + int64_val: 3000 + > + )", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 2000 || v == 3000; + }}, + {R"(values: < + int64_val: 2000 + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 2000; + }}, + {R"(values: < + int64_val: 3000 + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 3000; + }}, + {R"()", + [](int v, bool valid) { + if (!valid) { + return false; + } + return false; + }}, + {vec_2k_3k, + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 <= v && v < 3000; + }}, }; std::string raw_plan_tmp = R"(vector_anns: < field_id: 100 predicates: < - compare_expr: < - left_column_info: < - field_id: 101 - data_type: Int32 - > - right_column_info: < + term_expr: < + column_info: < field_id: 102 data_type: Int64 > - op: @@@@ + @@@@ > > query_info: < @@ -937,23 +1685,26 @@ TEST_P(ExprTest, TestCompare) { >)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i32_fid = schema->AddDebugField("age1", DataType::INT32); - auto i64_fid = schema->AddDebugField("age2", DataType::INT64); + auto i64_fid = schema->AddDebugField("age", DataType::INT64); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - std::vector age1_col; - std::vector age2_col; - int num_iters = 1; + std::vector nullable_col; + FixedVector valid_data_col; + int num_iters = 100; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); - auto new_age1_col = raw_data.get_col(i32_fid); - auto new_age2_col = raw_data.get_col(i64_fid); - age1_col.insert( - age1_col.end(), new_age1_col.begin(), new_age1_col.end()); - age2_col.insert( - age2_col.end(), new_age2_col.begin(), new_age2_col.end()); + auto new_nullable_col = raw_data.get_col(nullable_fid); + auto new_valid_data_col = raw_data.get_col_valid(nullable_fid); + valid_data_col.insert(valid_data_col.end(), + new_valid_data_col.begin(), + new_valid_data_col.end()); + nullable_col.insert(nullable_col.end(), + new_nullable_col.begin(), + new_nullable_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, @@ -981,16 +1732,14 @@ TEST_P(ExprTest, TestCompare) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val1 = age1_col[i]; - auto val2 = age2_col[i]; - auto ref = ref_func(val1, val2); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" - << boost::format("[%1%, %2%]") % val1 % val2; + auto val = nullable_col[i]; + auto ref = ref_func(val, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } } } -TEST_P(ExprTest, TestCompareWithScalarIndex) { +TEST_P(ExprTest, TestCompare) { std::vector>> testcases = { {R"(LessThan)", [](int a, int64_t b) { return a < b; }}, @@ -1001,85 +1750,78 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { {R"(NotEqual)", [](int a, int64_t b) { return a != b; }}, }; - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - compare_expr: < - left_column_info: < - field_id: %3% - data_type: %4% - > - right_column_info: < - field_id: %5% - data_type: %6% - > - op: %2% - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + compare_expr: < + left_column_info: < + field_id: 101 + data_type: Int32 + > + right_column_info: < + field_id: 102 + data_type: Int64 + > + op: @@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" >)"; - auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i32_fid = schema->AddDebugField("age32", DataType::INT32); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto i32_fid = schema->AddDebugField("age1", DataType::INT32); + auto i64_fid = schema->AddDebugField("age2", DataType::INT64); schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); + auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - auto raw_data = DataGen(schema, N); - segcore::LoadIndexInfo load_index_info; - - // load index for int32 field - auto age32_col = raw_data.get_col(i32_fid); - age32_col[0] = 1000; - GenScalarIndexing(N, age32_col.data()); - auto age32_index = milvus::index::CreateScalarIndexSort(); - age32_index->Build(N, age32_col.data()); - load_index_info.field_id = i32_fid.get(); - load_index_info.field_type = DataType::INT32; - load_index_info.index = std::move(age32_index); - seg->LoadIndex(load_index_info); - - // load index for int64 field - auto age64_col = raw_data.get_col(i64_fid); - age64_col[0] = 2000; - GenScalarIndexing(N, age64_col.data()); - auto age64_index = milvus::index::CreateScalarIndexSort(); - age64_index->Build(N, age64_col.data()); - load_index_info.field_id = i64_fid.get(); - load_index_info.field_type = DataType::INT64; - load_index_info.index = std::move(age64_index); - seg->LoadIndex(load_index_info); + std::vector age1_col; + std::vector age2_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_age1_col = raw_data.get_col(i32_fid); + auto new_age2_col = raw_data.get_col(i64_fid); + age1_col.insert( + age1_col.end(), new_age1_col.begin(), new_age1_col.end()); + age2_col.insert( + age2_col.end(), new_age2_col.begin(), new_age2_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { - auto dsl_string = - boost::format(serialized_expr_plan) % vec_fid.get() % clause % - i32_fid.get() % proto::schema::DataType_Name(int(DataType::INT32)) % - i64_fid.get() % proto::schema::DataType_Name(int(DataType::INT64)); - auto binary_plan = - translate_text_plan_with_metric_type(dsl_string.str()); - auto plan = CreateSearchPlanByExpr( - *schema, binary_plan.data(), binary_plan.size()); - // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; final = ExecuteQueryExpr( plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg.get(), - N, + seg_promote, + N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + EXPECT_EQ(final.size(), N * num_iters); - for (int i = 0; i < N; ++i) { + for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val1 = age32_col[i]; - auto val2 = age64_col[i]; + + auto val1 = age1_col[i]; + auto val2 = age2_col[i]; auto ref = ref_func(val1, val2); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << boost::format("[%1%, %2%]") % val1 % val2; @@ -1087,575 +1829,635 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { } } -TEST_P(ExprTest, TestCompareExpr) { +TEST_P(ExprTest, TestCompareNullable) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + compare_expr: < + left_column_info: < + field_id: 101 + data_type: Int32 + > + right_column_info: < + field_id: 103 + data_type: Int64 + > + op: @@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); - auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto str3_fid = schema->AddDebugField("string3", DataType::VARCHAR); - schema->set_primary_field_id(str1_fid); + auto i32_fid = schema->AddDebugField("age1", DataType::INT32); + auto i64_fid = schema->AddDebugField("age2", DataType::INT64); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); + schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); - size_t N = 1000; - auto raw_data = DataGen(schema, N); - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age1_col; + std::vector nullable_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_age1_col = raw_data.get_col(i32_fid); + auto new_nullable_col = raw_data.get_col(nullable_fid); + valid_data_col = raw_data.get_col_valid(nullable_fid); + age1_col.insert( + age1_col.end(), new_age1_col.begin(), new_age1_col.end()); + nullable_col.insert(nullable_col.end(), + new_nullable_col.begin(), + new_nullable_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); - seg->LoadFieldData(FieldId(field_id), info); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val1 = age1_col[i]; + auto val2 = nullable_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } } +} - auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { - switch (type) { - case DataType::BOOL: { - auto compare_expr = std::make_shared( - bool_fid, - bool_1_fid, - DataType::BOOL, - DataType::BOOL, - proto::plan::OpType::LessThan); - return compare_expr; - } - case DataType::INT8: { - auto compare_expr = - std::make_shared(int8_fid, - int8_1_fid, - DataType::INT8, - DataType::INT8, - OpType::LessThan); - return compare_expr; - } - case DataType::INT16: { - auto compare_expr = - std::make_shared(int16_fid, - int16_1_fid, - DataType::INT16, - DataType::INT16, - OpType::LessThan); - return compare_expr; - } - case DataType::INT32: { - auto compare_expr = - std::make_shared(int32_fid, - int32_1_fid, - DataType::INT32, - DataType::INT32, - OpType::LessThan); - return compare_expr; - } - case DataType::INT64: { - auto compare_expr = - std::make_shared(int64_fid, - int64_1_fid, - DataType::INT64, - DataType::INT64, - OpType::LessThan); - return compare_expr; - } - case DataType::FLOAT: { - auto compare_expr = - std::make_shared(float_fid, - float_1_fid, - DataType::FLOAT, - DataType::FLOAT, - OpType::LessThan); - return compare_expr; - } - case DataType::DOUBLE: { - auto compare_expr = - std::make_shared(double_fid, - double_1_fid, - DataType::DOUBLE, - DataType::DOUBLE, - OpType::LessThan); - return compare_expr; - } - case DataType::VARCHAR: { - auto compare_expr = - std::make_shared(str2_fid, - str3_fid, - DataType::VARCHAR, - DataType::VARCHAR, - OpType::LessThan); - return compare_expr; - } - default: - return std::make_shared(int8_fid, - int8_1_fid, - DataType::INT8, - DataType::INT8, - OpType::LessThan); - } - }; - std::cout << "start compare test" << std::endl; - auto expr = build_expr(DataType::BOOL); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT8); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT16); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT32); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT64); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::FLOAT); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::DOUBLE); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - std::cout << "end compare test" << std::endl; -} +TEST_P(ExprTest, TestCompareNullable2) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; -TEST(Expr, TestExprPerformance) { - GTEST_SKIP() << "Skip performance test, open it when test performance"; + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + compare_expr: < + left_column_info: < + field_id: 103 + data_type: Int64 + > + right_column_info: < + field_id: 101 + data_type: Int32 + > + op: @@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; auto schema = std::make_shared(); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); - - std::map fids = {{DataType::INT8, int8_fid}, - {DataType::INT16, int16_fid}, - {DataType::INT32, int32_fid}, - {DataType::INT64, int64_fid}, - {DataType::VARCHAR, str2_fid}, - {DataType::FLOAT, float_fid}, - {DataType::DOUBLE, double_fid}}; + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i32_fid = schema->AddDebugField("age1", DataType::INT32); + auto i64_fid = schema->AddDebugField("age2", DataType::INT64); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); + schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age1_col; + std::vector nullable_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_age1_col = raw_data.get_col(i32_fid); + auto new_nullable_col = raw_data.get_col(nullable_fid); + valid_data_col = raw_data.get_col_valid(nullable_fid); + age1_col.insert( + age1_col.end(), new_age1_col.begin(), new_age1_col.end()); + nullable_col.insert(nullable_col.end(), + new_nullable_col.begin(), + new_nullable_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; - seg->LoadFieldData(FieldId(field_id), info); + auto val2 = age1_col[i]; + auto val1 = nullable_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } } +} - enum ExprType { - UnaryRangeExpr = 0, - TermExprImpl = 1, - CompareExpr = 2, - LogicalUnaryExpr = 3, - BinaryRangeExpr = 4, - LogicalBinaryExpr = 5, - BinaryArithOpEvalRangeExpr = 6, - }; +TEST_P(ExprTest, TestCompareWithScalarIndex) { + std::vector>> + testcases = { + {R"(LessThan)", [](int a, int64_t b) { return a < b; }}, + {R"(LessEqual)", [](int a, int64_t b) { return a <= b; }}, + {R"(GreaterThan)", [](int a, int64_t b) { return a > b; }}, + {R"(GreaterEqual)", [](int a, int64_t b) { return a >= b; }}, + {R"(Equal)", [](int a, int64_t b) { return a == b; }}, + {R"(NotEqual)", [](int a, int64_t b) { return a != b; }}, + }; - auto build_unary_range_expr = [&](DataType data_type, - int64_t value) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - proto::plan::GenericValue val; - val.set_int64_val(value); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::LessThan, - val); - } else if (IsFloatDataType(data_type)) { - proto::plan::GenericValue val; - val.set_float_val(float(value)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::LessThan, - val); - } else if (IsStringDataType(data_type)) { - proto::plan::GenericValue val; - val.set_string_val(std::to_string(value)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::LessThan, - val); - } else { - throw std::runtime_error("not supported type"); - } - }; - - auto build_binary_range_expr = [&](DataType data_type, - int64_t low, - int64_t high) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_int64_val(low); - proto::plan::GenericValue val2; - val2.set_int64_val(high); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - val1, - val2, - true, - true); - } else if (IsFloatDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_float_val(float(low)); - proto::plan::GenericValue val2; - val2.set_float_val(float(high)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - val1, - val2, - true, - true); - } else if (IsStringDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_string_val(std::to_string(low)); - proto::plan::GenericValue val2; - val2.set_string_val(std::to_string(low)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - val1, - val2, - true, - true); - } else { - throw std::runtime_error("not supported type"); - } - }; - - auto build_term_expr = - [&](DataType data_type, - std::vector in_vals) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - std::vector vals; - for (auto& v : in_vals) { - proto::plan::GenericValue val; - val.set_int64_val(v); - vals.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), vals, false); - } else if (IsFloatDataType(data_type)) { - std::vector vals; - for (auto& v : in_vals) { - proto::plan::GenericValue val; - val.set_float_val(float(v)); - vals.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), vals, false); - } else if (IsStringDataType(data_type)) { - std::vector vals; - for (auto& v : in_vals) { - proto::plan::GenericValue val; - val.set_string_val(std::to_string(v)); - vals.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), vals, false); - } else { - throw std::runtime_error("not supported type"); - } - }; - - auto build_compare_expr = [&](DataType data_type) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type) || IsFloatDataType(data_type) || - IsStringDataType(data_type)) { - return std::make_shared( - fids[data_type], - fids[data_type], - data_type, - data_type, - proto::plan::OpType::LessThan); - } else { - throw std::runtime_error("not supported type"); - } - }; + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: %4% + > + right_column_info: < + field_id: %5% + data_type: %6% + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; - auto build_logical_unary_expr = - [&](DataType data_type) -> expr::TypedExprPtr { - auto child_expr = build_unary_range_expr(data_type, 10); - return std::make_shared( - expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); - }; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i32_fid = schema->AddDebugField("age32", DataType::INT32); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); - auto build_logical_binary_expr = - [&](DataType data_type) -> expr::TypedExprPtr { - auto child1_expr = build_unary_range_expr(data_type, 10); - auto child2_expr = build_unary_range_expr(data_type, 10); - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - }; + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; - auto build_multi_logical_binary_expr = - [&](DataType data_type) -> expr::TypedExprPtr { - auto child1_expr = build_unary_range_expr(data_type, 100); - auto child2_expr = build_unary_range_expr(data_type, 100); - auto child3_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - auto child4_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - auto child5_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); - auto child6_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child5_expr, child6_expr); - }; + // load index for int32 field + auto age32_col = raw_data.get_col(i32_fid); + age32_col[0] = 1000; + auto age32_index = milvus::index::CreateScalarIndexSort(); + age32_index->Build(N, age32_col.data()); + load_index_info.field_id = i32_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(age32_index); + seg->LoadIndex(load_index_info); - auto build_arith_op_expr = [&](DataType data_type, - int64_t right_val, - int64_t val) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_int64_val(right_val); - proto::plan::GenericValue val2; - val2.set_int64_val(val); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val1, - val2); - } else if (IsFloatDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_float_val(float(right_val)); - proto::plan::GenericValue val2; - val2.set_float_val(float(val)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val1, - val2); - } else { - throw std::runtime_error("not supported type"); - } - }; + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data()); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); - auto test_case_base = [=, &seg](expr::TypedExprPtr expr) { - std::cout << expr->ToString() << std::endl; + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = + boost::format(serialized_expr_plan) % vec_fid.get() % clause % + i32_fid.get() % proto::schema::DataType_Name(int(DataType::INT32)) % + i64_fid.get() % proto::schema::DataType_Name(int(DataType::INT64)); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - auto start = std::chrono::steady_clock::now(); - for (int i = 0; i < 100; i++) { - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = age32_col[i]; + auto val2 = age64_col[i]; + auto ref = ref_func(val1, val2); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; } - std::cout << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() / - 100.0 - << "us" << std::endl; - }; + } +} - std::cout << "test unary range operator" << std::endl; - auto expr = build_unary_range_expr(DataType::INT8, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::INT16, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::INT32, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::INT64, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::FLOAT, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::DOUBLE, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::VARCHAR, 10); - test_case_base(expr); +TEST_P(ExprTest, TestCompareWithScalarIndexNullable) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; - std::cout << "test binary range operator" << std::endl; - expr = build_binary_range_expr(DataType::INT8, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::INT16, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::INT32, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::INT64, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::FLOAT, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::DOUBLE, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::VARCHAR, 10, 100); - test_case_base(expr); + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: %4% + > + right_column_info: < + field_id: %5% + data_type: %6% + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; - std::cout << "test compare expr operator" << std::endl; - expr = build_compare_expr(DataType::INT8); - test_case_base(expr); - expr = build_compare_expr(DataType::INT16); - test_case_base(expr); - expr = build_compare_expr(DataType::INT32); - test_case_base(expr); - expr = build_compare_expr(DataType::INT64); - test_case_base(expr); - expr = build_compare_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_compare_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_compare_expr(DataType::VARCHAR); - test_case_base(expr); + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT32, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); - std::cout << "test artih op val operator" << std::endl; - expr = build_arith_op_expr(DataType::INT8, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::INT16, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::INT32, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::INT64, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::FLOAT, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::DOUBLE, 10, 100); - test_case_base(expr); + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; - std::cout << "test logical unary expr operator" << std::endl; - expr = build_logical_unary_expr(DataType::INT8); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::INT16); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::INT32); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::INT64); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::VARCHAR); - test_case_base(expr); + // load index for int32 field + auto nullable_col = raw_data.get_col(nullable_fid); + nullable_col[0] = 1000; + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto nullable_index = milvus::index::CreateScalarIndexSort(); + nullable_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(nullable_index); + seg->LoadIndex(load_index_info); - std::cout << "test logical binary expr operator" << std::endl; - expr = build_logical_binary_expr(DataType::INT8); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::INT16); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::INT32); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::INT64); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::VARCHAR); - test_case_base(expr); + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data()); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); - std::cout << "test multi logical binary expr operator" << std::endl; - expr = build_multi_logical_binary_expr(DataType::INT8); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::INT16); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::INT32); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::INT64); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::VARCHAR); - test_case_base(expr); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)) % + i64_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = nullable_col[i]; + auto val2 = age64_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } } -TEST_P(ExprTest, test_term_pk) { +TEST_P(ExprTest, TestCompareWithScalarIndexNullable2) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: %4% + > + right_column_info: < + field_id: %5% + data_type: %6% + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); - schema->AddField( - FieldName("Timestamp"), FieldId(1), DataType::INT64, false); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - schema->set_primary_field_id(int64_fid); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT32, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); auto seg = CreateSealedSegment(schema); - int N = 100000; + int N = 1000; auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; - // load field data - auto fields = schema->get_fields(); - - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); + // load index for int32 field + auto nullable_col = raw_data.get_col(nullable_fid); + nullable_col[0] = 1000; + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto nullable_index = milvus::index::CreateScalarIndexSort(); + nullable_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(nullable_index); + seg->LoadIndex(load_index_info); - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data()); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); - seg->LoadFieldData(FieldId(field_id), info); - } - - std::vector retrieve_ints; - for (int i = 0; i < 10; ++i) { - proto::plan::GenericValue val; - val.set_int64_val(i); - retrieve_ints.push_back(val); - } - auto expr = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % i64_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)) % + nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(final[i], true); - } - for (int i = 10; i < N; ++i) { - EXPECT_EQ(final[i], false); - } - retrieve_ints.clear(); - for (int i = 0; i < 10; ++i) { - proto::plan::GenericValue val; - val.set_int64_val(i + N); - retrieve_ints.push_back(val); - } - expr = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); - for (int i = 0; i < N; ++i) { - EXPECT_EQ(final[i], false); + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val2 = nullable_col[i]; + auto val1 = age64_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } } } @@ -1724,9 +2526,11 @@ TEST_P(ExprTest, test_term_pk_with_sorted) { } } -TEST_P(ExprTest, TestConjuctExpr) { +TEST_P(ExprTest, TestSealedSegmentGetBatchSize) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); auto int16_fid = schema->AddDebugField("int16", DataType::INT16); @@ -1735,16 +2539,18 @@ TEST_P(ExprTest, TestConjuctExpr) { auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto str3_fid = schema->AddDebugField("string3", DataType::VARCHAR); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + size_t N = 1000; auto raw_data = DataGen(schema, N); - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -1758,60 +2564,149 @@ TEST_P(ExprTest, TestConjuctExpr) { seg->LoadFieldData(FieldId(field_id), info); } - auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { - ::milvus::proto::plan::GenericValue value; - value.set_int64_val(l); - auto left = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - value); - value.set_int64_val(r); - auto right = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::LessThan, - value); - - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, left, right); - }; - - std::vector> test_case = { - {100, 0}, {0, 100}, {8192, 8194}}; - for (auto& pair : test_case) { - std::cout << pair.first << "|" << pair.second << std::endl; - auto expr = build_expr(pair.first, pair.second); - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - BitsetType final; - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - for (int i = 0; i < N; ++i) { - EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { + switch (type) { + case DataType::BOOL: { + auto compare_expr = std::make_shared( + bool_fid, + bool_1_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); + return compare_expr; + } + case DataType::INT8: { + auto compare_expr = + std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case DataType::INT16: { + auto compare_expr = + std::make_shared(int16_fid, + int16_1_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); + return compare_expr; + } + case DataType::INT32: { + auto compare_expr = + std::make_shared(int32_fid, + int32_1_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); + return compare_expr; + } + case DataType::INT64: { + auto compare_expr = + std::make_shared(int64_fid, + int64_1_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); + return compare_expr; + } + case DataType::FLOAT: { + auto compare_expr = + std::make_shared(float_fid, + float_1_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); + return compare_expr; + } + case DataType::DOUBLE: { + auto compare_expr = + std::make_shared(double_fid, + double_1_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); + return compare_expr; + } + case DataType::VARCHAR: { + auto compare_expr = + std::make_shared(str2_fid, + str3_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); + return compare_expr; + } + default: + return std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); } - } + }; + std::cout << "start compare test" << std::endl; + auto expr = build_expr(DataType::BOOL); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT8); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT16); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT32); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT64); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::FLOAT); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::DOUBLE); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + std::cout << "end compare test" << std::endl; } -TEST_P(ExprTest, TestUnaryBenchTest) { +TEST_P(ExprTest, TestCompareExprNullable) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_nullable_fid = + schema->AddDebugField("bool1", DataType::BOOL, true); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int8_nullable_fid = + schema->AddDebugField("int81", DataType::INT8, true); auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int16_nullable_fid = + schema->AddDebugField("int161", DataType::INT16, true); auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int32_nullable_fid = + schema->AddDebugField("int321", DataType::INT32, true); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto int64_nullable_fid = + schema->AddDebugField("int641", DataType::INT64, true); auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_nullable_fid = + schema->AddDebugField("float1", DataType::FLOAT, true); auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_nullable_fid = + schema->AddDebugField("double1", DataType::DOUBLE, true); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto str_nullable_fid = + schema->AddDebugField("string3", DataType::VARCHAR, true); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + size_t N = 1000; auto raw_data = DataGen(schema, N); - - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -1825,62 +2720,149 @@ TEST_P(ExprTest, TestUnaryBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(10); - } else { - val.set_int64_val(10); - } - auto expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::GreaterThan, - val); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 10; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - } - std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; - } -} - -TEST_P(ExprTest, TestBinaryRangeBenchTest) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { + switch (type) { + case DataType::BOOL: { + auto compare_expr = std::make_shared( + bool_fid, + bool_nullable_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); + return compare_expr; + } + case DataType::INT8: { + auto compare_expr = + std::make_shared(int8_fid, + int8_nullable_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case DataType::INT16: { + auto compare_expr = + std::make_shared(int16_fid, + int16_nullable_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); + return compare_expr; + } + case DataType::INT32: { + auto compare_expr = + std::make_shared(int32_fid, + int32_nullable_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); + return compare_expr; + } + case DataType::INT64: { + auto compare_expr = + std::make_shared(int64_fid, + int64_nullable_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); + return compare_expr; + } + case DataType::FLOAT: { + auto compare_expr = + std::make_shared(float_fid, + float_nullable_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); + return compare_expr; + } + case DataType::DOUBLE: { + auto compare_expr = + std::make_shared(double_fid, + double_nullable_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); + return compare_expr; + } + case DataType::VARCHAR: { + auto compare_expr = + std::make_shared(str2_fid, + str_nullable_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); + return compare_expr; + } + default: + return std::make_shared(int8_fid, + int8_nullable_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + } + }; + std::cout << "start compare test" << std::endl; + auto expr = build_expr(DataType::BOOL); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT8); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT16); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT32); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT64); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::FLOAT); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::DOUBLE); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + std::cout << "end compare test" << std::endl; +} + +TEST_P(ExprTest, TestCompareExprNullable2) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_nullable_fid = + schema->AddDebugField("bool1", DataType::BOOL, true); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_nullable_fid = + schema->AddDebugField("int81", DataType::INT8, true); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_nullable_fid = + schema->AddDebugField("int161", DataType::INT16, true); auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int32_nullable_fid = + schema->AddDebugField("int321", DataType::INT32, true); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto int64_nullable_fid = + schema->AddDebugField("int641", DataType::INT64, true); auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_nullable_fid = + schema->AddDebugField("float1", DataType::FLOAT, true); auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_nullable_fid = + schema->AddDebugField("double1", DataType::DOUBLE, true); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto str_nullable_fid = + schema->AddDebugField("string3", DataType::VARCHAR, true); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + size_t N = 1000; auto raw_data = DataGen(schema, N); - - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -1894,52 +2876,119 @@ TEST_P(ExprTest, TestBinaryRangeBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue lower; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - lower.set_float_val(10); - } else { - lower.set_int64_val(10); - } - proto::plan::GenericValue upper; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - upper.set_float_val(45); - } else { - upper.set_int64_val(45); - } - auto expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - lower, - upper, - true, - true); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 10; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { + switch (type) { + case DataType::BOOL: { + auto compare_expr = std::make_shared( + bool_fid, + bool_nullable_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); + return compare_expr; + } + case DataType::INT8: { + auto compare_expr = + std::make_shared(int8_nullable_fid, + int8_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case DataType::INT16: { + auto compare_expr = + std::make_shared(int16_nullable_fid, + int16_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); + return compare_expr; + } + case DataType::INT32: { + auto compare_expr = + std::make_shared(int32_nullable_fid, + int32_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); + return compare_expr; + } + case DataType::INT64: { + auto compare_expr = + std::make_shared(int64_nullable_fid, + int64_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); + return compare_expr; + } + case DataType::FLOAT: { + auto compare_expr = + std::make_shared(float_nullable_fid, + float_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); + return compare_expr; + } + case DataType::DOUBLE: { + auto compare_expr = + std::make_shared(double_nullable_fid, + double_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); + return compare_expr; + } + case DataType::VARCHAR: { + auto compare_expr = + std::make_shared(str_nullable_fid, + str2_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); + return compare_expr; + } + default: + return std::make_shared(int8_nullable_fid, + int8_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); } - std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; - } + }; + std::cout << "start compare test" << std::endl; + auto expr = build_expr(DataType::BOOL); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT8); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT16); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT32); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::INT64); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::FLOAT); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + expr = build_expr(DataType::DOUBLE); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + std::cout << "end compare test" << std::endl; } -TEST_P(ExprTest, TestLogicalUnaryBenchTest) { +TEST(Expr, TestExprPerformance) { + GTEST_SKIP() << "Skip performance test, open it when test performance"; auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); auto int16_fid = schema->AddDebugField("int16", DataType::INT16); @@ -1954,80 +3003,16 @@ TEST_P(ExprTest, TestLogicalUnaryBenchTest) { auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); schema->set_primary_field_id(str1_fid); - auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); - - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); - - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); - - seg->LoadFieldData(FieldId(field_id), info); - } - - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(10); - } else { - val.set_int64_val(10); - } - auto child_expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::GreaterThan, - val); - auto expr = std::make_shared( - expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 50; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - } - std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; - } -} - -TEST_P(ExprTest, TestBinaryLogicalBenchTest) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); + std::map fids = {{DataType::INT8, int8_fid}, + {DataType::INT16, int16_fid}, + {DataType::INT32, int32_fid}, + {DataType::INT64, int64_fid}, + {DataType::VARCHAR, str2_fid}, + {DataType::FLOAT, float_fid}, + {DataType::DOUBLE, double_fid}}; auto seg = CreateSealedSegment(schema); - int N = 10000; + int N = 1000; auto raw_data = DataGen(schema, N); // load field data @@ -2044,222 +3029,366 @@ TEST_P(ExprTest, TestBinaryLogicalBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; + enum ExprType { + UnaryRangeExpr = 0, + TermExprImpl = 1, + CompareExpr = 2, + LogicalUnaryExpr = 3, + BinaryRangeExpr = 4, + LogicalBinaryExpr = 5, + BinaryArithOpEvalRangeExpr = 6, + }; - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(-1000000); + auto build_unary_range_expr = [&](DataType data_type, + int64_t value) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val; + val.set_int64_val(value); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val; + val.set_float_val(float(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); } else { - val.set_int64_val(-1000000); + throw std::runtime_error("not supported type"); } - proto::plan::GenericValue val1; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val1.set_float_val(-100); + }; + + auto build_binary_range_expr = [&](DataType data_type, + int64_t low, + int64_t high) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(low); + proto::plan::GenericValue val2; + val2.set_int64_val(high); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(low)); + proto::plan::GenericValue val2; + val2.set_float_val(float(high)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_string_val(std::to_string(low)); + proto::plan::GenericValue val2; + val2.set_string_val(std::to_string(low)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); } else { - val1.set_int64_val(-100); + throw std::runtime_error("not supported type"); } - auto child1_expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::LessThan, - val); - auto child2_expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::NotEqual, - val1); - auto expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 50; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); + }; + + auto build_term_expr = + [&](DataType data_type, + std::vector in_vals) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_int64_val(v); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsFloatDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_float_val(float(v)); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsStringDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(v)); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else { + throw std::runtime_error("not supported type"); } - std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; - } -} + }; -TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); + auto build_compare_expr = [&](DataType data_type) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type) || IsFloatDataType(data_type) || + IsStringDataType(data_type)) { + return std::make_shared( + fids[data_type], + fids[data_type], + data_type, + data_type, + proto::plan::OpType::LessThan); + } else { + throw std::runtime_error("not supported type"); + } + }; - auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); + auto build_logical_unary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + }; - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); - - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); - - seg->LoadFieldData(FieldId(field_id), info); - } + auto build_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 10); + auto child2_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + }; - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; + auto build_multi_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 100); + auto child2_expr = build_unary_range_expr(data_type, 100); + auto child3_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child4_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child5_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + auto child6_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child5_expr, child6_expr); + }; - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(100); - } else { - val.set_int64_val(100); - } - proto::plan::GenericValue right; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - right.set_float_val(10); + auto build_arith_op_expr = [&](DataType data_type, + int64_t right_val, + int64_t val) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(right_val); + proto::plan::GenericValue val2; + val2.set_int64_val(val); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(right_val)); + proto::plan::GenericValue val2; + val2.set_float_val(float(val)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); } else { - right.set_int64_val(10); + throw std::runtime_error("not supported type"); } - auto expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val, - right); + }; + + auto test_case_base = [=, &seg](expr::TypedExprPtr expr) { + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + std::cout << expr->ToString() << std::endl; BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 50; i++) { - auto start = std::chrono::steady_clock::now(); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i < 100; i++) { final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); + EXPECT_EQ(final.size(), N); } - std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; - } + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() / + 100.0 + << "us" << std::endl; + }; + + std::cout << "test unary range operator" << std::endl; + auto expr = build_unary_range_expr(DataType::INT8, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::INT16, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::INT32, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::INT64, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::FLOAT, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::DOUBLE, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::VARCHAR, 10); + test_case_base(expr); + + std::cout << "test binary range operator" << std::endl; + expr = build_binary_range_expr(DataType::INT8, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::INT16, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::INT32, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::INT64, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::FLOAT, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::DOUBLE, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::VARCHAR, 10, 100); + test_case_base(expr); + + std::cout << "test compare expr operator" << std::endl; + expr = build_compare_expr(DataType::INT8); + test_case_base(expr); + expr = build_compare_expr(DataType::INT16); + test_case_base(expr); + expr = build_compare_expr(DataType::INT32); + test_case_base(expr); + expr = build_compare_expr(DataType::INT64); + test_case_base(expr); + expr = build_compare_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_compare_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_compare_expr(DataType::VARCHAR); + test_case_base(expr); + + std::cout << "test artih op val operator" << std::endl; + expr = build_arith_op_expr(DataType::INT8, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::INT16, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::INT32, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::INT64, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::FLOAT, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::DOUBLE, 10, 100); + test_case_base(expr); + + std::cout << "test logical unary expr operator" << std::endl; + expr = build_logical_unary_expr(DataType::INT8); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::INT16); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::INT32); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::INT64); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::VARCHAR); + test_case_base(expr); + + std::cout << "test logical binary expr operator" << std::endl; + expr = build_logical_binary_expr(DataType::INT8); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::INT16); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::INT32); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::INT64); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::VARCHAR); + test_case_base(expr); + + std::cout << "test multi logical binary expr operator" << std::endl; + expr = build_multi_logical_binary_expr(DataType::INT8); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::INT16); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::INT32); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::INT64); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::VARCHAR); + test_case_base(expr); } -TEST_P(ExprTest, TestCompareExprBenchTest) { +TEST(Expr, TestExprNOT) { auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8, true); auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16, true); auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32, true); auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true); auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); - + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR, true); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT, true); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE, true); schema->set_primary_field_id(str1_fid); + std::map fids = {{DataType::INT8, int8_fid}, + {DataType::INT16, int16_fid}, + {DataType::INT32, int32_fid}, + {DataType::INT64, int64_fid}, + {DataType::VARCHAR, str2_fid}, + {DataType::FLOAT, float_fid}, + {DataType::DOUBLE, double_fid}}; + auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); - - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); - - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); - - seg->LoadFieldData(FieldId(field_id), info); - } - - std::vector< - std::pair, std::pair>> - test_cases = { - {{int8_fid, DataType::INT8}, {int8_1_fid, DataType::INT8}}, - {{int16_fid, DataType::INT16}, {int16_fid, DataType::INT16}}, - {{int32_fid, DataType::INT32}, {int32_1_fid, DataType::INT32}}, - {{int64_fid, DataType::INT64}, {int64_1_fid, DataType::INT64}}, - {{float_fid, DataType::FLOAT}, {float_1_fid, DataType::FLOAT}}, - {{double_fid, DataType::DOUBLE}, {double_1_fid, DataType::DOUBLE}}}; - - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.first.second) << std::endl; - proto::plan::GenericValue lower; - auto expr = std::make_shared(pair.first.first, - pair.second.first, - pair.first.second, - pair.second.second, - OpType::LessThan); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 10; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - } - std::cout << " cost: " << all_cost / 10 << "us" << std::endl; - } -} - -TEST_P(ExprTest, TestRefactorExprs) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); - - auto seg = CreateSealedSegment(schema); - int N = 10000; + FixedVector valid_data_i8; + FixedVector valid_data_i16; + FixedVector valid_data_i32; + FixedVector valid_data_i64; + FixedVector valid_data_str; + FixedVector valid_data_float; + FixedVector valid_data_double; + int N = 1000; auto raw_data = DataGen(schema, N); + valid_data_i8 = raw_data.get_col_valid(int8_fid); + valid_data_i16 = raw_data.get_col_valid(int16_fid); + valid_data_i32 = raw_data.get_col_valid(int32_fid); + valid_data_i64 = raw_data.get_col_valid(int64_fid); + valid_data_str = raw_data.get_col_valid(str2_fid); + valid_data_float = raw_data.get_col_valid(float_fid); + valid_data_double = raw_data.get_col_valid(double_fid); // load field data auto fields = schema->get_fields(); @@ -2285,382 +3414,4241 @@ TEST_P(ExprTest, TestRefactorExprs) { BinaryArithOpEvalRangeExpr = 6, }; - auto build_expr = [&](enum ExprType test_type, - int n) -> expr::TypedExprPtr { - switch (test_type) { - case UnaryRangeExpr: { - proto::plan::GenericValue val; - val.set_int64_val(10); - return std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val); - } - case TermExprImpl: { - std::vector retrieve_ints; - // for (int i = 0; i < n; ++i) { - // retrieve_ints.push_back("xxxxxx" + std::to_string(i % 10)); - // } - // return std::make_shared>( - // ColumnInfo(str1_fid, DataType::VARCHAR), - // retrieve_ints, - // proto::plan::GenericValue::ValCase::kStringVal); - for (int i = 0; i < n; ++i) { - proto::plan::GenericValue val; - val.set_float_val(i); - retrieve_ints.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(double_fid, DataType::DOUBLE), - retrieve_ints); - } - case CompareExpr: { - auto compare_expr = - std::make_shared(int8_fid, - int8_1_fid, - DataType::INT8, - DataType::INT8, - OpType::LessThan); - return compare_expr; - } - case BinaryRangeExpr: { - proto::plan::GenericValue lower; - lower.set_int64_val(10); - proto::plan::GenericValue upper; - upper.set_int64_val(45); - return std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - lower, - upper, - true, - true); - } - case LogicalUnaryExpr: { - proto::plan::GenericValue val; - val.set_int64_val(10); - auto child_expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - return std::make_shared( - expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); - } - case LogicalBinaryExpr: { + auto build_unary_range_expr = [&](DataType data_type, + int64_t value) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val; + val.set_int64_val(value); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val; + val.set_float_val(float(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_binary_range_expr = [&](DataType data_type, + int64_t low, + int64_t high) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(low); + proto::plan::GenericValue val2; + val2.set_int64_val(high); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(low)); + proto::plan::GenericValue val2; + val2.set_float_val(float(high)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_string_val(std::to_string(low)); + proto::plan::GenericValue val2; + val2.set_string_val(std::to_string(low)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_term_expr = + [&](DataType data_type, + std::vector in_vals) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { proto::plan::GenericValue val; - val.set_int64_val(10); - auto child1_expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - auto child2_expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::NotEqual, - val); - ; - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, - child1_expr, - child2_expr); + val.set_int64_val(v); + vals.push_back(val); } - case BinaryArithOpEvalRangeExpr: { + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsFloatDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { proto::plan::GenericValue val; - val.set_int64_val(100); - proto::plan::GenericValue right; - right.set_int64_val(10); - return std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val, - right); + val.set_float_val(float(v)); + vals.push_back(val); } - default: { + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsStringDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { proto::plan::GenericValue val; - val.set_int64_val(10); - return std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); + val.set_string_val(std::to_string(v)); + vals.push_back(val); } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else { + throw std::runtime_error("not supported type"); } }; - auto test_case = [&](int n) { - auto expr = build_expr(UnaryRangeExpr, n); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - std::cout << "start test" << std::endl; - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - std::cout << n << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << "us" << std::endl; - }; - test_case(3); - test_case(10); - test_case(20); - test_case(30); - test_case(50); - test_case(100); - test_case(200); - // test_case(500); -} + auto build_compare_expr = [&](DataType data_type) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type) || IsFloatDataType(data_type) || + IsStringDataType(data_type)) { + return std::make_shared( + fids[data_type], + fids[data_type], + data_type, + data_type, + proto::plan::OpType::LessThan); + } else { + throw std::runtime_error("not supported type"); + } + }; -TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { - std::vector< - std::tuple>> - testcases = { - {R"(LessThan)", - [](std::string a, std::string b) { return a.compare(b) < 0; }}, - {R"(LessEqual)", - [](std::string a, std::string b) { return a.compare(b) <= 0; }}, - {R"(GreaterThan)", - [](std::string a, std::string b) { return a.compare(b) > 0; }}, - {R"(GreaterEqual)", - [](std::string a, std::string b) { return a.compare(b) >= 0; }}, - {R"(Equal)", - [](std::string a, std::string b) { return a.compare(b) == 0; }}, - {R"(NotEqual)", - [](std::string a, std::string b) { return a.compare(b) != 0; }}, - }; + auto build_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 10); + auto child2_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + }; - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - compare_expr: < - left_column_info: < - field_id: %3% - data_type: VarChar - > - right_column_info: < - field_id: %4% - data_type: VarChar - > - op: %2% - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; + auto build_multi_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 100); + auto child2_expr = build_unary_range_expr(data_type, 100); + auto child3_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child4_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child5_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + auto child6_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child5_expr, child6_expr); + }; + + auto build_arith_op_expr = [&](DataType data_type, + int64_t right_val, + int64_t val) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(right_val); + proto::plan::GenericValue val2; + val2.set_int64_val(val); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(right_val)); + proto::plan::GenericValue val2; + val2.set_float_val(float(val)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_logical_unary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + }; + + auto test_ans = [=, &seg](expr::TypedExprPtr expr, + FixedVector valid_data) { + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, expr); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + for (int i = 0; i < N; i++) { + if (!valid_data[i]) { + EXPECT_EQ(final[i], false); + } + } + }; + + auto expr = build_unary_range_expr(DataType::INT8, 10); + test_ans(expr, valid_data_i8); + expr = build_unary_range_expr(DataType::INT16, 10); + test_ans(expr, valid_data_i16); + expr = build_unary_range_expr(DataType::INT32, 10); + test_ans(expr, valid_data_i32); + expr = build_unary_range_expr(DataType::INT64, 10); + test_ans(expr, valid_data_i64); + expr = build_unary_range_expr(DataType::FLOAT, 10); + test_ans(expr, valid_data_float); + expr = build_unary_range_expr(DataType::DOUBLE, 10); + test_ans(expr, valid_data_double); + expr = build_unary_range_expr(DataType::VARCHAR, 10); + test_ans(expr, valid_data_str); + + expr = build_binary_range_expr(DataType::INT8, 10, 100); + test_ans(expr, valid_data_i8); + expr = build_binary_range_expr(DataType::INT16, 10, 100); + test_ans(expr, valid_data_i16); + expr = build_binary_range_expr(DataType::INT32, 10, 100); + test_ans(expr, valid_data_i32); + expr = build_binary_range_expr(DataType::INT64, 10, 100); + test_ans(expr, valid_data_i64); + expr = build_binary_range_expr(DataType::FLOAT, 10, 100); + test_ans(expr, valid_data_float); + expr = build_binary_range_expr(DataType::DOUBLE, 10, 100); + test_ans(expr, valid_data_double); + expr = build_binary_range_expr(DataType::VARCHAR, 10, 100); + test_ans(expr, valid_data_str); + + expr = build_compare_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_compare_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_compare_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_compare_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_compare_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_compare_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_compare_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); + + expr = build_arith_op_expr(DataType::INT8, 10, 100); + test_ans(expr, valid_data_i8); + expr = build_arith_op_expr(DataType::INT16, 10, 100); + test_ans(expr, valid_data_i16); + expr = build_arith_op_expr(DataType::INT32, 10, 100); + test_ans(expr, valid_data_i32); + expr = build_arith_op_expr(DataType::INT64, 10, 100); + test_ans(expr, valid_data_i64); + expr = build_arith_op_expr(DataType::FLOAT, 10, 100); + test_ans(expr, valid_data_float); + expr = build_arith_op_expr(DataType::DOUBLE, 10, 100); + test_ans(expr, valid_data_double); + + expr = build_logical_unary_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_logical_unary_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_logical_unary_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_logical_unary_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_logical_unary_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_logical_unary_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_logical_unary_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); + + expr = build_logical_binary_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_logical_binary_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_logical_binary_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_logical_binary_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_logical_binary_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_logical_binary_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_logical_binary_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); + + expr = build_multi_logical_binary_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_multi_logical_binary_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_multi_logical_binary_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_multi_logical_binary_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_multi_logical_binary_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_multi_logical_binary_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_multi_logical_binary_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); +} +TEST_P(ExprTest, test_term_pk) { auto schema = std::make_shared(); + schema->AddField( + FieldName("Timestamp"), FieldId(1), DataType::INT64, false); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - schema->set_primary_field_id(str1_fid); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(int64_fid); auto seg = CreateSealedSegment(schema); int N = 1000; auto raw_data = DataGen(schema, N); - segcore::LoadIndexInfo load_index_info; - // load index for int32 field - auto str1_col = raw_data.get_col(str1_fid); - GenScalarIndexing(N, str1_col.data()); - auto str1_index = milvus::index::CreateScalarIndexSort(); - str1_index->Build(N, str1_col.data()); - load_index_info.field_id = str1_fid.get(); - load_index_info.field_type = DataType::VARCHAR; - load_index_info.index = std::move(str1_index); - seg->LoadIndex(load_index_info); + // load field data + auto fields = schema->get_fields(); - // load index for int64 field - auto str2_col = raw_data.get_col(str2_fid); - GenScalarIndexing(N, str2_col.data()); - auto str2_index = milvus::index::CreateScalarIndexSort(); - str2_index->Build(N, str2_col.data()); - load_index_info.field_id = str2_fid.get(); - load_index_info.field_type = DataType::VARCHAR; - load_index_info.index = std::move(str2_index); - seg->LoadIndex(load_index_info); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); - for (auto [clause, ref_func] : testcases) { - auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % - clause % str1_fid.get() % str2_fid.get(); - auto binary_plan = - translate_text_plan_with_metric_type(dsl_string.str()); - auto plan = CreateSearchPlanByExpr( - *schema, binary_plan.data(), binary_plan.size()); - // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; - BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg.get(), - N, - MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); - for (int i = 0; i < N; ++i) { - auto ans = final[i]; - auto val1 = str1_col[i]; - auto val2 = str2_col[i]; - auto ref = ref_func(val1, val2); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" - << boost::format("[%1%, %2%]") % val1 % val2; - } + seg->LoadFieldData(FieldId(field_id), info); } -} -TEST_P(ExprTest, TestBinaryArithOpEvalRange) { - std::vector, DataType>> testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 101 - data_type: Int8 - > - arith_op: Add - right_operand: < - int64_val: 4 - > - op: Equal - value: < - int64_val: 8 - > - >)", + std::vector retrieve_ints; + for (int i = 0; i < 10; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i); + retrieve_ints.push_back(val); + } + auto expr = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(final[i], true); + } + for (int i = 10; i < N; ++i) { + EXPECT_EQ(final[i], false); + } + retrieve_ints.clear(); + for (int i = 0; i < 10; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i + N); + retrieve_ints.push_back(val); + } + expr = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(final[i], false); + } +} + +TEST_P(ExprTest, TestGrowingSegmentGetBatchSize) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + auto raw_data = DataGen(schema, N); + seg->PreInsert(N); + seg->Insert(0, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + + proto::plan::GenericValue val; + val.set_int64_val(10); + auto expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto plan_node = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + + std::vector test_batch_size = { + 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; + + for (const auto& batch_size : test_batch_size) { + EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; + auto plan = plan::PlanFragment(plan_node); + auto query_context = std::make_shared( + "query id", seg.get(), N, MAX_TIMESTAMP); + + auto task = + milvus::exec::Task::Create("task_expr", plan, 0, query_context); + auto last_num = N % batch_size; + auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; + int iter = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + auto childrens = result->childrens(); + if (++iter != iter_num) { + EXPECT_EQ(childrens[0]->size(), batch_size); + } else { + EXPECT_EQ(childrens[0]->size(), last_num); + } + } + } +} + +TEST_P(ExprTest, TestConjuctExpr) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(l); + auto left = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + value); + value.set_int64_val(r); + auto right = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::LessThan, + value); + + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, left, right); + }; + + std::vector> test_case = { + {100, 0}, {0, 100}, {8192, 8194}}; + for (auto& pair : test_case) { + std::cout << pair.first << "|" << pair.second << std::endl; + auto expr = build_expr(pair.first, pair.second); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; + } + } +} + +TEST_P(ExprTest, TestConjuctExprNullable) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_nullable_fid = + schema->AddDebugField("int8_nullable", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_nullable_fid = + schema->AddDebugField("int16_nullable", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_nullable_fid = + schema->AddDebugField("int32_nullable", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_nullable_fid = + schema->AddDebugField("int64_nullable", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(l); + auto left = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + value); + value.set_int64_val(r); + auto right = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::LessThan, + value); + + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, left, right); + }; + + std::vector> test_case = { + {100, 0}, {0, 100}, {8192, 8194}}; + for (auto& pair : test_case) { + std::cout << pair.first << "|" << pair.second << std::endl; + auto expr = build_expr(pair.first, pair.second); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; + } + } +} + +TEST_P(ExprTest, TestUnaryBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(10); + } else { + val.set_int64_val(10); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::GreaterThan, + val); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestBinaryRangeBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue lower; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + lower.set_float_val(10); + } else { + lower.set_int64_val(10); + } + proto::plan::GenericValue upper; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + upper.set_float_val(45); + } else { + upper.set_int64_val(45); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + lower, + upper, + true, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestLogicalUnaryBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(10); + } else { + val.set_int64_val(10); + } + auto child_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::GreaterThan, + val); + auto expr = std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestBinaryLogicalBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(-1000000); + } else { + val.set_int64_val(-1000000); + } + proto::plan::GenericValue val1; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val1.set_float_val(-100); + } else { + val1.set_int64_val(-100); + } + auto child1_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::LessThan, + val); + auto child2_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::NotEqual, + val1); + auto expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(100); + } else { + val.set_int64_val(100); + } + proto::plan::GenericValue right; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + right.set_float_val(10); + } else { + right.set_int64_val(10); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val, + right); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestCompareExprBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector< + std::pair, std::pair>> + test_cases = { + {{int8_fid, DataType::INT8}, {int8_1_fid, DataType::INT8}}, + {{int16_fid, DataType::INT16}, {int16_fid, DataType::INT16}}, + {{int32_fid, DataType::INT32}, {int32_1_fid, DataType::INT32}}, + {{int64_fid, DataType::INT64}, {int64_1_fid, DataType::INT64}}, + {{float_fid, DataType::FLOAT}, {float_1_fid, DataType::FLOAT}}, + {{double_fid, DataType::DOUBLE}, {double_1_fid, DataType::DOUBLE}}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.first.second) << std::endl; + proto::plan::GenericValue lower; + auto expr = std::make_shared(pair.first.first, + pair.second.first, + pair.first.second, + pair.second.second, + OpType::LessThan); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestRefactorExprs) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + enum ExprType { + UnaryRangeExpr = 0, + TermExprImpl = 1, + CompareExpr = 2, + LogicalUnaryExpr = 3, + BinaryRangeExpr = 4, + LogicalBinaryExpr = 5, + BinaryArithOpEvalRangeExpr = 6, + }; + + auto build_expr = [&](enum ExprType test_type, + int n) -> expr::TypedExprPtr { + switch (test_type) { + case UnaryRangeExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + } + case TermExprImpl: { + std::vector retrieve_ints; + // for (int i = 0; i < n; ++i) { + // retrieve_ints.push_back("xxxxxx" + std::to_string(i % 10)); + // } + // return std::make_shared>( + // ColumnInfo(str1_fid, DataType::VARCHAR), + // retrieve_ints, + // proto::plan::GenericValue::ValCase::kStringVal); + for (int i = 0; i < n; ++i) { + proto::plan::GenericValue val; + val.set_float_val(i); + retrieve_ints.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(double_fid, DataType::DOUBLE), + retrieve_ints); + } + case CompareExpr: { + auto compare_expr = + std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case BinaryRangeExpr: { + proto::plan::GenericValue lower; + lower.set_int64_val(10); + proto::plan::GenericValue upper; + upper.set_int64_val(45); + return std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + lower, + upper, + true, + true); + } + case LogicalUnaryExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + auto child_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + } + case LogicalBinaryExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + auto child1_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto child2_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::NotEqual, + val); + ; + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, + child1_expr, + child2_expr); + } + case BinaryArithOpEvalRangeExpr: { + proto::plan::GenericValue val; + val.set_int64_val(100); + proto::plan::GenericValue right; + right.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val, + right); + } + default: { + proto::plan::GenericValue val; + val.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + } + } + }; + auto test_case = [&](int n) { + auto expr = build_expr(UnaryRangeExpr, n); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); + std::cout << n << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + }; + test_case(3); + test_case(10); + test_case(20); + test_case(30); + test_case(50); + test_case(100); + test_case(200); + // test_case(500); +} + +TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](std::string a, std::string b) { return a.compare(b) < 0; }}, + {R"(LessEqual)", + [](std::string a, std::string b) { return a.compare(b) <= 0; }}, + {R"(GreaterThan)", + [](std::string a, std::string b) { return a.compare(b) > 0; }}, + {R"(GreaterEqual)", + [](std::string a, std::string b) { return a.compare(b) >= 0; }}, + {R"(Equal)", + [](std::string a, std::string b) { return a.compare(b) == 0; }}, + {R"(NotEqual)", + [](std::string a, std::string b) { return a.compare(b) != 0; }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: VarChar + > + right_column_info: < + field_id: %4% + data_type: VarChar + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto str1_col = raw_data.get_col(str1_fid); + auto str1_index = milvus::index::CreateScalarIndexSort(); + str1_index->Build(N, str1_col.data()); + load_index_info.field_id = str1_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str1_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto str2_col = raw_data.get_col(str2_fid); + auto str2_index = milvus::index::CreateStringIndexMarisa(); + str2_index->Build(N, str2_col.data()); + load_index_info.field_id = str2_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str2_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % str1_fid.get() % str2_fid.get(); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = str1_col[i]; + auto val2 = str2_col[i]; + auto ref = ref_func(val1, val2); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable) { + std::vector>> + testcases = { + {R"(LessThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) < 0; + }}, + {R"(LessEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) <= 0; + }}, + {R"(GreaterThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) > 0; + }}, + {R"(GreaterEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) >= 0; + }}, + {R"(Equal)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) == 0; + }}, + {R"(NotEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) != 0; + }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: VarChar + > + right_column_info: < + field_id: %4% + data_type: VarChar + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto nullable_fid = + schema->AddDebugField("nullable_fid", DataType::VARCHAR, true); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto str1_col = raw_data.get_col(str1_fid); + auto str1_index = milvus::index::CreateScalarIndexSort(); + str1_index->Build(N, str1_col.data()); + load_index_info.field_id = str1_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str1_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto nullable_col = raw_data.get_col(nullable_fid); + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto str2_index = milvus::index::CreateStringIndexMarisa(); + str2_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str2_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % str1_fid.get() % nullable_fid.get(); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = str1_col[i]; + auto val2 = nullable_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable2) { + std::vector>> + testcases = { + {R"(LessThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) < 0; + }}, + {R"(LessEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) <= 0; + }}, + {R"(GreaterThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) > 0; + }}, + {R"(GreaterEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) >= 0; + }}, + {R"(Equal)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) == 0; + }}, + {R"(NotEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) != 0; + }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: VarChar + > + right_column_info: < + field_id: %4% + data_type: VarChar + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto nullable_fid = + schema->AddDebugField("nullable_fid", DataType::VARCHAR, true); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto str1_col = raw_data.get_col(str1_fid); + auto str1_index = milvus::index::CreateScalarIndexSort(); + str1_index->Build(N, str1_col.data()); + load_index_info.field_id = str1_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str1_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto nullable_col = raw_data.get_col(nullable_fid); + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto str2_index = milvus::index::CreateStringIndexMarisa(); + str2_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str2_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % nullable_fid.get() % str1_fid.get(); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = nullable_col[i]; + auto val2 = str1_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRange) { + std::vector, DataType>> testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + > + >)", [](int8_t v) { return (v + 4) == 8; }, DataType::INT8}, {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id: 102 + data_type: Int16 + > + arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + > + >)", + [](int16_t v) { return (v - 500) == 1500; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + > + >)", + [](int32_t v) { return (v * 2) == 4000; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + > + >)", + [](int64_t v) { return (v / 2) == 1000; }, + DataType::INT64}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) == 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) == 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](double v) { return (v + 500) == 2500; }, + DataType::DOUBLE}, + // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) != 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) != 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) != 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) != 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) != 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) != 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) > 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) > 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) > 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) > 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) > 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) > 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) >= 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) >= 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) >= 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) >= 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) >= 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) >= 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) < 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) < 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) < 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) < 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) < 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) < 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) <= 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) <= 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) <= 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) <= 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) <= 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) <= 2500; }, + DataType::INT64}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i8_fid = schema->AddDebugField("age8", DataType::INT8); + auto i16_fid = schema->AddDebugField("age16", DataType::INT16); + auto i32_fid = schema->AddDebugField("age32", DataType::INT32); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto float_fid = schema->AddDebugField("age_float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age8_col; + std::vector age16_col; + std::vector age32_col; + std::vector age64_col; + std::vector age_float_col; + std::vector age_double_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto new_age8_col = raw_data.get_col(i8_fid); + auto new_age16_col = raw_data.get_col(i16_fid); + auto new_age32_col = raw_data.get_col(i32_fid); + auto new_age64_col = raw_data.get_col(i64_fid); + auto new_age_float_col = raw_data.get_col(float_fid); + auto new_age_double_col = raw_data.get_col(double_fid); + + age8_col.insert( + age8_col.end(), new_age8_col.begin(), new_age8_col.end()); + age16_col.insert( + age16_col.end(), new_age16_col.begin(), new_age16_col.end()); + age32_col.insert( + age32_col.end(), new_age32_col.begin(), new_age32_col.end()); + age64_col.insert( + age64_col.end(), new_age64_col.begin(), new_age64_col.end()); + age_float_col.insert(age_float_col.end(), + new_age_float_col.begin(), + new_age_float_col.end()); + age_double_col.insert(age_double_col.end(), + new_age_double_col.begin(), + new_age_double_col.end()); + + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + // if (dtype == DataType::INT8) { + // dsl_string.replace(loc, 5, dsl_string_int8); + // } else if (dtype == DataType::INT16) { + // dsl_string.replace(loc, 5, dsl_string_int16); + // } else if (dtype == DataType::INT32) { + // dsl_string.replace(loc, 5, dsl_string_int32); + // } else if (dtype == DataType::INT64) { + // dsl_string.replace(loc, 5, dsl_string_int64); + // } else if (dtype == DataType::FLOAT) { + // dsl_string.replace(loc, 5, dsl_string_float); + // } else if (dtype == DataType::DOUBLE) { + // dsl_string.replace(loc, 5, dsl_string_double); + // } else { + // ASSERT_TRUE(false) << "No test case defined for this data type"; + // } + // loc = dsl_string.find("@@@@"); + // dsl_string.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) + << clause << "@" << i << "!!" << val << std::endl; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeNullable) { + std::vector< + std::tuple, DataType>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) == 8; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) == 1500; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) == 4000; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) == 1000; + }, + DataType::INT64}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) == 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::DOUBLE}, + // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) != 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) != 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) != 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) != 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) > 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) > 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) > 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) > 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) > 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) > 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) >= 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) >= 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) >= 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) >= 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) >= 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) >= 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) < 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) < 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) < 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) < 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) < 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) < 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) <= 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) <= 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) <= 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) <= 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) <= 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) <= 2500; + }, + DataType::INT64}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i8_nullable_fid = schema->AddDebugField("age8", DataType::INT8, true); + auto i16_nullable_fid = + schema->AddDebugField("age16", DataType::INT16, true); + auto i32_nullable_fid = + schema->AddDebugField("age32", DataType::INT32, true); + auto i64_nullable_fid = + schema->AddDebugField("age64_nullable", DataType::INT64, true); + auto float_nullable_fid = + schema->AddDebugField("age_float", DataType::FLOAT, true); + auto double_nullable_fid = + schema->AddDebugField("age_double", DataType::DOUBLE, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age8_col; + std::vector age16_col; + std::vector age32_col; + std::vector age64_col; + std::vector age_float_col; + std::vector age_double_col; + FixedVector age8_valid_col; + FixedVector age16_valid_col; + FixedVector age32_valid_col; + FixedVector age64_valid_col; + FixedVector age_float_valid_col; + FixedVector age_double_valid_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto new_age8_col = raw_data.get_col(i8_nullable_fid); + auto new_age16_col = raw_data.get_col(i16_nullable_fid); + auto new_age32_col = raw_data.get_col(i32_nullable_fid); + auto new_age64_col = raw_data.get_col(i64_nullable_fid); + auto new_age_float_col = raw_data.get_col(float_nullable_fid); + auto new_age_double_col = raw_data.get_col(double_nullable_fid); + age8_valid_col = raw_data.get_col_valid(i8_nullable_fid); + age16_valid_col = raw_data.get_col_valid(i16_nullable_fid); + age32_valid_col = raw_data.get_col_valid(i32_nullable_fid); + age64_valid_col = raw_data.get_col_valid(i64_nullable_fid); + age_float_valid_col = raw_data.get_col_valid(float_nullable_fid); + age_double_valid_col = raw_data.get_col_valid(double_nullable_fid); + + age8_col.insert( + age8_col.end(), new_age8_col.begin(), new_age8_col.end()); + age16_col.insert( + age16_col.end(), new_age16_col.begin(), new_age16_col.end()); + age32_col.insert( + age32_col.end(), new_age32_col.begin(), new_age32_col.end()); + age64_col.insert( + age64_col.end(), new_age64_col.begin(), new_age64_col.end()); + age_float_col.insert(age_float_col.end(), + new_age_float_col.begin(), + new_age_float_col.end()); + age_double_col.insert(age_double_col.end(), + new_age_double_col.begin(), + new_age_double_col.end()); + + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + // if (dtype == DataType::INT8) { + // dsl_string.replace(loc, 5, dsl_string_int8); + // } else if (dtype == DataType::INT16) { + // dsl_string.replace(loc, 5, dsl_string_int16); + // } else if (dtype == DataType::INT32) { + // dsl_string.replace(loc, 5, dsl_string_int32); + // } else if (dtype == DataType::INT64) { + // dsl_string.replace(loc, 5, dsl_string_int64); + // } else if (dtype == DataType::FLOAT) { + // dsl_string.replace(loc, 5, dsl_string_float); + // } else if (dtype == DataType::DOUBLE) { + // dsl_string.replace(loc, 5, dsl_string_double); + // } else { + // ASSERT_TRUE(false) << "No test case defined for this data type"; + // } + // loc = dsl_string.find("@@@@"); + // dsl_string.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val, age8_valid_col[i]); + ASSERT_EQ(ans, ref) + << clause << "@" << i << "!!" << val << std::endl; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val, age16_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val, age32_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val, age64_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val, age_float_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val, age_double_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + std::vector< + std::tuple>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: Equal + value: + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: Equal + value: + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length == 4; + }}, + // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length != 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length > 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length >= 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length < 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - int64_val: 500 + int64_val: 1 > - op: Equal + op: LessEqual value: < - int64_val: 1500 + int64_val: 2 > >)", - [](int16_t v) { return (v - 500) == 1500; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < int64_val: 2 > - op: Equal + op: LessEqual value: < - int64_val: 4000 + int64_val: 4 > >)", - [](int32_t v) { return (v * 2) == 4000; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < int64_val: 2 > - op: Equal + op: LessEqual value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int64_t v) { return (v / 2) == 1000; }, - DataType::INT64}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > - op: Equal + op: LessEqual value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) == 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length <= 4; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = + ref_func(milvus::Json(simdjson::padded_string(json_col[i]))); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONNullable) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + std::vector< + std::tuple>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: Equal value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) == 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > - arith_op: Add + arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 > op: Equal - value: < - float_val: 2500 + value: + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength + op: Equal + value: >)", - [](double v) { return (v + 500) == 2500; }, - DataType::DOUBLE}, - // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length == 4; + }}, + // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: NotEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) != 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: NotEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) != 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2668,15 +7656,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: NotEqual value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) != 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2684,80 +7679,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: NotEqual value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) != 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: NotEqual value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) != 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: NotEqual value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) != 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length != 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) > 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) > 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2765,15 +7798,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterThan value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) > 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2781,80 +7821,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterThan value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) > 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: GreaterThan value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) > 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: GreaterThan value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) > 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length > 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) >= 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) >= 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2862,15 +7940,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterEqual value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) >= 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2878,80 +7963,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterEqual value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) >= 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: GreaterEqual value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) >= 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: GreaterEqual value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) >= 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length >= 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: LessThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) < 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: LessThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) < 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2959,15 +8082,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessThan value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) < 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2975,80 +8105,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessThan value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) < 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: LessThan value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) < 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: LessThan value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) < 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length < 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: LessEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) <= 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: LessEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) <= 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -3056,15 +8224,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessEqual value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) <= 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -3072,102 +8247,1727 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessEqual value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) <= 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: LessEqual value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) <= 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: LessEqual value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) <= 2500; }, - DataType::INT64}, + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length <= 4; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto nullable_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(nullable_fid); + valid_data = raw_data.get_col_valid(nullable_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = + ref_func(milvus::Json(simdjson::padded_string(json_col[i])), + valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { + struct Testcase { + double right_operand; + double value; + OpType op; + std::vector nested_path; + }; + std::vector testcases{ + {10, 20, OpType::Equal, {"double"}}, + {20, 30, OpType::Equal, {"double"}}, + {30, 40, OpType::NotEqual, {"double"}}, + {40, 50, OpType::NotEqual, {"double"}}, + {10, 20, OpType::Equal, {"int"}}, + {20, 30, OpType::Equal, {"int"}}, + {30, 40, OpType::NotEqual, {"int"}}, + {40, 50, OpType::NotEqual, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](double value) { + if (testcase.op == OpType::Equal) { + return value + testcase.right_operand == testcase.value; + } + return value + testcase.right_operand != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_float_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_float_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::Add, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref) + << testcase.value << " " << val << " " << testcase.op; + } + } + + std::vector array_testcases{ + {0, 3, OpType::Equal, {"array"}}, + {0, 5, OpType::NotEqual, {"array"}}, + }; + + for (auto testcase : array_testcases) { + auto check = [&](int64_t value) { + if (testcase.op == OpType::Equal) { + return value == testcase.value; + } + return value != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_int64_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_int64_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::ArrayLength, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto json = milvus::Json(simdjson::padded_string(json_col[i])); + int64_t array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + auto ref = check(array_length); + ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloatNullable) { + struct Testcase { + double right_operand; + double value; + OpType op; + std::vector nested_path; + }; + std::vector testcases{ + {10, 20, OpType::Equal, {"double"}}, + {20, 30, OpType::Equal, {"double"}}, + {30, 40, OpType::NotEqual, {"double"}}, + {40, 50, OpType::NotEqual, {"double"}}, + {10, 20, OpType::Equal, {"int"}}, + {20, 30, OpType::Equal, {"int"}}, + {30, 40, OpType::NotEqual, {"int"}}, + {40, 50, OpType::NotEqual, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](double value, bool valid) { + if (!valid) { + return false; + } + if (testcase.op == OpType::Equal) { + return value + testcase.right_operand == testcase.value; + } + return value + testcase.right_operand != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_float_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_float_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::Add, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data[i]); + ASSERT_EQ(ans, ref) + << testcase.value << " " << val << " " << testcase.op; + } + } + + std::vector array_testcases{ + {0, 3, OpType::Equal, {"array"}}, + {0, 5, OpType::NotEqual, {"array"}}, }; - std::string raw_plan_tmp = R"(vector_anns: < - field_id: 100 - predicates: < - @@@@@ - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" + for (auto testcase : array_testcases) { + auto check = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + if (testcase.op == OpType::Equal) { + return value == testcase.value; + } + return value != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_int64_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_int64_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::ArrayLength, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto json = milvus::Json(simdjson::padded_string(json_col[i])); + int64_t array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + auto ref = check(array_length, valid_data[i]); + ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { + std::vector, DataType>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) == 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) == 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) == 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) == 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) == 0; }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](float v) { return (v + 500) == 2500; }, + DataType::FLOAT}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](double v) { return (v + 500) == 2500; }, + DataType::DOUBLE}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2000 + >)", + [](float v) { return (v + 500) != 2000; }, + DataType::FLOAT}, + {R"(arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + >)", + [](double v) { return (v - 500) != 2000; }, + DataType::DOUBLE}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + >)", + [](int8_t v) { return (v * 2) != 2; }, + DataType::INT8}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int16_t v) { return (v / 2) != 2000; }, + DataType::INT16}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 1 + >)", + [](int32_t v) { return (v % 100) != 1; }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int64_t v) { return (v + 500) != 2000; }, + DataType::INT64}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterThan + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) > 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) > 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) > 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) > 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) > 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterEqual + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) >= 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) >= 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) >= 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) >= 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) >= 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessThan + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) < 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) < 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) < 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) < 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) < 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessEqual + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) <= 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) <= 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) <= 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) <= 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) <= 0; }, + DataType::INT32}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + binary_arith_op_eval_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i8_fid = schema->AddDebugField("age8", DataType::INT8); + auto i16_fid = schema->AddDebugField("age16", DataType::INT16); + auto i32_fid = schema->AddDebugField("age32", DataType::INT32); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto float_fid = schema->AddDebugField("age_float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int8 field + auto age8_col = raw_data.get_col(i8_fid); + age8_col[0] = 4; + auto age8_index = milvus::index::CreateScalarIndexSort(); + age8_index->Build(N, age8_col.data(), nullptr); + load_index_info.field_id = i8_fid.get(); + load_index_info.field_type = DataType::INT8; + load_index_info.index = std::move(age8_index); + seg->LoadIndex(load_index_info); + + // load index for 16 field + auto age16_col = raw_data.get_col(i16_fid); + age16_col[0] = 2000; + auto age16_index = milvus::index::CreateScalarIndexSort(); + age16_index->Build(N, age16_col.data(), nullptr); + load_index_info.field_id = i16_fid.get(); + load_index_info.field_type = DataType::INT16; + load_index_info.index = std::move(age16_index); + seg->LoadIndex(load_index_info); + + // load index for int32 field + auto age32_col = raw_data.get_col(i32_fid); + age32_col[0] = 2000; + auto age32_index = milvus::index::CreateScalarIndexSort(); + age32_index->Build(N, age32_col.data(), nullptr); + load_index_info.field_id = i32_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(age32_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data(), nullptr); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); + + // load index for float field + auto age_float_col = raw_data.get_col(float_fid); + age_float_col[0] = 2000; + auto age_float_index = milvus::index::CreateScalarIndexSort(); + age_float_index->Build(N, age_float_col.data(), nullptr); + load_index_info.field_id = float_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_float_index); + seg->LoadIndex(load_index_info); + + // load index for double field + auto age_double_col = raw_data.get_col(double_fid); + age_double_col[0] = 2000; + auto age_double_index = milvus::index::CreateScalarIndexSort(); + age_double_index->Build(N, age_double_col.data(), nullptr); + load_index_info.field_id = double_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_double_index); + seg->LoadIndex(load_index_info); + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + if (dtype == DataType::INT8) { + expr = boost::format(expr_plan) % vec_fid.get() % i8_fid.get() % + proto::schema::DataType_Name(int(DataType::INT8)); + } else if (dtype == DataType::INT16) { + expr = boost::format(expr_plan) % vec_fid.get() % i16_fid.get() % + proto::schema::DataType_Name(int(DataType::INT16)); + } else if (dtype == DataType::INT32) { + expr = boost::format(expr_plan) % vec_fid.get() % i32_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)); + } else if (dtype == DataType::INT64) { + expr = boost::format(expr_plan) % vec_fid.get() % i64_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)); + } else if (dtype == DataType::FLOAT) { + expr = boost::format(expr_plan) % vec_fid.get() % float_fid.get() % + proto::schema::DataType_Name(int(DataType::FLOAT)); + } else if (dtype == DataType::DOUBLE) { + expr = boost::format(expr_plan) % vec_fid.get() % double_fid.get() % + proto::schema::DataType_Name(int(DataType::DOUBLE)); + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + + auto binary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndexNullable) { + std::vector< + std::tuple, DataType>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) == 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) == 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) == 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) == 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) == 0; + }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::FLOAT}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::DOUBLE}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2000 + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2000; + }, + DataType::FLOAT}, + {R"(arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) != 2000; + }, + DataType::DOUBLE}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) != 2; + }, + DataType::INT8}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) != 2000; + }, + DataType::INT16}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 1 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) != 1; + }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2000; + }, + DataType::INT64}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterThan + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) > 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) > 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) > 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) > 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) > 0; + }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterEqual + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) >= 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) >= 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) >= 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) >= 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) >= 0; + }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessThan + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) < 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) < 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) < 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) < 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) < 0; + }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessEqual + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) <= 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) <= 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) <= 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) <= 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) <= 0; + }, + DataType::INT32}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + binary_arith_op_eval_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + > + @@@@)"; + auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i8_fid = schema->AddDebugField("age8", DataType::INT8); - auto i16_fid = schema->AddDebugField("age16", DataType::INT16); - auto i32_fid = schema->AddDebugField("age32", DataType::INT32); + auto i8_nullable_fid = schema->AddDebugField("age8", DataType::INT8, true); + auto i16_nullable_fid = + schema->AddDebugField("age16", DataType::INT16, true); + auto i32_nullable_fid = + schema->AddDebugField("age32", DataType::INT32, true); auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto float_fid = schema->AddDebugField("age_float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); + auto i64_nullable_fid = + schema->AddDebugField("age641", DataType::INT64, true); + auto float_nullable_fid = + schema->AddDebugField("age_float", DataType::FLOAT, true); + auto double_nullable_fid = + schema->AddDebugField("age_double", DataType::DOUBLE, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + auto i8_valid_data = raw_data.get_col_valid(i8_nullable_fid); + auto i16_valid_data = raw_data.get_col_valid(i16_nullable_fid); + auto i32_valid_data = raw_data.get_col_valid(i32_nullable_fid); + auto i64_valid_data = raw_data.get_col_valid(i64_nullable_fid); + auto float_valid_data = raw_data.get_col_valid(float_nullable_fid); + auto double_valid_data = raw_data.get_col_valid(double_nullable_fid); + + // load index for int8 field + auto age8_col = raw_data.get_col(i8_nullable_fid); + age8_col[0] = 4; + auto age8_index = milvus::index::CreateScalarIndexSort(); + age8_index->Build(N, age8_col.data(), i8_valid_data.data()); + load_index_info.field_id = i8_nullable_fid.get(); + load_index_info.field_type = DataType::INT8; + load_index_info.index = std::move(age8_index); + seg->LoadIndex(load_index_info); + + // load index for 16 field + auto age16_col = raw_data.get_col(i16_nullable_fid); + age16_col[0] = 2000; + auto age16_index = milvus::index::CreateScalarIndexSort(); + age16_index->Build(N, age16_col.data(), i16_valid_data.data()); + load_index_info.field_id = i16_nullable_fid.get(); + load_index_info.field_type = DataType::INT16; + load_index_info.index = std::move(age16_index); + seg->LoadIndex(load_index_info); + + // load index for int32 field + auto age32_col = raw_data.get_col(i32_nullable_fid); + age32_col[0] = 2000; + auto age32_index = milvus::index::CreateScalarIndexSort(); + age32_index->Build(N, age32_col.data(), i32_valid_data.data()); + load_index_info.field_id = i32_nullable_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(age32_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto age64_col = raw_data.get_col(i64_nullable_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data(), i64_valid_data.data()); + load_index_info.field_id = i64_nullable_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); + + // load index for float field + auto age_float_col = raw_data.get_col(float_nullable_fid); + age_float_col[0] = 2000; + auto age_float_index = milvus::index::CreateScalarIndexSort(); + age_float_index->Build(N, age_float_col.data(), float_valid_data.data()); + load_index_info.field_id = float_nullable_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_float_index); + seg->LoadIndex(load_index_info); + + // load index for double field + auto age_double_col = raw_data.get_col(double_nullable_fid); + age_double_col[0] = 2000; + auto age_double_index = milvus::index::CreateScalarIndexSort(); + age_double_index->Build(N, age_double_col.data(), double_valid_data.data()); + load_index_info.field_id = double_nullable_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_double_index); + seg->LoadIndex(load_index_info); + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + if (dtype == DataType::INT8) { + expr = boost::format(expr_plan) % vec_fid.get() % + i8_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT8)); + } else if (dtype == DataType::INT16) { + expr = boost::format(expr_plan) % vec_fid.get() % + i16_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT16)); + } else if (dtype == DataType::INT32) { + expr = boost::format(expr_plan) % vec_fid.get() % + i32_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)); + } else if (dtype == DataType::INT64) { + expr = boost::format(expr_plan) % vec_fid.get() % + i64_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)); + } else if (dtype == DataType::FLOAT) { + expr = boost::format(expr_plan) % vec_fid.get() % + float_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::FLOAT)); + } else if (dtype == DataType::DOUBLE) { + expr = boost::format(expr_plan) % vec_fid.get() % + double_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::DOUBLE)); + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + + auto binary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val, i8_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val, i16_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val, i32_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val, i64_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val, float_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val, double_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestUnaryRangeWithJSON) { + std::vector< + std::tuple)>, + DataType>> + testcases = { + {R"(op: Equal + value: < + bool_val: true + >)", + [](std::variant v) { + return std::get(v); + }, + DataType::BOOL}, + {R"(op: LessEqual + value: < + int64_val: 1500 + >)", + [](std::variant v) { + return std::get(v) < 1500; + }, + DataType::INT64}, + {R"(op: LessEqual + value: < + float_val: 4000 + >)", + [](std::variant v) { + return std::get(v) <= 4000; + }, + DataType::DOUBLE}, + {R"(op: GreaterThan + value: < + float_val: 1000 + >)", + [](std::variant v) { + return std::get(v) > 1000; + }, + DataType::DOUBLE}, + {R"(op: GreaterEqual + value: < + int64_val: 0 + >)", + [](std::variant v) { + return std::get(v) >= 0; + }, + DataType::INT64}, + {R"(op: NotEqual + value: < + bool_val: true + >)", + [](std::variant v) { + return !std::get(v); + }, + DataType::BOOL}, + {R"(op: Equal + value: < + string_val: "test" + >)", + [](std::variant v) { + return std::get(v) == "test"; + }, + DataType::STRING}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + unary_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - std::vector age8_col; - std::vector age16_col; - std::vector age32_col; - std::vector age64_col; - std::vector age_float_col; - std::vector age_double_col; + std::vector json_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); - auto new_age8_col = raw_data.get_col(i8_fid); - auto new_age16_col = raw_data.get_col(i16_fid); - auto new_age32_col = raw_data.get_col(i32_fid); - auto new_age64_col = raw_data.get_col(i64_fid); - auto new_age_float_col = raw_data.get_col(float_fid); - auto new_age_double_col = raw_data.get_col(double_fid); - - age8_col.insert( - age8_col.end(), new_age8_col.begin(), new_age8_col.end()); - age16_col.insert( - age16_col.end(), new_age16_col.begin(), new_age16_col.end()); - age32_col.insert( - age32_col.end(), new_age32_col.begin(), new_age32_col.end()); - age64_col.insert( - age64_col.end(), new_age64_col.begin(), new_age64_col.end()); - age_float_col.insert(age_float_col.end(), - new_age_float_col.begin(), - new_age_float_col.end()); - age_double_col.insert(age_double_col.end(), - new_age_double_col.begin(), - new_age_double_col.end()); - + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, @@ -3177,64 +9977,83 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { } auto seg_promote = dynamic_cast(seg.get()); - + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { - auto loc = raw_plan_tmp.find("@@@@@"); - auto raw_plan = raw_plan_tmp; - raw_plan.replace(loc, 5, clause); - // if (dtype == DataType::INT8) { - // dsl_string.replace(loc, 5, dsl_string_int8); - // } else if (dtype == DataType::INT16) { - // dsl_string.replace(loc, 5, dsl_string_int16); - // } else if (dtype == DataType::INT32) { - // dsl_string.replace(loc, 5, dsl_string_int32); - // } else if (dtype == DataType::INT64) { - // dsl_string.replace(loc, 5, dsl_string_int64); - // } else if (dtype == DataType::FLOAT) { - // dsl_string.replace(loc, 5, dsl_string_float); - // } else if (dtype == DataType::DOUBLE) { - // dsl_string.replace(loc, 5, dsl_string_double); - // } else { - // ASSERT_TRUE(false) << "No test case defined for this data type"; - // } - // loc = dsl_string.find("@@@@"); - // dsl_string.replace(loc, 4, clause); - auto plan_str = translate_text_plan_with_metric_type(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + BitsetType final; final = ExecuteQueryExpr( plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg.get(), + seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::INT8) { - auto val = age8_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) - << clause << "@" << i << "!!" << val << std::endl; - } else if (dtype == DataType::INT16) { - auto val = age16_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT32) { - auto val = age32_col[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } else if (dtype == DataType::INT64) { - auto val = age64_col[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::FLOAT) { - auto val = age_float_col[i]; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = age_double_col[i]; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } else { @@ -3244,781 +10063,852 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { } } -TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; - - std::vector< - std::tuple>> +TEST_P(ExprTest, TestUnaryRangeWithJSONNullable) { + std::vector, bool)>, + DataType>> testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: Equal - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) == 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: Equal - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) == 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) == 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) == 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: Equal - value: - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) == 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: Equal - value: - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + {R"(op: Equal + value: < + bool_val: true + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length == 4; - }}, - // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: NotEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) != 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: NotEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) != 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) != 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) != 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) != 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + return std::get(v); + }, + DataType::BOOL}, + {R"(op: LessEqual + value: < + int64_val: 1500 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length != 4; - }}, - - // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: GreaterThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) > 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: GreaterThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) > 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) > 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) > 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) > 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + return std::get(v) < 1500; + }, + DataType::INT64}, + {R"(op: LessEqual + value: < + float_val: 4000 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length > 4; - }}, - - // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: GreaterEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) >= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: GreaterEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) >= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) >= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) >= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) >= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + return std::get(v) <= 4000; + }, + DataType::DOUBLE}, + {R"(op: GreaterThan + value: < + float_val: 1000 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length >= 4; - }}, + return std::get(v) > 1000; + }, + DataType::DOUBLE}, + {R"(op: GreaterEqual + value: < + int64_val: 0 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return std::get(v) >= 0; + }, + DataType::INT64}, + {R"(op: NotEqual + value: < + bool_val: true + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return !std::get(v); + }, + DataType::BOOL}, + {R"(op: Equal + value: < + string_val: "test" + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return std::get(v) == "test"; + }, + DataType::STRING}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + unary_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); - // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: LessThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) < 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: LessThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) < 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) < 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) < 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) < 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestTermWithJSON) { + std::vector< + std::tuple)>, + DataType>> + testcases = { + {R"(values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {true, false}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::BOOL}, + {R"(values: , values: , values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {1500, 2048, 3216}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::INT64}, + {R"(values: , values: , values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {1500.0, 4000, 235.14}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::DOUBLE}, + {R"(values: , values: , values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {"aaa", "abc", "235.14"}; + return term_set.find(std::get(v)) != + term_set.end(); + }, + DataType::STRING}, + {R"()", + [](std::variant v) { + return false; + }, + DataType::INT64}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + term_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestTermWithJSONNullable) { + std::vector, bool)>, + DataType>> + testcases = { + {R"(values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length < 4; - }}, + std::unordered_set term_set; + term_set = {true, false}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::BOOL}, + {R"(values: , values: , values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {1500, 2048, 3216}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::INT64}, + {R"(values: , values: , values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {1500.0, 4000, 235.14}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::DOUBLE}, + {R"(values: , values: , values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {"aaa", "abc", "235.14"}; + return term_set.find(std::get(v)) != + term_set.end(); + }, + DataType::STRING}, + {R"()", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return false; + }, + DataType::INT64}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + term_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestExistsWithJSON) { + std::vector, DataType>> + testcases = { + {R"()", [](bool v) { return v; }, DataType::BOOL}, + {R"()", [](bool v) { return v; }, DataType::INT64}, + {R"()", [](bool v) { return v; }, DataType::STRING}, + {R"()", [](bool v) { return v; }, DataType::VARCHAR}, + {R"()", [](bool v) { return v; }, DataType::DOUBLE}, + }; - // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: LessEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) <= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: LessEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) <= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) <= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) <= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) <= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + exists_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + case DataType::VARCHAR: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "varchar"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/bool"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/int"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/double"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/string"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::VARCHAR) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/varchar"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestExistsWithJSONNullable) { + std::vector< + std::tuple, DataType>> + testcases = { + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; } - return array_length <= 4; - }}, + return v; + }, + DataType::BOOL}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; + }, + DataType::INT64}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; + }, + DataType::STRING}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; + }, + DataType::VARCHAR}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; + }, + DataType::DOUBLE}, }; - std::string raw_plan_tmp = R"(vector_anns: < - field_id: 100 - predicates: < - @@@@@ - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + exists_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" >)"; + + std::string arith_expr = R"( + info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField( - "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); - auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -4031,14 +10921,58 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + case DataType::VARCHAR: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "varchar"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); - for (auto [clause, ref_func] : testcases) { - auto loc = raw_plan_tmp.find("@@@@@"); - auto raw_plan = raw_plan_tmp; - raw_plan.replace(loc, 5, clause); - auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; final = ExecuteQueryExpr( plan->plan_node_->plannodes_->sources()[0]->sources()[0], @@ -4049,31 +10983,46 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto ref = - ref_func(milvus::Json(simdjson::padded_string(json_col[i]))); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/bool"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/int"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/double"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/string"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::VARCHAR) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/varchar"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } } } } -TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { - struct Testcase { - double right_operand; - double value; - OpType op; - std::vector nested_path; - }; - std::vector testcases{ - {10, 20, OpType::Equal, {"double"}}, - {20, 30, OpType::Equal, {"double"}}, - {30, 40, OpType::NotEqual, {"double"}}, - {40, 50, OpType::NotEqual, {"double"}}, - {10, 20, OpType::Equal, {"int"}}, - {20, 30, OpType::Equal, {"int"}}, - {30, 40, OpType::NotEqual, {"int"}}, - {40, 50, OpType::NotEqual, {"int"}}, - }; +template +struct Testcase { + std::vector term; + std::vector nested_path; + bool res; +}; +TEST_P(ExprTest, TestTermInFieldJson) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto json_fid = schema->AddDebugField("json", DataType::JSON); @@ -4084,7 +11033,7 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { std::vector json_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); + auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); json_col.insert( @@ -4098,709 +11047,581 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - for (auto testcase : testcases) { - auto check = [&](double value) { - if (testcase.op == OpType::Equal) { - return value + testcase.right_operand == testcase.value; - } - return value + testcase.right_operand != testcase.value; - }; - auto pointer = milvus::Json::pointer(testcase.nested_path); - proto::plan::GenericValue value; - value.set_float_val(testcase.value); - proto::plan::GenericValue right_operand; - right_operand.set_float_val(testcase.right_operand); - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - testcase.op, - ArithOpType::Add, - value, - right_operand); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N * num_iters); - - for (int i = 0; i < N * num_iters; ++i) { - auto ans = final[i]; - - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(pointer) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref) - << testcase.value << " " << val << " " << testcase.op; - } - } - - std::vector array_testcases{ - {0, 3, OpType::Equal, {"array"}}, - {0, 5, OpType::NotEqual, {"array"}}, - }; + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; - for (auto testcase : array_testcases) { - auto check = [&](int64_t value) { - if (testcase.op == OpType::Equal) { - return value == testcase.value; - } - return value != testcase.value; + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); }; auto pointer = milvus::Json::pointer(testcase.nested_path); - proto::plan::GenericValue value; - value.set_int64_val(testcase.value); - proto::plan::GenericValue right_operand; - right_operand.set_int64_val(testcase.right_operand); - auto expr = std::make_shared( + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - testcase.op, - ArithOpType::ArrayLength, - value, - right_operand); + values, + true); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + // std::cout << "cost" + // << std::chrono::duration_cast( + // std::chrono::steady_clock::now() - start) + // .count() + // << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - - auto json = milvus::Json(simdjson::padded_string(json_col[i])); - int64_t array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - auto ref = check(array_length); - ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; + ASSERT_EQ(ans, check(res)); } } -} - -TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { - std::vector, DataType>> - testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: Equal - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) == 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: Equal - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) == 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) == 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) == 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: Equal - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) == 0; }, - DataType::INT32}, - {R"(arith_op: Add - right_operand: < - float_val: 500 - > - op: Equal - value: < - float_val: 2500 - >)", - [](float v) { return (v + 500) == 2500; }, - DataType::FLOAT}, - {R"(arith_op: Add - right_operand: < - float_val: 500 - > - op: Equal - value: < - float_val: 2500 - >)", - [](double v) { return (v + 500) == 2500; }, - DataType::DOUBLE}, - {R"(arith_op: Add - right_operand: < - float_val: 500 - > - op: NotEqual - value: < - float_val: 2000 - >)", - [](float v) { return (v + 500) != 2000; }, - DataType::FLOAT}, - {R"(arith_op: Sub - right_operand: < - float_val: 500 - > - op: NotEqual - value: < - float_val: 2500 - >)", - [](double v) { return (v - 500) != 2000; }, - DataType::DOUBLE}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 2 - >)", - [](int8_t v) { return (v * 2) != 2; }, - DataType::INT8}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 2000 - >)", - [](int16_t v) { return (v / 2) != 2000; }, - DataType::INT16}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: NotEqual - value: < - int64_val: 1 - >)", - [](int32_t v) { return (v % 100) != 1; }, - DataType::INT32}, - {R"(arith_op: Add - right_operand: < - int64_val: 500 - > - op: NotEqual - value: < - int64_val: 2000 - >)", - [](int64_t v) { return (v + 500) != 2000; }, - DataType::INT64}, - - // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: GreaterThan - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) > 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: GreaterThan - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) > 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) > 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) > 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: GreaterThan - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) > 0; }, - DataType::INT32}, - // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: GreaterEqual - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) >= 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: GreaterEqual - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) >= 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) >= 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) >= 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: GreaterEqual - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) >= 0; }, - DataType::INT32}, + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; - // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: LessThan - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) < 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: LessThan - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) < 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) < 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) < 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: LessThan - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) < 0; }, - DataType::INT32}, + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); - // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: LessEqual - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) <= 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: LessEqual - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) <= 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) <= 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) <= 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: LessEqual - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) <= 0; }, - DataType::INT32}, + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - binary_arith_op_eval_range_expr: < - @@@@@ - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } - std::string arith_expr = R"( - column_info: < - field_id: %2% - data_type: %3% - > - @@@@)"; + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } +} +TEST_P(ExprTest, TestTermInFieldJsonNullable) { auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i8_fid = schema->AddDebugField("age8", DataType::INT8); - auto i16_fid = schema->AddDebugField("age16", DataType::INT16); - auto i32_fid = schema->AddDebugField("age32", DataType::INT32); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto float_fid = schema->AddDebugField("age_float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); + auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - auto raw_data = DataGen(schema, N); - segcore::LoadIndexInfo load_index_info; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); - // load index for int8 field - auto age8_col = raw_data.get_col(i8_fid); - age8_col[0] = 4; - GenScalarIndexing(N, age8_col.data()); - auto age8_index = milvus::index::CreateScalarIndexSort(); - age8_index->Build(N, age8_col.data()); - load_index_info.field_id = i8_fid.get(); - load_index_info.field_type = DataType::INT8; - load_index_info.index = std::move(age8_index); - seg->LoadIndex(load_index_info); + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } - // load index for 16 field - auto age16_col = raw_data.get_col(i16_fid); - age16_col[0] = 2000; - GenScalarIndexing(N, age16_col.data()); - auto age16_index = milvus::index::CreateScalarIndexSort(); - age16_index->Build(N, age16_col.data()); - load_index_info.field_id = i16_fid.get(); - load_index_info.field_type = DataType::INT16; - load_index_info.index = std::move(age16_index); - seg->LoadIndex(load_index_info); + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - // load index for int32 field - auto age32_col = raw_data.get_col(i32_fid); - age32_col[0] = 2000; - GenScalarIndexing(N, age32_col.data()); - auto age32_index = milvus::index::CreateScalarIndexSort(); - age32_index->Build(N, age32_col.data()); - load_index_info.field_id = i32_fid.get(); - load_index_info.field_type = DataType::INT32; - load_index_info.index = std::move(age32_index); - seg->LoadIndex(load_index_info); + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; - // load index for int64 field - auto age64_col = raw_data.get_col(i64_fid); - age64_col[0] = 2000; - GenScalarIndexing(N, age64_col.data()); - auto age64_index = milvus::index::CreateScalarIndexSort(); - age64_index->Build(N, age64_col.data()); - load_index_info.field_id = i64_fid.get(); - load_index_info.field_type = DataType::INT64; - load_index_info.index = std::move(age64_index); - seg->LoadIndex(load_index_info); + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + // std::cout << "cost" + // << std::chrono::duration_cast( + // std::chrono::steady_clock::now() - start) + // .count() + // << std::endl; + EXPECT_EQ(final.size(), N * num_iters); - // load index for float field - auto age_float_col = raw_data.get_col(float_fid); - age_float_col[0] = 2000; - GenScalarIndexing(N, age_float_col.data()); - auto age_float_index = milvus::index::CreateScalarIndexSort(); - age_float_index->Build(N, age_float_col.data()); - load_index_info.field_id = float_fid.get(); - load_index_info.field_type = DataType::FLOAT; - load_index_info.index = std::move(age_float_index); - seg->LoadIndex(load_index_info); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } - // load index for double field - auto age_double_col = raw_data.get_col(double_fid); - age_double_col[0] = 2000; - GenScalarIndexing(N, age_double_col.data()); - auto age_double_index = milvus::index::CreateScalarIndexSort(); - age_double_index->Build(N, age_double_col.data()); - load_index_info.field_id = double_fid.get(); - load_index_info.field_type = DataType::FLOAT; - load_index_info.index = std::move(age_double_index); - seg->LoadIndex(load_index_info); + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; - auto seg_promote = dynamic_cast(seg.get()); + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); - int offset = 0; - for (auto [clause, ref_func, dtype] : testcases) { - auto loc = serialized_expr_plan.find("@@@@@"); - auto expr_plan = serialized_expr_plan; - expr_plan.replace(loc, 5, arith_expr); - loc = expr_plan.find("@@@@"); - expr_plan.replace(loc, 4, clause); - boost::format expr; - if (dtype == DataType::INT8) { - expr = boost::format(expr_plan) % vec_fid.get() % i8_fid.get() % - proto::schema::DataType_Name(int(DataType::INT8)); - } else if (dtype == DataType::INT16) { - expr = boost::format(expr_plan) % vec_fid.get() % i16_fid.get() % - proto::schema::DataType_Name(int(DataType::INT16)); - } else if (dtype == DataType::INT32) { - expr = boost::format(expr_plan) % vec_fid.get() % i32_fid.get() % - proto::schema::DataType_Name(int(DataType::INT32)); - } else if (dtype == DataType::INT64) { - expr = boost::format(expr_plan) % vec_fid.get() % i64_fid.get() % - proto::schema::DataType_Name(int(DataType::INT64)); - } else if (dtype == DataType::FLOAT) { - expr = boost::format(expr_plan) % vec_fid.get() % float_fid.get() % - proto::schema::DataType_Name(int(DataType::FLOAT)); - } else if (dtype == DataType::DOUBLE) { - expr = boost::format(expr_plan) % vec_fid.get() % double_fid.get() % - proto::schema::DataType_Name(int(DataType::DOUBLE)); - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); } + } - auto binary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, binary_plan.data(), binary_plan.size()); + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values, + bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N, - MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); - for (int i = 0; i < N; ++i) { + for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::INT8) { - auto val = age8_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT16) { - auto val = age16_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT32) { - auto val = age32_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = age64_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::FLOAT) { - auto val = age_float_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = age_double_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res, valid_data[i])); } } } -TEST_P(ExprTest, TestUnaryRangeWithJSON) { - std::vector< - std::tuple)>, - DataType>> - testcases = { - {R"(op: Equal - value: < - bool_val: true - >)", - [](std::variant v) { - return std::get(v); - }, - DataType::BOOL}, - {R"(op: LessEqual - value: < - int64_val: 1500 - >)", - [](std::variant v) { - return std::get(v) < 1500; - }, - DataType::INT64}, - {R"(op: LessEqual - value: < - float_val: 4000 - >)", - [](std::variant v) { - return std::get(v) <= 4000; - }, - DataType::DOUBLE}, - {R"(op: GreaterThan - value: < - float_val: 1000 - >)", - [](std::variant v) { - return std::get(v) > 1000; - }, - DataType::DOUBLE}, - {R"(op: GreaterEqual - value: < - int64_val: 0 - >)", - [](std::variant v) { - return std::get(v) >= 0; - }, - DataType::INT64}, - {R"(op: NotEqual - value: < - bool_val: true - >)", - [](std::variant v) { - return !std::get(v); - }, - DataType::BOOL}, - {R"(op: Equal - value: < - string_val: "test" - >)", - [](std::variant v) { - return std::get(v) == "test"; - }, - DataType::STRING}, - }; - - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - unary_range_expr: < - @@@@@ - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; +TEST_P(ExprTest, PraseJsonContainsExpr) { + std::vector raw_plans{ + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAny + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: + elements: + elements: + elements: + op:ContainsAll + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + }; - std::string arith_expr = R"( - column_info: < - field_id: %2% - data_type: %3% - nested_path:"%4%" - > - @@@@)"; + for (auto& raw_plan : raw_plans) { + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto schema = std::make_shared(); + schema->AddDebugField("fakevec", data_type, 16, metric_type); + schema->AddDebugField("json", DataType::JSON); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + } +} +TEST_P(ExprTest, TestJsonContainsAny) { auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto json_fid = schema->AddDebugField("json", DataType::JSON); schema->set_primary_field_id(i64_fid); @@ -4809,7 +11630,7 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { std::vector json_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); + auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); json_col.insert( @@ -4823,172 +11644,220 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - int offset = 0; - for (auto [clause, ref_func, dtype] : testcases) { - auto loc = serialized_expr_plan.find("@@@@@"); - auto expr_plan = serialized_expr_plan; - expr_plan.replace(loc, 5, arith_expr); - loc = expr_plan.find("@@@@"); - expr_plan.replace(loc, 4, clause); - boost::format expr; - switch (dtype) { - case DataType::BOOL: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; - break; - } - case DataType::INT64: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "int"; - break; - } - case DataType::DOUBLE: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "double"; - break; + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - case DataType::STRING: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "string"; - break; + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - default: { - ASSERT_TRUE(false) << "No test case defined for this data type"; + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res)); } + } - auto unary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, unary_plan.data(), unary_plan.size()); + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::BOOL) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/bool") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/int") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/double") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::STRING) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/string") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res)); } } } -TEST_P(ExprTest, TestTermWithJSON) { - std::vector< - std::tuple)>, - DataType>> - testcases = { - {R"(values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {true, false}; - return term_set.find(std::get(v)) != term_set.end(); - }, - DataType::BOOL}, - {R"(values: , values: , values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {1500, 2048, 3216}; - return term_set.find(std::get(v)) != term_set.end(); - }, - DataType::INT64}, - {R"(values: , values: , values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {1500.0, 4000, 235.14}; - return term_set.find(std::get(v)) != term_set.end(); - }, - DataType::DOUBLE}, - {R"(values: , values: , values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {"aaa", "abc", "235.14"}; - return term_set.find(std::get(v)) != - term_set.end(); - }, - DataType::STRING}, - {R"()", - [](std::variant v) { - return false; - }, - DataType::INT64}, - }; - - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - term_expr: < - @@@@@ - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; - - std::string arith_expr = R"( - column_info: < - field_id: %2% - data_type: %3% - nested_path:"%4%" - > - @@@@)"; - +TEST_P(ExprTest, TestJsonContainsAnyNullable) { auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); + auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -5001,253 +11870,219 @@ TEST_P(ExprTest, TestTermWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - int offset = 0; - for (auto [clause, ref_func, dtype] : testcases) { - auto loc = serialized_expr_plan.find("@@@@@"); - auto expr_plan = serialized_expr_plan; - expr_plan.replace(loc, 5, arith_expr); - loc = expr_plan.find("@@@@"); - expr_plan.replace(loc, 4, clause); - boost::format expr; - switch (dtype) { - case DataType::BOOL: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; - break; - } - case DataType::INT64: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "int"; - break; + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; } - case DataType::DOUBLE: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "double"; - break; + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - case DataType::STRING: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "string"; - break; + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; } - default: { - ASSERT_TRUE(false) << "No test case defined for this data type"; + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res, valid_data[i])); } + } - auto unary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, unary_plan.data(), unary_plan.size()); + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + for (auto testcase : testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::BOOL) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/bool") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/int") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/double") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::STRING) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/string") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res, valid_data[i])); } } -} - -TEST_P(ExprTest, TestExistsWithJSON) { - std::vector, DataType>> - testcases = { - {R"()", [](bool v) { return v; }, DataType::BOOL}, - {R"()", [](bool v) { return v; }, DataType::INT64}, - {R"()", [](bool v) { return v; }, DataType::STRING}, - {R"()", [](bool v) { return v; }, DataType::VARCHAR}, - {R"()", [](bool v) { return v; }, DataType::DOUBLE}, - }; - - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - exists_expr: < - @@@@@ - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; - - std::string arith_expr = R"( - info: < - field_id: %2% - data_type: %3% - nested_path:"%4%" - > - @@@@)"; - - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); - schema->set_primary_field_id(i64_fid); - - auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 1000; - std::vector json_col; - int num_iters = 1; - for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); - auto new_json_col = raw_data.get_col(json_fid); - - json_col.insert( - json_col.end(), new_json_col.begin(), new_json_col.end()); - seg->PreInsert(N); - seg->Insert(iter * N, - N, - raw_data.row_ids_.data(), - raw_data.timestamps_.data(), - raw_data.raw_); - } - auto seg_promote = dynamic_cast(seg.get()); + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; - int offset = 0; - for (auto [clause, ref_func, dtype] : testcases) { - auto loc = serialized_expr_plan.find("@@@@@"); - auto expr_plan = serialized_expr_plan; - expr_plan.replace(loc, 5, arith_expr); - loc = expr_plan.find("@@@@"); - expr_plan.replace(loc, 4, clause); - boost::format expr; - switch (dtype) { - case DataType::BOOL: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; - break; - } - case DataType::INT64: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "int"; - break; - } - case DataType::DOUBLE: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "double"; - break; - } - case DataType::STRING: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "string"; - break; - } - case DataType::VARCHAR: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "varchar"; - break; - } - default: { - ASSERT_TRUE(false) << "No test case defined for this data type"; + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values, + bool valid) { + if (!valid) { + return false; } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); } - - auto unary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, unary_plan.data(), unary_plan.size()); - + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::BOOL) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/bool"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/int"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/double"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::STRING) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/string"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::VARCHAR) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/varchar"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res, valid_data[i])); } } } -template -struct Testcase { - std::vector term; - std::vector nested_path; - bool res; -}; - -TEST_P(ExprTest, TestTermInFieldJson) { +TEST_P(ExprTest, TestJsonContainsAll) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto json_fid = schema->AddDebugField("json", DataType::JSON); @@ -5273,13 +12108,18 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto seg_promote = dynamic_cast(seg.get()); - std::vector> bool_testcases{{{true}, {"bool"}}, - {{false}, {"bool"}}}; + std::vector> bool_testcases{{{true, true}, {"bool"}}, + {{false, false}, {"bool"}}}; for (auto testcase : bool_testcases) { auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; @@ -5288,22 +12128,23 @@ TEST_P(ExprTest, TestTermInFieldJson) { val.set_bool_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - // std::cout << "cost" - // << std::chrono::duration_cast( - // std::chrono::steady_clock::now() - start) - // .count() - // << std::endl; + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -5319,29 +12160,37 @@ TEST_P(ExprTest, TestTermInFieldJson) { } std::vector> double_testcases{ - {{1.123}, {"double"}}, - {{10.34}, {"double"}}, - {{100.234}, {"double"}}, - {{1000.4546}, {"double"}}, + {{1.123, 10.34}, {"double"}}, + {{10.34, 100.234}, {"double"}}, + {{100.234, 1000.4546}, {"double"}}, + {{1000.4546, 1.123}, {"double"}}, + {{1000.4546, 10.34}, {"double"}}, + {{1.123, 100.234}, {"double"}}, }; for (auto testcase : double_testcases) { auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; - for (auto v : testcase.term) { + for (auto& v : testcase.term) { proto::plan::GenericValue val; val.set_float_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); @@ -5369,16 +12218,23 @@ TEST_P(ExprTest, TestTermInFieldJson) { } std::vector> testcases{ - {{1}, {"int"}}, - {{10}, {"int"}}, - {{100}, {"int"}}, - {{1000}, {"int"}}, + {{1, 10}, {"int"}}, + {{10, 100}, {"int"}}, + {{100, 1000}, {"int"}}, + {{1000, 10}, {"int"}}, + {{2, 4, 6, 8, 10}, {"int"}}, + {{1, 2, 3, 4, 5}, {"int"}}, }; for (auto testcase : testcases) { auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; @@ -5387,11 +12243,12 @@ TEST_P(ExprTest, TestTermInFieldJson) { val.set_int64_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); @@ -5418,210 +12275,76 @@ TEST_P(ExprTest, TestTermInFieldJson) { } std::vector> testcases_string = { - {{"1sads"}, {"string"}}, - {{"10dsf"}, {"string"}}, - {{"100"}, {"string"}}, - {{"100ddfdsssdfdsfsd0"}, {"string"}}, + {{"1sads", "10dsf"}, {"string"}}, + {{"10dsf", "100"}, {"string"}}, + {{"100", "10dsf", "1sads"}, {"string"}}, + {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, }; for (auto testcase : testcases_string) { auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_string_val(v); - values.push_back(val); - } - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - values, - true); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - auto start = std::chrono::steady_clock::now(); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - std::cout << "cost" - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; - EXPECT_EQ(final.size(), N * num_iters); - - for (int i = 0; i < N * num_iters; ++i) { - auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); - } - } -} - -TEST_P(ExprTest, PraseJsonContainsExpr) { - std::vector raw_plans{ - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAny - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: - elements: - elements: - elements: - op:ContainsAll - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - }; - - for (auto& raw_plan : raw_plans) { - auto plan_str = translate_text_plan_with_metric_type(raw_plan); - auto schema = std::make_shared(); - schema->AddDebugField("fakevec", data_type, 16, metric_type); - schema->AddDebugField("json", DataType::JSON); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } } } -TEST_P(ExprTest, TestJsonContainsAny) { +TEST_P(ExprTest, TestJsonContainsAllNullable) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -5635,13 +12358,21 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto seg_promote = dynamic_cast(seg.get()); - std::vector> bool_testcases{{{true}, {"bool"}}, - {{false}, {"bool"}}}; + std::vector> bool_testcases{{{true, true}, {"bool"}}, + {{false, false}, {"bool"}}}; for (auto testcase : bool_testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; @@ -5653,7 +12384,7 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, values); BitsetType final; @@ -5677,21 +12408,31 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (const auto& element : array) { res.push_back(element.template get()); } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check(res, valid_data[i])); } } std::vector> double_testcases{ - {{1.123}, {"double"}}, - {{10.34}, {"double"}}, - {{100.234}, {"double"}}, - {{1000.4546}, {"double"}}, + {{1.123, 10.34}, {"double"}}, + {{10.34, 100.234}, {"double"}}, + {{100.234, 1000.4546}, {"double"}}, + {{1000.4546, 1.123}, {"double"}}, + {{1000.4546, 10.34}, {"double"}}, + {{1.123, 100.234}, {"double"}}, }; for (auto testcase : double_testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; @@ -5703,7 +12444,7 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, values); BitsetType final; @@ -5727,38 +12468,214 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (const auto& element : array) { res.push_back(element.template get()); } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check(res, valid_data[i])); } } std::vector> testcases{ - {{1}, {"int"}}, - {{10}, {"int"}}, - {{100}, {"int"}}, - {{1000}, {"int"}}, + {{1, 10}, {"int"}}, + {{10, 100}, {"int"}}, + {{100, 1000}, {"int"}}, + {{1000, 10}, {"int"}}, + {{2, 4, 6, 8, 10}, {"int"}}, + {{1, 2, 3, 4, 5}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases_string = { + {{"1sads", "10dsf"}, {"string"}}, + {{"10dsf", "100"}, {"string"}}, + {{"100", "10dsf", "1sads"}, {"string"}}, + {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, }; - for (auto testcase : testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values, + bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } +} + +TEST_P(ExprTest, TestJsonContainsArray) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + proto::plan::GenericValue generic_a; + auto* a = generic_a.mutable_array_val(); + a->set_same_type(false); + for (int i = 0; i < 4; ++i) { + if (i % 4 == 0) { + proto::plan::GenericValue int_val; + int_val.set_int64_val(int64_t(i)); + a->add_array()->CopyFrom(int_val); + } else if ((i - 1) % 4 == 0) { + proto::plan::GenericValue bool_val; + bool_val.set_bool_val(bool(i)); + a->add_array()->CopyFrom(bool_val); + } else if ((i - 2) % 4 == 0) { + proto::plan::GenericValue float_val; + float_val.set_float_val(double(i)); + a->add_array()->CopyFrom(float_val); + } else if ((i - 3) % 4 == 0) { + proto::plan::GenericValue string_val; + string_val.set_string_val(std::to_string(i)); + a->add_array()->CopyFrom(string_val); + } + } + proto::plan::GenericValue generic_b; + auto* b = generic_b.mutable_array_val(); + b->set_same_type(true); + proto::plan::GenericValue int_val1; + int_val1.set_int64_val(int64_t(1)); + b->add_array()->CopyFrom(int_val1); + + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(2)); + b->add_array()->CopyFrom(int_val2); + + proto::plan::GenericValue int_val3; + int_val3.set_int64_val(int64_t(3)); + b->add_array()->CopyFrom(int_val3); + + std::vector> diff_testcases{ + {{generic_a}, {"string"}}, {{generic_b}, {"array"}}}; + + for (auto& testcase : diff_testcases) { + auto check = [&](const std::vector& values, int i) { + if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { + return true; + } + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_int64_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAny, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5771,44 +12688,28 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } - std::vector> testcases_string = { - {{"1sads"}, {"string"}}, - {{"10dsf"}, {"string"}}, - {{"100"}, {"string"}}, - {{"100ddfdsssdfdsfsd0"}, {"string"}}, - }; - - for (auto testcase : testcases_string) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto& testcase : diff_testcases) { + auto check = [&](const std::vector& values, int i) { + if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { + return true; + } + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_string_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5821,72 +12722,47 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } -} - -TEST_P(ExprTest, TestJsonContainsAll) { - auto schema = std::make_shared(); - auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); - schema->set_primary_field_id(i64_fid); - auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 1000; - std::vector json_col; - int num_iters = 1; - for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGenForJsonArray(schema, N, iter); - auto new_json_col = raw_data.get_col(json_fid); + proto::plan::GenericValue g_sub_arr1; + auto* sub_arr1 = g_sub_arr1.mutable_array_val(); + sub_arr1->set_same_type(true); + proto::plan::GenericValue int_val11; + int_val11.set_int64_val(int64_t(1)); + sub_arr1->add_array()->CopyFrom(int_val11); - json_col.insert( - json_col.end(), new_json_col.begin(), new_json_col.end()); - seg->PreInsert(N); - seg->Insert(iter * N, - N, - raw_data.row_ids_.data(), - raw_data.timestamps_.data(), - raw_data.raw_); - } + proto::plan::GenericValue int_val12; + int_val12.set_int64_val(int64_t(2)); + sub_arr1->add_array()->CopyFrom(int_val12); - auto seg_promote = dynamic_cast(seg.get()); + proto::plan::GenericValue g_sub_arr2; + auto* sub_arr2 = g_sub_arr2.mutable_array_val(); + sub_arr2->set_same_type(true); + proto::plan::GenericValue int_val21; + int_val21.set_int64_val(int64_t(3)); + sub_arr2->add_array()->CopyFrom(int_val21); - std::vector> bool_testcases{{{true, true}, {"bool"}}, - {{false, false}, {"bool"}}}; + proto::plan::GenericValue int_val22; + int_val22.set_int64_val(int64_t(4)); + sub_arr2->add_array()->CopyFrom(int_val22); + std::vector> diff_testcases2{ + {{g_sub_arr1, g_sub_arr2}, {"array2"}}}; - for (auto testcase : bool_testcases) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } - return true; - }; + for (auto& testcase : diff_testcases2) { + auto check = [&]() { return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto v : testcase.term) { - proto::plan::GenericValue val; - val.set_bool_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5899,51 +12775,24 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check()); } } - std::vector> double_testcases{ - {{1.123, 10.34}, {"double"}}, - {{10.34, 100.234}, {"double"}}, - {{100.234, 1000.4546}, {"double"}}, - {{1000.4546, 1.123}, {"double"}}, - {{1000.4546, 10.34}, {"double"}}, - {{1.123, 100.234}, {"double"}}, - }; - - for (auto testcase : double_testcases) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } + for (auto& testcase : diff_testcases2) { + auto check = [&](const std::vector& values, int i) { return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_float_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5956,51 +12805,49 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } - std::vector> testcases{ - {{1, 10}, {"int"}}, - {{10, 100}, {"int"}}, - {{100, 1000}, {"int"}}, - {{1000, 10}, {"int"}}, - {{2, 4, 6, 8, 10}, {"int"}}, - {{1, 2, 3, 4, 5}, {"int"}}, - }; + proto::plan::GenericValue g_sub_arr3; + auto* sub_arr3 = g_sub_arr3.mutable_array_val(); + sub_arr3->set_same_type(true); + proto::plan::GenericValue int_val31; + int_val31.set_int64_val(int64_t(5)); + sub_arr3->add_array()->CopyFrom(int_val31); - for (auto testcase : testcases) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } - return true; + proto::plan::GenericValue int_val32; + int_val32.set_int64_val(int64_t(6)); + sub_arr3->add_array()->CopyFrom(int_val32); + + proto::plan::GenericValue g_sub_arr4; + auto* sub_arr4 = g_sub_arr4.mutable_array_val(); + sub_arr4->set_same_type(true); + proto::plan::GenericValue int_val41; + int_val41.set_int64_val(int64_t(7)); + sub_arr4->add_array()->CopyFrom(int_val41); + + proto::plan::GenericValue int_val42; + int_val42.set_int64_val(int64_t(8)); + sub_arr4->add_array()->CopyFrom(int_val42); + std::vector> diff_testcases3{ + {{g_sub_arr3, g_sub_arr4}, {"array2"}}}; + + for (auto& testcase : diff_testcases3) { + auto check = [&](const std::vector& values, int i) { + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_int64_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -6013,49 +12860,25 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } - } - - std::vector> testcases_string = { - {{"1sads", "10dsf"}, {"string"}}, - {{"10dsf", "100"}, {"string"}}, - {{"100", "10dsf", "1sads"}, {"string"}}, - {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, - }; + } - for (auto testcase : testcases_string) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } - return true; + for (auto& testcase : diff_testcases3) { + auto check = [&](const std::vector& values, int i) { + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_string_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -6068,30 +12891,27 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } } -TEST_P(ExprTest, TestJsonContainsArray) { +TEST_P(ExprTest, TestJsonContainsArrayNullable) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -6146,7 +12966,10 @@ TEST_P(ExprTest, TestJsonContainsArray) { {{generic_a}, {"string"}}, {{generic_b}, {"array"}}}; for (auto& testcase : diff_testcases) { - auto check = [&](const std::vector& values, int i) { + auto check = [&](const std::vector& values, int i, bool valid) { + if (!valid) { + return false; + } if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { return true; } @@ -6175,12 +12998,15 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check(res, i, valid_data[i])); } } for (auto& testcase : diff_testcases) { - auto check = [&](const std::vector& values, int i) { + auto check = [&](const std::vector& values, int i, bool valid) { + if (!valid) { + return false; + } if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { return true; } @@ -6209,7 +13035,7 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check(res, i, valid_data[i])); } } @@ -6238,7 +13064,12 @@ TEST_P(ExprTest, TestJsonContainsArray) { {{g_sub_arr1, g_sub_arr2}, {"array2"}}}; for (auto& testcase : diff_testcases2) { - auto check = [&]() { return true; }; + auto check = [&](bool valid) { + if (!valid) { + return false; + } + return true; + }; auto pointer = milvus::Json::pointer(testcase.nested_path); auto expr = std::make_shared( milvus::expr::ColumnInfo( @@ -6261,12 +13092,15 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - ASSERT_EQ(ans, check()); + ASSERT_EQ(ans, check(valid_data[i])); } } for (auto& testcase : diff_testcases2) { - auto check = [&](const std::vector& values, int i) { + auto check = [&](const std::vector& values, int i, bool valid) { + if (!valid) { + return false; + } return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); @@ -6292,7 +13126,7 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check(res, i, valid_data[i])); } } @@ -6514,6 +13348,117 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArray) { } } +TEST_P(ExprTest, TestJsonContainsDiffTypeArrayNullable) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + proto::plan::GenericValue int_value; + int_value.set_int64_val(1); + auto diff_type_array1 = + generatedArrayWithFourDiffType(1, 2.2, false, "abc"); + auto diff_type_array2 = + generatedArrayWithFourDiffType(1, 2.2, false, "def"); + auto diff_type_array3 = generatedArrayWithFourDiffType(1, 2.2, true, "abc"); + auto diff_type_array4 = + generatedArrayWithFourDiffType(1, 3.3, false, "abc"); + auto diff_type_array5 = + generatedArrayWithFourDiffType(2, 2.2, false, "abc"); + + std::vector> diff_testcases{ + {{diff_type_array1, int_value}, {"array3"}, true}, + {{diff_type_array2, int_value}, {"array3"}, false}, + {{diff_type_array3, int_value}, {"array3"}, false}, + {{diff_type_array4, int_value}, {"array3"}, false}, + {{diff_type_array5, int_value}, {"array3"}, false}, + }; + + for (auto& testcase : diff_testcases) { + auto check = [&](bool valid) { + if (!valid) { + return false; + } + return testcase.res; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + ASSERT_EQ(ans, check(valid_data[i])); + } + } + + for (auto& testcase : diff_testcases) { + auto check = [&]() { return false; }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + ASSERT_EQ(ans, check()); + } + } +} + TEST_P(ExprTest, TestJsonContainsDiffType) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); @@ -6621,3 +13566,121 @@ TEST_P(ExprTest, TestJsonContainsDiffType) { } } } +TEST_P(ExprTest, TestJsonContainsDiffTypeNullable) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + proto::plan::GenericValue int_val; + int_val.set_int64_val(int64_t(3)); + proto::plan::GenericValue bool_val; + bool_val.set_bool_val(bool(false)); + proto::plan::GenericValue float_val; + float_val.set_float_val(double(100.34)); + proto::plan::GenericValue string_val; + string_val.set_string_val("10dsf"); + + proto::plan::GenericValue string_val2; + string_val2.set_string_val("abc"); + proto::plan::GenericValue bool_val2; + bool_val2.set_bool_val(bool(true)); + proto::plan::GenericValue float_val2; + float_val2.set_float_val(double(2.2)); + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(1)); + + std::vector> diff_testcases{ + {{int_val, bool_val, float_val, string_val}, + {"diff_type_array"}, + false}, + {{string_val2, bool_val2, float_val2, int_val2}, + {"diff_type_array"}, + true}, + }; + + for (auto& testcase : diff_testcases) { + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + } else { + ASSERT_EQ(ans, testcase.res); + } + } + } + + for (auto& testcase : diff_testcases) { + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + } else { + ASSERT_EQ(ans, testcase.res); + } + } + } +} diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index a406ab8e86bb0..cb4ccf4131cbd 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -166,7 +166,8 @@ GenAlwaysFalseExpr(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { } auto -GenAlwaysTrueExpr(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { +GenAlwaysTrueExprIfValid(const FieldMeta& fvec_meta, + const FieldMeta& str_meta) { auto always_false_expr = GenAlwaysFalseExpr(fvec_meta, str_meta); auto not_expr = GenNotExpr(); not_expr->set_allocated_child(always_false_expr); @@ -196,7 +197,7 @@ GenAlwaysFalsePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { auto GenAlwaysTruePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { - auto always_true_expr = GenAlwaysTrueExpr(fvec_meta, str_meta); + auto always_true_expr = GenAlwaysTrueExprIfValid(fvec_meta, str_meta); proto::plan::VectorType vector_type; if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { vector_type = proto::plan::VectorType::FloatVector; @@ -299,6 +300,82 @@ TEST(StringExpr, Term) { } } +TEST(StringExpr, TermNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto vec_2k_3k = []() -> std::vector { + std::vector ret; + for (int i = 2000; i < 3000; i++) { + ret.push_back(std::to_string(i)); + } + return ret; + }(); + + std::map> terms = { + {0, {"2000", "3000"}}, + {1, {"2000"}}, + {2, {"3000"}}, + {3, {}}, + {4, {vec_2k_3k}}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_str_col = raw_data.get_col(str_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [_, term] : terms) { + auto plan_proto = GenTermPlan(fvec_meta, str_meta, term); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto ref = std::find(term.begin(), term.end(), val) != term.end(); + ASSERT_EQ(ans, ref) << "@" << i << "!!" << val; + } + } +} + TEST(StringExpr, Compare) { auto schema = GenTestSchema(); const auto& fvec_meta = schema->operator[](FieldName("fvec")); @@ -395,6 +472,267 @@ TEST(StringExpr, Compare) { for (const auto& [op, ref_func] : testcases) { auto plan_proto = gen_compare_plan(op); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = str_col[i]; + auto another_val = another_str_col[i]; + auto ref = ref_func(val, another_val); + ASSERT_EQ(ans, ref) << "@" << op << "@" << i << "!!" << val; + } + } +} + +TEST(StringExpr, CompareNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + const auto& another_str_meta = schema->operator[](FieldName("another_str")); + + auto gen_compare_plan = + [&, fvec_meta, str_meta, another_str_meta]( + proto::plan::OpType op) -> std::unique_ptr { + auto str_col_info = + test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto another_str_col_info = + test::GenColumnInfo(another_str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + + auto compare_expr = GenCompareExpr(op); + compare_expr->set_allocated_left_column_info(str_col_info); + compare_expr->set_allocated_right_column_info(another_str_col_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_compare_expr(compare_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + std::vector>> + testcases{ + {proto::plan::OpType::GreaterThan, + [](std::string& v1, std::string& v2) { return v1 > v2; }}, + {proto::plan::OpType::GreaterEqual, + [](std::string& v1, std::string& v2) { return v1 >= v2; }}, + {proto::plan::OpType::LessThan, + [](std::string& v1, std::string& v2) { return v1 < v2; }}, + {proto::plan::OpType::LessEqual, + [](std::string& v1, std::string& v2) { return v1 <= v2; }}, + {proto::plan::OpType::Equal, + [](std::string& v1, std::string& v2) { return v1 == v2; }}, + {proto::plan::OpType::NotEqual, + [](std::string& v1, std::string& v2) { return v1 != v2; }}, + {proto::plan::OpType::PrefixMatch, + [](std::string& v1, std::string& v2) { + return PrefixMatch(v1, v2); + }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + std::vector another_str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto reserve_col = [&, raw_data](const FieldMeta& field_meta, + std::vector& str_col) { + auto new_str_col = raw_data.get_col(field_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + }; + + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + + reserve_col(str_meta, str_col); + reserve_col(another_str_meta, another_str_col); + + { + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [op, ref_func] : testcases) { + auto plan_proto = gen_compare_plan(op); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto another_val = another_str_col[i]; + auto ref = ref_func(val, another_val); + ASSERT_EQ(ans, ref) << "@" << op << "@" << i << "!!" << val; + } + } +} + +TEST(StringExpr, CompareNullable2) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR); + schema->AddDebugField("another_str", DataType::VARCHAR, true); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + const auto& another_str_meta = schema->operator[](FieldName("another_str")); + + auto gen_compare_plan = + [&, fvec_meta, str_meta, another_str_meta]( + proto::plan::OpType op) -> std::unique_ptr { + auto str_col_info = + test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto another_str_col_info = + test::GenColumnInfo(another_str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + + auto compare_expr = GenCompareExpr(op); + compare_expr->set_allocated_left_column_info(str_col_info); + compare_expr->set_allocated_right_column_info(another_str_col_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_compare_expr(compare_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + std::vector>> + testcases{ + {proto::plan::OpType::GreaterThan, + [](std::string& v1, std::string& v2) { return v1 > v2; }}, + {proto::plan::OpType::GreaterEqual, + [](std::string& v1, std::string& v2) { return v1 >= v2; }}, + {proto::plan::OpType::LessThan, + [](std::string& v1, std::string& v2) { return v1 < v2; }}, + {proto::plan::OpType::LessEqual, + [](std::string& v1, std::string& v2) { return v1 <= v2; }}, + {proto::plan::OpType::Equal, + [](std::string& v1, std::string& v2) { return v1 == v2; }}, + {proto::plan::OpType::NotEqual, + [](std::string& v1, std::string& v2) { return v1 != v2; }}, + {proto::plan::OpType::PrefixMatch, + [](std::string& v1, std::string& v2) { + return PrefixMatch(v1, v2); + }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + std::vector another_str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto reserve_col = [&, raw_data](const FieldMeta& field_meta, + std::vector& str_col) { + auto new_str_col = raw_data.get_col(field_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + }; + + auto new_str_valid_col = + raw_data.get_col_valid(another_str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + + reserve_col(str_meta, str_col); + reserve_col(another_str_meta, another_str_col); + + { + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [op, ref_func] : testcases) { + auto plan_proto = gen_compare_plan(op); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; final = ExecuteQueryExpr( plan->plan_node_->plannodes_->sources()[0]->sources()[0], @@ -405,7 +743,10 @@ TEST(StringExpr, Compare) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } auto val = str_col[i]; auto another_val = another_str_col[i]; auto ref = ref_func(val, another_val); @@ -510,6 +851,116 @@ TEST(StringExpr, UnaryRange) { } } +TEST(StringExpr, UnaryRangeNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto gen_unary_range_plan = + [&, fvec_meta, str_meta]( + proto::plan::OpType op, + std::string value) -> std::unique_ptr { + auto column_info = test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto unary_range_expr = test::GenUnaryRangeExpr(op, value); + unary_range_expr->set_allocated_column_info(column_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_unary_range_expr(unary_range_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + std::vector>> + testcases{ + {proto::plan::OpType::GreaterThan, + "2000", + [](std::string& val) { return val > "2000"; }}, + {proto::plan::OpType::GreaterEqual, + "2000", + [](std::string& val) { return val >= "2000"; }}, + {proto::plan::OpType::LessThan, + "3000", + [](std::string& val) { return val < "3000"; }}, + {proto::plan::OpType::LessEqual, + "3000", + [](std::string& val) { return val <= "3000"; }}, + {proto::plan::OpType::PrefixMatch, + "a", + [](std::string& val) { return PrefixMatch(val, "a"); }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_str_col = raw_data.get_col(str_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [op, value, ref_func] : testcases) { + auto plan_proto = gen_unary_range_plan(op, value); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) + << "@" << op << "@" << value << "@" << i << "!!" << val; + } + } +} + TEST(StringExpr, BinaryRange) { auto schema = GenTestSchema(); const auto& fvec_meta = schema->operator[](FieldName("fvec")); @@ -625,6 +1076,136 @@ TEST(StringExpr, BinaryRange) { } } +TEST(StringExpr, BinaryRangeNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto gen_binary_range_plan = + [&, fvec_meta, str_meta]( + bool lb_inclusive, + bool ub_inclusive, + std::string lb, + std::string ub) -> std::unique_ptr { + auto column_info = test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto binary_range_expr = + GenBinaryRangeExpr(lb_inclusive, ub_inclusive, lb, ub); + binary_range_expr->set_allocated_column_info(column_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_binary_range_expr(binary_range_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + // bool lb_inclusive, bool ub_inclusive, std::string lb, std::string ub + std::vector>> + testcases{ + {false, + false, + "2000", + "3000", + [](std::string& val) { return val > "2000" && val < "3000"; }}, + {false, + true, + "2000", + "3000", + [](std::string& val) { return val > "2000" && val <= "3000"; }}, + {true, + false, + "2000", + "3000", + [](std::string& val) { return val >= "2000" && val < "3000"; }}, + {true, + true, + "2000", + "3000", + [](std::string& val) { return val >= "2000" && val <= "3000"; }}, + {true, + true, + "2000", + "1000", + [](std::string& val) { return false; }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_str_col = raw_data.get_col(str_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [lb_inclusive, ub_inclusive, lb, ub, ref_func] : + testcases) { + auto plan_proto = + gen_binary_range_plan(lb_inclusive, ub_inclusive, lb, ub); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) + << "@" << lb_inclusive << "@" << ub_inclusive << "@" << lb + << "@" << ub << "@" << i << "!!" << val; + } + } +} + TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { auto schema = GenStrPKSchema(); const auto& fvec_meta = schema->operator[](FieldName("fvec")); @@ -718,7 +1299,7 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFields) { dataset.timestamps_.data(), dataset.raw_); - auto expr_proto = GenAlwaysTrueExpr(fvec_meta, str_meta); + auto expr_proto = GenAlwaysTrueExprIfValid(fvec_meta, str_meta); auto plan_proto = GenPlanNode(); plan_proto->mutable_query()->set_allocated_predicates(expr_proto); SetTargetEntry(plan_proto, {str_meta.get_id().get()}); @@ -733,4 +1314,47 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFields) { ASSERT_EQ(retrieved->fields_data().size(), 1); ASSERT_EQ(retrieved->fields_data(0).scalars().string_data().data().size(), N); + ASSERT_EQ(retrieved->fields_data(0).valid_data_size(), 0); +} + +TEST(AlwaysTrueStringPlan, QueryWithOutputFieldsNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto N = 10000; + auto dataset = DataGen(schema, N); + auto vec_col = dataset.get_col(fvec_meta.get_id()); + auto str_col = + dataset.get_col(str_meta.get_id())->scalars().string_data().data(); + auto valid_data = dataset.get_col_valid(str_meta.get_id()); + auto segment = CreateGrowingSegment(schema, empty_index_meta); + segment->PreInsert(N); + segment->Insert(0, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + + auto expr_proto = GenAlwaysTrueExprIfValid(fvec_meta, str_meta); + auto plan_proto = GenPlanNode(); + plan_proto->mutable_query()->set_allocated_predicates(expr_proto); + SetTargetEntry(plan_proto, {str_meta.get_id().get()}); + auto plan = ProtoParser(*schema).CreateRetrievePlan(*plan_proto); + + Timestamp time = MAX_TIMESTAMP; + + auto retrieved = segment->Retrieve( + nullptr, plan.get(), time, DEFAULT_MAX_OUTPUT_SIZE, false); + ASSERT_EQ(retrieved->offset().size(), N / 2); + ASSERT_EQ(retrieved->fields_data().size(), 1); + ASSERT_EQ(retrieved->fields_data(0).scalars().string_data().data().size(), + N / 2); + ASSERT_EQ(retrieved->fields_data(0).valid_data().size(), N / 2); } diff --git a/internal/core/unittest/test_utils/AssertUtils.h b/internal/core/unittest/test_utils/AssertUtils.h index 5e92369b90436..837e44fe768e1 100644 --- a/internal/core/unittest/test_utils/AssertUtils.h +++ b/internal/core/unittest/test_utils/AssertUtils.h @@ -139,7 +139,9 @@ template inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_EQ(index->Reverse_Lookup(offset), arr[offset]); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_EQ(raw.value(), arr[offset]); } } @@ -147,7 +149,9 @@ template <> inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE(compare_float(index->Reverse_Lookup(offset), arr[offset])); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_TRUE(compare_float(raw.value(), arr[offset])); } } @@ -155,7 +159,9 @@ template <> inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE(compare_double(index->Reverse_Lookup(offset), arr[offset])); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_TRUE(compare_double(raw.value(), arr[offset])); } } @@ -164,7 +170,9 @@ inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE(arr[offset].compare(index->Reverse_Lookup(offset)) == 0); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_TRUE(arr[offset].compare(raw.value()) == 0); } } diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 6e4faa28b1394..48af55d6ef7d6 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -667,8 +667,14 @@ DataGenForJsonArray(SchemaPtr schema, auto insert_data = std::make_unique(); auto insert_cols = [&insert_data]( auto& data, int64_t count, auto& field_meta) { + FixedVector valid_data(count); + if (field_meta.is_nullable()) { + for (int i = 0; i < count; ++i) { + valid_data[i] = i % 2 == 0 ? true : false; + } + } auto array = milvus::segcore::CreateDataArrayFrom( - data.data(), nullptr, count, field_meta); + data.data(), valid_data.data(), count, field_meta); insert_data->mutable_fields_data()->AddAllocated(array.release()); }; for (auto field_id : schema->get_field_ids()) { diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index a1dd1d8bba3fb..9332828960b53 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -13019,7 +13019,6 @@ def test_search_collection_with_non_default_data_after_release_load(self, nq, _a @pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.GPU) - @pytest.mark.skip(reason="issue #36184") def test_search_after_different_index_with_params_none_default_data(self, varchar_scalar_index, numeric_scalar_index, null_data_percent, _async): """