From a50634200a307beccec6ec296f677d4de278afef Mon Sep 17 00:00:00 2001 From: lixinguo Date: Sat, 17 Aug 2024 00:44:31 +0800 Subject: [PATCH 1/4] enhance: all op(Null) is false in expr Signed-off-by: lixinguo --- .../src/exec/expression/AlwaysTrueExpr.cpp | 1 + .../expression/BinaryArithOpEvalRangeExpr.cpp | 119 +- .../expression/BinaryArithOpEvalRangeExpr.h | 167 +- .../src/exec/expression/BinaryRangeExpr.cpp | 27 +- .../src/exec/expression/BinaryRangeExpr.h | 72 +- .../core/src/exec/expression/CompareExpr.cpp | 218 +- .../core/src/exec/expression/CompareExpr.h | 32 +- .../core/src/exec/expression/ExistsExpr.cpp | 5 + internal/core/src/exec/expression/Expr.h | 20 +- .../src/exec/expression/JsonContainsExpr.cpp | 40 + .../core/src/exec/expression/TermExpr.cpp | 25 + .../core/src/exec/expression/UnaryExpr.cpp | 65 +- internal/core/src/exec/expression/UnaryExpr.h | 89 +- .../operator/groupby/SearchGroupByOperator.h | 4 +- internal/core/src/index/BitmapIndex.cpp | 9 +- internal/core/src/index/BitmapIndex.h | 4 +- internal/core/src/index/HybridScalarIndex.h | 8 +- .../core/src/index/InvertedIndexTantivy.h | 4 +- internal/core/src/index/ScalarIndex.h | 4 +- internal/core/src/index/ScalarIndexSort.cpp | 24 +- internal/core/src/index/ScalarIndexSort.h | 4 +- internal/core/src/index/StringIndexMarisa.cpp | 15 +- internal/core/src/index/StringIndexMarisa.h | 6 +- internal/core/src/query/ScalarIndex.h | 4 +- internal/core/src/segcore/FieldIndexing.cpp | 1 + .../core/src/segcore/SegmentSealedImpl.cpp | 10 +- internal/core/src/segcore/Utils.cpp | 48 +- internal/core/unittest/test_expr.cpp | 14446 +++++++++++----- internal/core/unittest/test_string_expr.cpp | 738 +- .../core/unittest/test_utils/AssertUtils.h | 11 +- internal/core/unittest/test_utils/DataGen.h | 8 +- 31 files changed, 12011 insertions(+), 4217 deletions(-) diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp index 24789c429ac8a..94ff5b96986ba 100644 --- a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp @@ -25,6 +25,7 @@ PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) { ? active_count_ - current_pos_ : batch_size_; + // always true no need to skip null if (real_batch_size == 0) { result = nullptr; return; diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp index 7f64cae5b390e..055acb66e5950 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -129,6 +129,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { #define BinaryArithRangeJSONCompare(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ + if (valid_data && !valid_data[i]) { \ + res[i] = false; \ + continue; \ + } \ auto x = data[i].template at(pointer); \ if (x.error()) { \ if constexpr (std::is_same_v) { \ @@ -146,6 +150,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { #define BinaryArithRangeJSONCompareNotEqual(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ + if (valid_data && !valid_data[i]) { \ + res[i] = false; \ + continue; \ + } \ auto x = data[i].template at(pointer); \ if (x.error()) { \ if constexpr (std::is_same_v) { \ @@ -161,6 +169,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } while (false) auto execute_sub_batch = [op_type, arith_type](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, ValueType val, @@ -197,6 +206,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -246,6 +259,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -295,6 +312,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -344,6 +365,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -393,6 +418,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -442,6 +471,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } int array_length = 0; auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -511,6 +544,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { #define BinaryArithRangeArrayCompare(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ + if (valid_data && !valid_data[i]) { \ + res[i] = false; \ + continue; \ + } \ if (index >= data[i].length()) { \ res[i] = false; \ continue; \ @@ -521,6 +558,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } while (false) auto execute_sub_batch = [op_type, arith_type](const ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, ValueType val, @@ -601,6 +639,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = data[i].length() != val; } break; @@ -644,6 +686,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = data[i].length() > val; } break; @@ -687,6 +733,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = data[i].length() >= val; } break; @@ -730,6 +780,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = data[i].length() < val; } break; @@ -773,6 +827,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = data[i].length() <= val; } break; @@ -1217,6 +1275,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { auto arith_type = expr_->arith_op_type_; auto execute_sub_batch = [op_type, arith_type]( const T* data, + const bool* valid_data, const int size, TargetBitmapView res, HighPrecisionType value, @@ -1229,7 +1288,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Add> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1237,7 +1296,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Sub> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1245,7 +1304,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Mul> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1253,7 +1312,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Div> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1261,7 +1320,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Mod> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } default: @@ -1280,7 +1339,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Add> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1288,7 +1347,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Sub> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1296,7 +1355,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Mul> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1304,7 +1363,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Div> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1312,7 +1371,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Mod> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } default: @@ -1331,7 +1390,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Add> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1339,7 +1398,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Sub> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1347,7 +1406,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Mul> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1355,7 +1414,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Div> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1363,7 +1422,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Mod> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } default: @@ -1382,7 +1441,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Add> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1390,7 +1449,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Sub> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1398,7 +1457,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Mul> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1406,7 +1465,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Div> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1414,7 +1473,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Mod> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } default: @@ -1433,7 +1492,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Add> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1441,7 +1500,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Sub> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1449,7 +1508,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Mul> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1457,7 +1516,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Div> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1465,7 +1524,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Mod> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } default: @@ -1484,7 +1543,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Add> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1492,7 +1551,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Sub> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1500,7 +1559,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Mul> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1508,7 +1567,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Div> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1516,7 +1575,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Mod> func; - func(data, size, value, right_operand, res); + func(data, valid_data, size, value, right_operand, res); break; } default: diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h index 3c84819dc2b83..a75c5c3f4dbb1 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -97,6 +97,7 @@ struct ArithOpElementFunc { HighPrecisonType; void operator()(const T* src, + const bool* valid_data, size_t size, HighPrecisonType val, HighPrecisonType right_operand, @@ -239,28 +240,58 @@ struct ArithOpElementFunc { } } */ - - if constexpr (!std::is_same_v::op), - void>) { - constexpr auto cmp_op_cvt = CmpOpHelper::op; - if constexpr (!std::is_same_v::op), + auto execute_sub_batch = [](const T* src, + size_t size, + HighPrecisonType val, + HighPrecisonType right_operand, + TargetBitmapView res) { + if (size == 0) { + return; + } + if constexpr (!std::is_same_v::op), void>) { - constexpr auto arith_op_cvt = ArithOpHelper::op; + constexpr auto cmp_op_cvt = CmpOpHelper::op; + if constexpr (!std::is_same_v< + decltype(ArithOpHelper::op), + void>) { + constexpr auto arith_op_cvt = ArithOpHelper::op; - res.inplace_arith_compare( - src, right_operand, val, size); + res.inplace_arith_compare( + src, right_operand, val, size); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } } else { - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported arith type:{} for ArithOpElementFunc", - arith_op)); + PanicInfo(OpTypeInvalid, + fmt::format( + "unsupported cmp type:{} for ArithOpElementFunc", + cmp_op)); + } + }; + if (valid_data == nullptr) { + return execute_sub_batch(src, size, val, right_operand, res); + } + for (int left = 0; left < size; left++) { + for (int right = left; right < size; right++) { + if (valid_data[right]) { + if (right == size - 1) { + execute_sub_batch(src + left, + right - left, + val, + right_operand, + res + left); + } + continue; + } + execute_sub_batch( + src + left, right - left, val, right_operand, res + left); + left = right; + break; } - } else { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported cmp type:{} for ArithOpElementFunc", - cmp_op)); } } }; @@ -282,22 +313,30 @@ struct ArithOpIndexFunc { HighPrecisonType right_operand) { TargetBitmap res(size); for (size_t i = 0; i < size; ++i) { + if (!index->Reverse_Lookup(i).has_value()) { + res[i] = false; + continue; + } if constexpr (cmp_op == proto::plan::OpType::Equal) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) == val; + res[i] = (index->Reverse_Lookup(i).value() + + right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) == val; + res[i] = (index->Reverse_Lookup(i).value() - + right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) == val; + res[i] = (index->Reverse_Lookup(i).value() * + right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) == val; + res[i] = (index->Reverse_Lookup(i).value() / + right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) == val; + res[i] = (fmod(index->Reverse_Lookup(i).value(), + right_operand)) == val; } else { PanicInfo( OpTypeInvalid, @@ -307,20 +346,24 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::NotEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) != val; + res[i] = (index->Reverse_Lookup(i).value() + + right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) != val; + res[i] = (index->Reverse_Lookup(i).value() - + right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) != val; + res[i] = (index->Reverse_Lookup(i).value() * + right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) != val; + res[i] = (index->Reverse_Lookup(i).value() / + right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) != val; + res[i] = (fmod(index->Reverse_Lookup(i).value(), + right_operand)) != val; } else { PanicInfo( OpTypeInvalid, @@ -330,20 +373,24 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::GreaterThan) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) > val; + res[i] = (index->Reverse_Lookup(i).value() + + right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) > val; + res[i] = (index->Reverse_Lookup(i).value() - + right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) > val; + res[i] = (index->Reverse_Lookup(i).value() * + right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) > val; + res[i] = (index->Reverse_Lookup(i).value() / + right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) > val; + res[i] = (fmod(index->Reverse_Lookup(i).value(), + right_operand)) > val; } else { PanicInfo( OpTypeInvalid, @@ -353,20 +400,24 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::GreaterEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) >= val; + res[i] = (index->Reverse_Lookup(i).value() + + right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) >= val; + res[i] = (index->Reverse_Lookup(i).value() - + right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) >= val; + res[i] = (index->Reverse_Lookup(i).value() * + right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) >= val; + res[i] = (index->Reverse_Lookup(i).value() / + right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) >= val; + res[i] = (fmod(index->Reverse_Lookup(i).value(), + right_operand)) >= val; } else { PanicInfo( OpTypeInvalid, @@ -376,20 +427,24 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::LessThan) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) < val; + res[i] = (index->Reverse_Lookup(i).value() + + right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) < val; + res[i] = (index->Reverse_Lookup(i).value() - + right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) < val; + res[i] = (index->Reverse_Lookup(i).value() * + right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) < val; + res[i] = (index->Reverse_Lookup(i).value() / + right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) < val; + res[i] = (fmod(index->Reverse_Lookup(i).value(), + right_operand)) < val; } else { PanicInfo( OpTypeInvalid, @@ -399,20 +454,24 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::LessEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i) + right_operand) <= val; + res[i] = (index->Reverse_Lookup(i).value() + + right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i) - right_operand) <= val; + res[i] = (index->Reverse_Lookup(i).value() - + right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i) * right_operand) <= val; + res[i] = (index->Reverse_Lookup(i).value() * + right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i) / right_operand) <= val; + res[i] = (index->Reverse_Lookup(i).value() / + right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = - (fmod(index->Reverse_Lookup(i), right_operand)) <= val; + res[i] = (fmod(index->Reverse_Lookup(i).value(), + right_operand)) <= val; } else { PanicInfo( OpTypeInvalid, diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index be6aa576aaaee..94afd3d6abda9 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -246,22 +246,23 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const T* data, + const bool* valid_data, const int size, TargetBitmapView res, HighPrecisionType val1, HighPrecisionType val2) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, size, res); + func(val1, val2, data, valid_data, size, res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, size, res); + func(val1, val2, data, valid_data, size, res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, size, res); + func(val1, val2, data, valid_data, size, res); } else { BinaryRangeElementFunc func; - func(val1, val2, data, size, res); + func(val1, val2, data, valid_data, size, res); } }; auto skip_index_func = @@ -313,22 +314,23 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { auto execute_sub_batch = [lower_inclusive, upper_inclusive, pointer]( const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, ValueType val1, ValueType val2) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res); } else { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, size, res); + func(val1, val2, pointer, data, valid_data, size, res); } }; int64_t processed_size = ProcessDataChunks( @@ -366,6 +368,7 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, ValueType val1, @@ -373,16 +376,16 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { int index) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res); } else { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, size, res); + func(val1, val2, index, data, valid_data, size, res); } }; int64_t processed_size = ProcessDataChunks( diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index 6484a40e5ef1e..e359224e82b35 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -35,25 +35,65 @@ struct BinaryRangeElementFunc { T> HighPrecisionType; void - operator()(T val1, T val2, const T* src, size_t n, TargetBitmapView res) { - if constexpr (lower_inclusive && upper_inclusive) { - res.inplace_within_range_val( - val1, val2, src, n); - } else if constexpr (lower_inclusive && !upper_inclusive) { - res.inplace_within_range_val( - val1, val2, src, n); - } else if constexpr (!lower_inclusive && upper_inclusive) { - res.inplace_within_range_val( - val1, val2, src, n); - } else { - res.inplace_within_range_val( - val1, val2, src, n); + operator()(T val1, + T val2, + const T* src, + const bool* valid_data, + size_t n, + TargetBitmapView res) { + auto execute_sub_batch = [](T val1, + T val2, + const T* src, + size_t n, + TargetBitmapView res) { + if (n == 0) { + return; + } + if constexpr (lower_inclusive && upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else if constexpr (lower_inclusive && !upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else if constexpr (!lower_inclusive && upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else { + res.inplace_within_range_val( + val1, val2, src, n); + } + }; + if (valid_data == nullptr) { + return execute_sub_batch(val1, val2, src, n, res); + } + for (int left = 0; left < n; left++) { + for (int right = left; right < n; right++) { + if (valid_data[right]) { + if (right == n - 1) { + execute_sub_batch( + val1, val2, src + left, right - left, res + left); + } + continue; + } + execute_sub_batch( + val1, val2, src + left, right - left, res + left); + left = right; + break; + } } } }; #define BinaryRangeJSONCompare(cmp) \ do { \ + if (valid_data && !valid_data[i]) { \ + res[i] = false; \ + continue; \ + } \ auto x = src[i].template at(pointer); \ if (x.error()) { \ if constexpr (std::is_same_v) { \ @@ -81,6 +121,7 @@ struct BinaryRangeElementFuncForJson { ValueType val2, const std::string& pointer, const milvus::Json* src, + const bool* valid_data, size_t n, TargetBitmapView res) { for (size_t i = 0; i < n; ++i) { @@ -107,9 +148,14 @@ struct BinaryRangeElementFuncForArray { ValueType val2, int index, const milvus::ArrayView* src, + const bool* valid_data, size_t n, TargetBitmapView res) { for (size_t i = 0; i < n; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (lower_inclusive && upper_inclusive) { if (index >= src[i].length()) { res[i] = false; diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp index 467df6654a929..bb1a11800aa7f 100644 --- a/internal/core/src/exec/expression/CompareExpr.cpp +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -16,6 +16,7 @@ #include "CompareExpr.h" #include "common/type_c.h" +#include #include "query/Relational.h" namespace milvus { @@ -197,80 +198,133 @@ PhyCompareFilterExpr::GetChunkData(DataType data_type, template VectorPtr PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { - if (segment_->is_chunked()) { - auto real_batch_size = GetNextBatchSize(); - if (real_batch_size == 0) { - return nullptr; - } + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); - TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); + auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk_ - 1 + ? active_count_ - chunk_id * size_per_chunk_ + : size_per_chunk_; auto left = GetChunkData(expr_->left_data_type_, expr_->left_field_id_, - is_left_indexed_, - left_current_chunk_id_, - left_current_chunk_pos_); + chunk_id, + left_data_barrier); auto right = GetChunkData(expr_->right_data_type_, expr_->right_field_id_, - is_right_indexed_, - right_current_chunk_id_, - right_current_chunk_pos_); - for (int i = 0; i < real_batch_size; ++i) { - res[i] = boost::apply_visitor( - milvus::query::Relational{}, left(), right()); - } - return res_vec; - } else { - auto real_batch_size = GetNextBatchSize(); - if (real_batch_size == 0) { - return nullptr; - } - - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); - TargetBitmapView res(res_vec->GetRawData(), real_batch_size); - - auto left_data_barrier = - segment_->num_chunk_data(expr_->left_field_id_); - auto right_data_barrier = - segment_->num_chunk_data(expr_->right_field_id_); - - int64_t processed_rows = 0; - for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; - ++chunk_id) { - auto chunk_size = chunk_id == num_chunk_ - 1 - ? active_count_ - chunk_id * size_per_chunk_ - : size_per_chunk_; - auto left = GetChunkData(expr_->left_data_type_, - expr_->left_field_id_, - chunk_id, - left_data_barrier); - auto right = GetChunkData(expr_->right_data_type_, - expr_->right_field_id_, - chunk_id, - right_data_barrier); - - for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; - i < chunk_size; - ++i) { - res[processed_rows++] = boost::apply_visitor( + chunk_id, + right_data_barrier); + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + if (!left(i).has_value() || !right(i).has_value()) { + res[processed_rows] = false; + } else { + res[processed_rows] = boost::apply_visitor( milvus::query::Relational{}, - left(i), - right(i)); + left(i).value(), + right(i).value()); + } + processed_rows++; - if (processed_rows >= batch_size_) { - current_chunk_id_ = chunk_id; - current_chunk_pos_ = i + 1; - return res_vec; - } + if (processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + return res_vec; } } - return res_vec; } + return res_vec; } +// template +// VectorPtr +// PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { +// if (segment_->is_chunked()) { +// auto real_batch_size = GetNextBatchSize(); +// if (real_batch_size == 0) { +// return nullptr; +// } + +// auto res_vec = +// std::make_shared(TargetBitmap(real_batch_size)); +// TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + +// auto left = GetChunkData(expr_->left_data_type_, +// expr_->left_field_id_, +// is_left_indexed_, +// left_current_chunk_id_, +// left_current_chunk_pos_); +// auto right = GetChunkData(expr_->right_data_type_, +// expr_->right_field_id_, +// is_right_indexed_, +// right_current_chunk_id_, +// right_current_chunk_pos_); +// for (int i = 0; i < real_batch_size; ++i) { +// res[i] = boost::apply_visitor( +// milvus::query::Relational{}, left(), right()); +// } +// return res_vec; +// } else { +// auto real_batch_size = GetNextBatchSize(); +// if (real_batch_size == 0) { +// return nullptr; +// } + +// auto res_vec = +// std::make_shared(TargetBitmap(real_batch_size)); +// TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + +// auto left_data_barrier = +// segment_->num_chunk_data(expr_->left_field_id_); +// auto right_data_barrier = +// segment_->num_chunk_data(expr_->right_field_id_); + +// int64_t processed_rows = 0; +// for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; +// ++chunk_id) { +// auto chunk_size = chunk_id == num_chunk_ - 1 +// ? active_count_ - chunk_id * size_per_chunk_ +// : size_per_chunk_; +// auto left = GetChunkData(expr_->left_data_type_, +// expr_->left_field_id_, +// chunk_id, +// left_data_barrier); +// auto right = GetChunkData(expr_->right_data_type_, +// expr_->right_field_id_, +// chunk_id, +// right_data_barrier); + +// for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; +// i < chunk_size; +// ++i) { +// res[processed_rows++] = boost::apply_visitor( +// milvus::query::Relational{}, +// left(i), +// right(i)); + +// if (processed_rows >= batch_size_) { +// current_chunk_id_ = chunk_id; +// current_chunk_pos_ = i + 1; +// return res_vec; +// } +// } +// } +// return res_vec; +// } +// } + template ChunkDataAccessor PhyCompareFilterExpr::GetChunkData(FieldId field_id, @@ -280,12 +334,22 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); + if (!indexing.Reverse_Lookup(i).has_value()) { + return std::nullopt; + } + return indexing.Reverse_Lookup(i).value(); }; } } auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); - return [chunk_data](int i) -> const number { return chunk_data[i]; }; + auto chunk_valid_data = + segment_->chunk_data(field_id, chunk_id).valid_data(); + return [chunk_data, chunk_valid_data](int i) -> const number { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } + return chunk_data[i]; + }; } template <> @@ -297,8 +361,11 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { - return [&indexing](int i) -> const std::string { - return indexing.Reverse_Lookup(i); + return [&indexing](int i) -> const number { + if (!indexing.Reverse_Lookup(i).has_value()) { + return std::nullopt; + } + return indexing.Reverse_Lookup(i).value(); }; } } @@ -308,12 +375,23 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, .growing_enable_mmap) { auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); - return [chunk_data](int i) -> const number { return chunk_data[i]; }; + auto chunk_valid_data = + segment_->chunk_data(field_id, chunk_id).valid_data(); + return [chunk_data, chunk_valid_data](int i) -> const number { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } + return chunk_data[i]; + }; } else { - auto chunk_data = - segment_->chunk_view(field_id, chunk_id) - .first.data(); - return [chunk_data](int i) -> const number { + auto chunk_info = + segment_->chunk_view(field_id, chunk_id); + auto chunk_data = chunk_info.first.data(); + auto chunk_valid_data = chunk_info.second.data(); + return [chunk_data, chunk_valid_data](int i) -> const number { + if (chunk_valid_data && !chunk_valid_data[i]) { + return std::nullopt; + } return std::string(chunk_data[i]); }; } diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index fd9ef751387cb..a4b5dfaab77b2 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -18,6 +18,7 @@ #include #include +#include #include "common/EasyAssert.h" #include "common/Types.h" @@ -29,14 +30,17 @@ namespace milvus { namespace exec { -using number = boost::variant; +using number_type = boost::variant; + +using number = std::optional; + using ChunkDataAccessor = std::function; using MultipleChunkDataAccessor = std::function; @@ -304,6 +308,18 @@ class PhyCompareFilterExpr : public Expr { const T* left_data = left_chunk.data() + data_pos; const U* right_data = right_chunk.data() + data_pos; func(left_data, right_data, size, res + processed_size, values...); + const bool* left_valid_data = left_chunk.valid_data(); + const bool* right_valid_data = right_chunk.valid_data(); + // mask with valid_data + for (int i = 0; i < size; ++i) { + if (left_valid_data && !left_valid_data[i + data_pos]) { + res[processed_size + i] = false; + continue; + } + if (right_valid_data && !right_valid_data[i + data_pos]) { + res[processed_size + i] = false; + } + } processed_size += size; if (processed_size >= batch_size_) { diff --git a/internal/core/src/exec/expression/ExistsExpr.cpp b/internal/core/src/exec/expression/ExistsExpr.cpp index 6798eeedb4210..8f56d4197ba35 100644 --- a/internal/core/src/exec/expression/ExistsExpr.cpp +++ b/internal/core/src/exec/expression/ExistsExpr.cpp @@ -50,10 +50,15 @@ PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer) { for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = data[i].exist(pointer); } }; diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 25f90db4a249f..3806b74b9bc72 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -256,13 +256,13 @@ class SegmentExpr : public Expr { auto& skip_index = segment_->GetSkipIndex(); if (!skip_func || !skip_func(skip_index, field_id_, 0)) { - auto data_vec = - segment_ - ->get_batch_views( - field_id_, 0, current_data_chunk_pos_, need_size) - .first; - - func(data_vec.data(), need_size, res, values...); + auto views_info = segment_->get_batch_views( + field_id_, 0, current_data_chunk_pos_, need_size); + func(views_info.first.data(), + views_info.second.data(), + need_size, + res, + values...); } current_data_chunk_pos_ += need_size; return need_size; @@ -303,7 +303,11 @@ class SegmentExpr : public Expr { if (!skip_func || !skip_func(skip_index, field_id_, i)) { auto chunk = segment_->chunk_data(field_id_, i); const T* data = chunk.data() + data_pos; - func(data, size, res + processed_size, values...); + const bool* valid_data = chunk.valid_data(); + if (valid_data != nullptr) { + valid_data += data_pos; + } + func(data, valid_data, size, res + processed_size, values...); } processed_size += size; diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index da9f3d6aaa895..c91420e577182 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -182,6 +182,7 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { elements.insert(GetValueFromProto(element)); } auto execute_sub_batch = [](const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::unordered_set& elements) { @@ -195,6 +196,10 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { return false; }; for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -231,6 +236,7 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { elements.insert(GetValueFromProto(element)); } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer, @@ -253,6 +259,10 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -285,6 +295,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer, @@ -316,6 +327,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -354,6 +369,7 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { } auto execute_sub_batch = [](const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::unordered_set& elements) { @@ -369,6 +385,10 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { return tmp_elements.size() == 0; }; for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -406,6 +426,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer, @@ -431,6 +452,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { return tmp_elements.size() == 0; }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -467,6 +492,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer, @@ -553,6 +579,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { return tmp_elements_index.size() == 0; }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -590,6 +620,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { } auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer, @@ -625,6 +656,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { return exist_elements_index.size() == elements.size(); }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -662,6 +697,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { auto execute_sub_batch = [](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string& pointer, @@ -739,6 +775,10 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp index 0aaf7a4e69f74..e5322a84a93fa 100644 --- a/internal/core/src/exec/expression/TermExpr.cpp +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -250,6 +250,7 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { ValueType target_val = GetValueFromProto(expr_->vals_[0]); auto execute_sub_batch = [](const ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, const ValueType& target_val) { @@ -263,6 +264,10 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { return false; }; for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } executor(i); } }; @@ -309,11 +314,16 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { } auto execute_sub_batch = [](const ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, int index, const std::unordered_set& term_set) { for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if (index >= data[i].length()) { res[i] = false; continue; @@ -354,6 +364,7 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto execute_sub_batch = [](const Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string pointer, @@ -375,6 +386,10 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { return false; }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -416,6 +431,7 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { } auto execute_sub_batch = [](const Json* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::string pointer, @@ -439,6 +455,10 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { return terms.find(ValueType(x.value())) != terms.end(); }; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = executor(i); } }; @@ -542,11 +562,16 @@ PhyTermFilterExpr::ExecVisitorImplForData() { } std::unordered_set vals_set(vals.begin(), vals.end()); auto execute_sub_batch = [](const T* data, + const bool* valid_data, const int size, TargetBitmapView res, const std::unordered_set& vals) { TermElementFuncSet func; for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = func(vals, data[i]); } }; diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index 3b7c2116244fb..38cbb37ba5a31 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -271,6 +271,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { index = std::stoi(expr_->column_.nested_path_[0]); } auto execute_sub_batch = [op_type](const milvus::ArrayView* data, + const bool* valid_data, const int size, TargetBitmapView res, ValueType val, @@ -279,40 +280,40 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { case proto::plan::GreaterThan: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } case proto::plan::GreaterEqual: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } case proto::plan::LessThan: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } case proto::plan::LessEqual: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } case proto::plan::Equal: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } case proto::plan::NotEqual: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } case proto::plan::PrefixMatch: { UnaryElementFuncForArray func; - func(data, size, val, index, res); + func(data, valid_data, size, val, index, res); break; } default: @@ -492,12 +493,17 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } while (false) auto execute_sub_batch = [op_type, pointer](const milvus::Json* data, + const bool* valid_data, const int size, TargetBitmapView res, ExprValueType val) { switch (op_type) { case proto::plan::GreaterThan: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -508,6 +514,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::GreaterEqual: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -518,6 +528,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::LessThan: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -528,6 +542,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::LessEqual: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -538,6 +556,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::Equal: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -554,6 +576,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::NotEqual: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { auto doc = data[i].doc(); auto array = doc.at_pointer(pointer).get_array(); @@ -570,6 +596,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::PrefixMatch: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -584,6 +614,10 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { auto regex_pattern = translator(val); RegexMatcher matcher(regex_pattern); for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (std::is_same_v) { res[i] = false; } else { @@ -793,48 +827,49 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* data, + const bool* valid_data, const int size, TargetBitmapView res, IndexInnerType val) { switch (expr_type) { case proto::plan::GreaterThan: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::GreaterEqual: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::LessThan: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::LessEqual: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::Equal: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::NotEqual: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::PrefixMatch: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } case proto::plan::Match: { UnaryElementFunc func; - func(data, size, val, res); + func(data, valid_data, size, val, res); break; } default: diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index 83711f6d70dab..612b5ba1d70f9 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -60,6 +60,7 @@ struct UnaryElementFunc { IndexInnerType; void operator()(const T* src, + const bool* valid_data, size_t size, IndexInnerType val, TargetBitmapView res) { @@ -96,33 +97,66 @@ struct UnaryElementFunc { } */ + auto execute_sub_batch = [](const T* src, + size_t size, + IndexInnerType val, + TargetBitmapView res) { + if (size == 0) { + return; + } + if constexpr (op == proto::plan::OpType::Equal) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::NotEqual) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::LessThan) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::LessEqual) { + res.inplace_compare_val( + src, size, val); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported op_type:{} for UnaryElementFunc", + op)); + } + }; + if constexpr (op == proto::plan::OpType::PrefixMatch) { for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } res[i] = milvus::query::Match( src[i], val, proto::plan::OpType::PrefixMatch); } - } else if constexpr (op == proto::plan::OpType::Equal) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::NotEqual) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::GreaterThan) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::LessThan) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::GreaterEqual) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::LessEqual) { - res.inplace_compare_val( - src, size, val); - } else { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported op_type:{} for UnaryElementFunc", op)); + return; + } + if (!valid_data) { + return execute_sub_batch(src, size, val, res); + } + for (int left = 0; left < size; left++) { + for (int right = left; right < size; right++) { + if (valid_data[right]) { + if (right == size - 1) { + execute_sub_batch( + src + left, right - left, val, res + left); + } + continue; + } + execute_sub_batch(src + left, right - left, val, res + left); + left = right; + break; + } } } }; @@ -148,11 +182,16 @@ struct UnaryElementFuncForArray { ValueType>; void operator()(const ArrayView* src, + const bool* valid_data, size_t size, ValueType val, int index, TargetBitmapView res) { for (int i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = false; + continue; + } if constexpr (op == proto::plan::OpType::Equal) { if constexpr (std::is_same_v) { res[i] = src[i].is_same_array(val); @@ -224,7 +263,11 @@ struct UnaryIndexFuncForMatch { RegexMatcher matcher(regex_pattern); for (int64_t i = 0; i < cnt; i++) { auto raw = index->Reverse_Lookup(i); - res[i] = matcher(raw); + if (!raw.has_value()) { + res[i] = false; + continue; + } + res[i] = matcher(raw.value()); } return res; } diff --git a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h index 78833a8d34cd5..e266b5b34ce82 100644 --- a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h +++ b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h @@ -100,7 +100,9 @@ class SealedDataGetter : public DataGetter { } return field_data_->operator[](idx); } else { - return (*field_index_).Reverse_Lookup(idx); + auto value = (*field_index_).Reverse_Lookup(idx); + AssertInfo(value.has_value(), "field data not found"); + return value.value(); } } }; diff --git a/internal/core/src/index/BitmapIndex.cpp b/internal/core/src/index/BitmapIndex.cpp index cc4de8e3bf358..00cd6a58cbcfe 100644 --- a/internal/core/src/index/BitmapIndex.cpp +++ b/internal/core/src/index/BitmapIndex.cpp @@ -80,7 +80,7 @@ BitmapIndex::Build(const Config& config) { template void -BitmapIndex::Build(size_t n, const T* data) { +BitmapIndex::Build(size_t n, const T* data, const bool* valid_data) { if (is_built_) { return; } @@ -93,8 +93,10 @@ BitmapIndex::Build(size_t n, const T* data) { T* p = const_cast(data); for (int i = 0; i < n; ++i, ++p) { - data_[*p].add(i); - valid_bitset.set(i); + if (!valid_data || valid_data[i]) { + data_[*p].add(i); + valid_bitset.set(i); + } } if (data_.size() < DEFAULT_BITMAP_INDEX_BUILD_MODE_BOUND) { @@ -1121,6 +1123,7 @@ BitmapIndex::Reverse_Lookup(size_t idx) const { } } } + return std::nullopt; PanicInfo(UnexpectedError, fmt::format( "scalar bitmap index can not lookup target value of index {}", diff --git a/internal/core/src/index/BitmapIndex.h b/internal/core/src/index/BitmapIndex.h index eb11e75441348..714ddf878c4d4 100644 --- a/internal/core/src/index/BitmapIndex.h +++ b/internal/core/src/index/BitmapIndex.h @@ -77,7 +77,7 @@ class BitmapIndex : public ScalarIndex { } void - Build(size_t n, const T* values) override; + Build(size_t n, const T* values, const bool* valid_data = nullptr) override; void Build(const Config& config = {}) override; @@ -106,7 +106,7 @@ class BitmapIndex : public ScalarIndex { T upper_bound_value, bool ub_inclusive) override; - T + std::optional Reverse_Lookup(size_t offset) const override; int64_t diff --git a/internal/core/src/index/HybridScalarIndex.h b/internal/core/src/index/HybridScalarIndex.h index 0829afc963fbc..8b6e484c2a71f 100644 --- a/internal/core/src/index/HybridScalarIndex.h +++ b/internal/core/src/index/HybridScalarIndex.h @@ -67,10 +67,12 @@ class HybridScalarIndex : public ScalarIndex { } void - Build(size_t n, const T* values) override { + Build(size_t n, + const T* values, + const bool* valid_data = nullptr) override { SelectIndexBuildType(n, values); auto index = GetInternalIndex(); - index->Build(n, values); + index->Build(n, values, valid_data); is_built_ = true; } @@ -133,7 +135,7 @@ class HybridScalarIndex : public ScalarIndex { lower_bound_value, lb_inclusive, upper_bound_value, ub_inclusive); } - T + std::optional Reverse_Lookup(size_t offset) const override { return internal_index_->Reverse_Lookup(offset); } diff --git a/internal/core/src/index/InvertedIndexTantivy.h b/internal/core/src/index/InvertedIndexTantivy.h index 9d7febfd90942..62ecff9f29470 100644 --- a/internal/core/src/index/InvertedIndexTantivy.h +++ b/internal/core/src/index/InvertedIndexTantivy.h @@ -94,7 +94,7 @@ class InvertedIndexTantivy : public ScalarIndex { * deprecated, only used in small chunk index. */ void - Build(size_t n, const T* values) override { + Build(size_t n, const T* values, const bool* valid_data) override { PanicInfo(ErrorCode::NotImplemented, "Build should not be called"); } @@ -136,7 +136,7 @@ class InvertedIndexTantivy : public ScalarIndex { return false; } - T + std::optional Reverse_Lookup(size_t offset) const override { PanicInfo(ErrorCode::NotImplemented, "Reverse_Lookup should not be handled by inverted index"); diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index 6105ce4afb980..6f411179cc3ea 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -80,7 +80,7 @@ class ScalarIndex : public IndexBase { GetIndexType() const = 0; virtual void - Build(size_t n, const T* values) = 0; + Build(size_t n, const T* values, const bool* valid_data = nullptr) = 0; virtual const TargetBitmap In(size_t n, const T* values) = 0; @@ -117,7 +117,7 @@ class ScalarIndex : public IndexBase { T upper_bound_value, bool ub_inclusive) = 0; - virtual T + virtual std::optional Reverse_Lookup(size_t offset) const = 0; virtual const TargetBitmap diff --git a/internal/core/src/index/ScalarIndexSort.cpp b/internal/core/src/index/ScalarIndexSort.cpp index 56396d7d192f3..873f475e06df9 100644 --- a/internal/core/src/index/ScalarIndexSort.cpp +++ b/internal/core/src/index/ScalarIndexSort.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -61,7 +62,7 @@ ScalarIndexSort::Build(const Config& config) { template void -ScalarIndexSort::Build(size_t n, const T* values) { +ScalarIndexSort::Build(size_t n, const T* values, const bool* valid_data) { if (is_built_) return; if (n == 0) { @@ -70,12 +71,16 @@ ScalarIndexSort::Build(size_t n, const T* values) { data_.reserve(n); total_num_rows_ = n; valid_bitset = TargetBitmap(total_num_rows_, false); - idx_to_offsets_.resize(n); + idx_to_offsets_.resize(n, -1); + T* p = const_cast(values); - for (size_t i = 0; i < n; ++i) { - data_.emplace_back(IndexStructure(*p++, i)); - valid_bitset.set(i); + for (size_t i = 0; i < n; ++i, ++p) { + if (!valid_data || valid_data[i]) { + data_.emplace_back(IndexStructure(*p, i)); + valid_bitset.set(i); + } } + std::sort(data_.begin(), data_.end()); for (size_t i = 0; i < data_.size(); ++i) { idx_to_offsets_[data_[i].idx_] = i; @@ -112,7 +117,7 @@ ScalarIndexSort::BuildWithFieldData( } std::sort(data_.begin(), data_.end()); - idx_to_offsets_.resize(total_num_rows_); + idx_to_offsets_.resize(total_num_rows_, -1); for (size_t i = 0; i < length; ++i) { idx_to_offsets_[data_[i].idx_] = i; } @@ -174,7 +179,7 @@ ScalarIndexSort::LoadWithoutAssemble(const BinarySet& index_binary, memcpy(&total_num_rows_, index_num_rows->data.get(), (size_t)index_num_rows->size); - idx_to_offsets_.resize(total_num_rows_); + idx_to_offsets_.resize(total_num_rows_, -1); valid_bitset = TargetBitmap(total_num_rows_, false); memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size); for (size_t i = 0; i < data_.size(); ++i) { @@ -355,12 +360,15 @@ ScalarIndexSort::Range(T lower_bound_value, } template -T +std::optional ScalarIndexSort::Reverse_Lookup(size_t idx) const { AssertInfo(idx < idx_to_offsets_.size(), "out of range of total count"); AssertInfo(is_built_, "index has not been built"); auto offset = idx_to_offsets_[idx]; + if (offset < 0) { + return std::nullopt; + } return data_[offset].a_; } diff --git a/internal/core/src/index/ScalarIndexSort.h b/internal/core/src/index/ScalarIndexSort.h index fb33f030c2a03..f9ae09b3b3c69 100644 --- a/internal/core/src/index/ScalarIndexSort.h +++ b/internal/core/src/index/ScalarIndexSort.h @@ -56,7 +56,7 @@ class ScalarIndexSort : public ScalarIndex { } void - Build(size_t n, const T* values) override; + Build(size_t n, const T* values, const bool* valid_data = nullptr) override; void Build(const Config& config = {}) override; @@ -82,7 +82,7 @@ class ScalarIndexSort : public ScalarIndex { T upper_bound_value, bool ub_inclusive) override; - T + std::optional Reverse_Lookup(size_t offset) const override; int64_t diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index e3c853193571a..3d71ecd8e6e35 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -118,7 +119,9 @@ StringIndexMarisa::BuildWithFieldData( } void -StringIndexMarisa::Build(size_t n, const std::string* values) { +StringIndexMarisa::Build(size_t n, + const std::string* values, + const bool* valid_data) { if (built_) { PanicInfo(IndexAlreadyBuild, "index has been built"); } @@ -127,7 +130,9 @@ StringIndexMarisa::Build(size_t n, const std::string* values) { { // fill key set. for (size_t i = 0; i < n; i++) { - keyset.push_back(values[i].c_str()); + if (!valid_data || valid_data[i]) { + keyset.push_back(values[i].c_str()); + } } } @@ -534,11 +539,13 @@ StringIndexMarisa::prefix_match(const std::string_view prefix) { } return ret; } - -std::string +std::optional StringIndexMarisa::Reverse_Lookup(size_t offset) const { AssertInfo(offset < str_ids_.size(), "out of range of total count"); marisa::Agent agent; + if (str_ids_[offset] < 0) { + return std::nullopt; + } agent.set_query(str_ids_[offset]); trie_.reverse_lookup(agent); return std::string(agent.key().ptr(), agent.key().length()); diff --git a/internal/core/src/index/StringIndexMarisa.h b/internal/core/src/index/StringIndexMarisa.h index 72913d6675987..9bf44d73619c8 100644 --- a/internal/core/src/index/StringIndexMarisa.h +++ b/internal/core/src/index/StringIndexMarisa.h @@ -55,7 +55,9 @@ class StringIndexMarisa : public StringIndex { } void - Build(size_t n, const std::string* values) override; + Build(size_t n, + const std::string* values, + const bool* valid_data = nullptr) override; void Build(const Config& config = {}) override; @@ -87,7 +89,7 @@ class StringIndexMarisa : public StringIndex { const TargetBitmap PrefixMatch(const std::string_view prefix) override; - std::string + std::optional Reverse_Lookup(size_t offset) const override; BinarySet diff --git a/internal/core/src/query/ScalarIndex.h b/internal/core/src/query/ScalarIndex.h index eb9d0f3a18687..b72a68c6d5bc8 100644 --- a/internal/core/src/query/ScalarIndex.h +++ b/internal/core/src/query/ScalarIndex.h @@ -26,7 +26,7 @@ template inline index::ScalarIndexPtr generate_scalar_index(Span data) { auto indexing = std::make_unique>(); - indexing->Build(data.row_count(), data.data()); + indexing->Build(data.row_count(), data.data(), data.valid_data()); return indexing; } @@ -34,7 +34,7 @@ template <> inline index::ScalarIndexPtr generate_scalar_index(Span data) { auto indexing = index::CreateStringIndexSort(); - indexing->Build(data.row_count(), data.data()); + indexing->Build(data.row_count(), data.data(), data.valid_data()); return indexing; } diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index eb81947fcdba4..8c924e24ba01e 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -299,6 +299,7 @@ ScalarFieldIndexing::BuildIndexRange(int64_t ack_beg, for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) { auto chunk_data = source->get_chunk_data(chunk_id); // build index for chunk + // seem no lint, not pass valid_data here // TODO if constexpr (std::is_same_v) { auto indexing = index::CreateStringIndexSort(); diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 9fff1a9d09410..7a2a2dd264e14 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -198,8 +198,10 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && int64_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk(int64_index->Reverse_Lookup(i), - i); + AssertInfo(int64_index->Reverse_Lookup(i).has_value(), + "Primary key not found"); + insert_record_.insert_pk( + int64_index->Reverse_Lookup(i).value(), i); } insert_record_.seal_pks(); } @@ -212,8 +214,10 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && string_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { + AssertInfo(string_index->Reverse_Lookup(i).has_value(), + "Primary key not found"); insert_record_.insert_pk( - string_index->Reverse_Lookup(i), i); + string_index->Reverse_Lookup(i).value(), i); } insert_record_.seal_pks(); } diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index e0bd00007b461..a778d86cb1f43 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -691,7 +691,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_bool_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -702,7 +706,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -713,7 +721,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -724,7 +736,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -735,7 +751,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_long_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -746,7 +766,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_float_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -757,7 +781,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_double_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -768,7 +796,11 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]); + auto value = ptr->Reverse_Lookup(seg_offsets[i]); + if (!value.has_value()) { + continue; + } + raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_string_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 2bfc4646d10af..c45cc4e3b12fc 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -379,6 +379,254 @@ TEST_P(ExprTest, TestRange) { } } +TEST_P(ExprTest, TestRangeNullable) { + std::vector>> + testcases = { + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: false, + upper_inclusive: false, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 < v && v < 3000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: true, + upper_inclusive: false, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 <= v && v < 3000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: false, + upper_inclusive: true, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 < v && v <= 3000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + lower_inclusive: true, + upper_inclusive: true, + lower_value: < + int64_val: 2000 + > + upper_value: < + int64_val: 3000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 <= v && v <= 3000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: GreaterEqual, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v >= 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: GreaterThan, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v > 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: LessEqual, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v <= 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: LessThan, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v < 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: Equal, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Int64 + > + op: NotEqual, + value: < + int64_val: 2000 + > + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v != 2000; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age", DataType::INT64); + schema->set_primary_field_id(i64_fid); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector data_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_data_col = raw_data.get_col(i64_fid); + valid_data_col = raw_data.get_col_valid(nullable_fid); + data_col.insert( + data_col.end(), new_data_col.begin(), new_data_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = data_col[i]; + auto valid_data = valid_data_col[i]; + auto ref = ref_func(val, valid_data); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } + } +} + TEST_P(ExprTest, TestBinaryRangeJSON) { struct Testcase { bool lower_inclusive; @@ -475,25 +723,34 @@ TEST_P(ExprTest, TestBinaryRangeJSON) { } } -TEST_P(ExprTest, TestExistsJson) { +TEST_P(ExprTest, TestBinaryRangeJSONNullable) { struct Testcase { + bool lower_inclusive; + bool upper_inclusive; + int64_t lower; + int64_t upper; std::vector nested_path; }; std::vector testcases{ - {{"A"}}, - {{"int"}}, - {{"double"}}, - {{"B"}}, + {true, false, 10, 20, {"int"}}, + {true, true, 20, 30, {"int"}}, + {false, true, 30, 40, {"int"}}, + {false, false, 40, 50, {"int"}}, + {true, false, 10, 20, {"double"}}, + {true, true, 20, 30, {"double"}}, + {false, true, 30, 40, {"double"}}, + {false, false, 40, 50, {"double"}}, }; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); @@ -501,6 +758,7 @@ TEST_P(ExprTest, TestExistsJson) { json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); + valid_data_col = raw_data.get_col_valid(json_fid); seg->PreInsert(N); seg->Insert(iter * N, N, @@ -511,11 +769,32 @@ TEST_P(ExprTest, TestExistsJson) { auto seg_promote = dynamic_cast(seg.get()); for (auto testcase : testcases) { - auto check = [&](bool value) { return value; }; + auto check = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + int64_t lower = testcase.lower, upper = testcase.upper; + if (!testcase.lower_inclusive) { + lower++; + } + if (!testcase.upper_inclusive) { + upper--; + } + return lower <= value && value <= upper; + }; auto pointer = milvus::Json::pointer(testcase.nested_path); - auto expr = - std::make_shared(milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path)); + RetrievePlanNode plan; + milvus::proto::plan::GenericValue lower_val; + lower_val.set_int64_val(testcase.lower); + milvus::proto::plan::GenericValue upper_val; + upper_val.set_int64_val(testcase.upper); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + lower_val, + upper_val, + testcase.lower_inclusive, + testcase.upper_inclusive); BitsetType final; auto plannode = std::make_shared(DEFAULT_PLANNODE_ID, expr); @@ -525,38 +804,179 @@ TEST_P(ExprTest, TestExistsJson) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist(pointer); - auto ref = check(val); - ASSERT_EQ(ans, ref); + + if (testcase.nested_path[0] == "int") { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref) + << val << testcase.lower_inclusive << testcase.lower + << testcase.upper_inclusive << testcase.upper; + } else { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref) + << val << testcase.lower_inclusive << testcase.lower + << testcase.upper_inclusive << testcase.upper; + } } } } -template -T -GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) { - if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kBoolVal); - return static_cast(value_proto.bool_val()); - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kInt64Val); - return static_cast(value_proto.int64_val()); - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kFloatVal); - return static_cast(value_proto.float_val()); - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kStringVal); - return static_cast(value_proto.string_val()); - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == - milvus::proto::plan::GenericValue::kArrayVal); - return static_cast(value_proto.array_val()); - } else if constexpr (std::is_same_v) { +TEST_P(ExprTest, TestExistsJson) { + struct Testcase { + std::vector nested_path; + }; + std::vector testcases{ + {{"A"}}, + {{"int"}}, + {{"double"}}, + {{"B"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](bool value) { return value; }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = + std::make_shared(milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path)); + BitsetType final; + plan.filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode( + plan.filter_plannode_.value(), seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist(pointer); + auto ref = check(val); + ASSERT_EQ(ans, ref); + } + } +} + +TEST_P(ExprTest, TestExistsJsonNullable) { + struct Testcase { + std::vector nested_path; + }; + std::vector testcases{ + {{"A"}}, + {{"int"}}, + {{"double"}}, + {{"B"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + valid_data_col = raw_data.get_col_valid(json_fid); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](bool value, bool valid) { + if (!valid) { + return false; + } + return value; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = + std::make_shared(milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path)); + BitsetType final; + plan.filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode( + plan.filter_plannode_.value(), seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist(pointer); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } + } +} + +template +T +GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) { + if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kBoolVal); + return static_cast(value_proto.bool_val()); + } else if constexpr (std::is_integral_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kInt64Val); + return static_cast(value_proto.int64_val()); + } else if constexpr (std::is_floating_point_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kFloatVal); + return static_cast(value_proto.float_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kStringVal); + return static_cast(value_proto.string_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kArrayVal); + return static_cast(value_proto.array_val()); + } else if constexpr (std::is_same_v) { return static_cast(value_proto); } else { PanicInfo(milvus::ErrorCode::UnexpectedError, @@ -734,30 +1154,36 @@ TEST_P(ExprTest, TestUnaryRangeJson) { } } -TEST_P(ExprTest, TestTermJson) { +TEST_P(ExprTest, TestUnaryRangeJsonNullable) { struct Testcase { - std::vector term; + int64_t val; std::vector nested_path; }; std::vector testcases{ - {{1, 2, 3, 4}, {"int"}}, - {{10, 100, 1000, 10000}, {"int"}}, - {{100, 10000, 9999, 444}, {"int"}}, - {{23, 42, 66, 17, 25}, {"int"}}, + {10, {"int"}}, + {20, {"int"}}, + {30, {"int"}}, + {40, {"int"}}, + {10, {"double"}}, + {20, {"double"}}, + {30, {"double"}}, + {40, {"double"}}, }; auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + FixedVector valid_data_col; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data_col = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -770,53 +1196,331 @@ TEST_P(ExprTest, TestTermJson) { } auto seg_promote = dynamic_cast(seg.get()); - for (auto testcase : testcases) { - auto check = [&](int64_t value) { - std::unordered_set term_set(testcase.term.begin(), - testcase.term.end()); - return term_set.find(value) != term_set.end(); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + std::vector ops{ + OpType::Equal, + OpType::NotEqual, + OpType::GreaterThan, + OpType::GreaterEqual, + OpType::LessThan, + OpType::LessEqual, + }; + for (const auto& testcase : testcases) { + auto check = [&](int64_t value, bool valid) { + return value == testcase.val; }; - auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (const auto& val : testcase.term) { + std::function f = check; + for (auto& op : ops) { + switch (op) { + case OpType::Equal: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value == testcase.val; + }; + break; + } + case OpType::NotEqual: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value != testcase.val; + }; + break; + } + case OpType::GreaterEqual: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value >= testcase.val; + }; + break; + } + case OpType::GreaterThan: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value > testcase.val; + }; + break; + } + case OpType::LessEqual: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value <= testcase.val; + }; + break; + } + case OpType::LessThan: { + f = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + return value < testcase.val; + }; + break; + } + default: { + PanicInfo(Unsupported, "unsupported range node"); + } + } + + auto pointer = milvus::Json::pointer(testcase.nested_path); proto::plan::GenericValue value; - value.set_int64_val(val); - values.push_back(value); - } - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - values); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N * num_iters); + value.set_int64_val(testcase.val); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + op, + value); + BitsetType final; + auto plan = std::make_shared( + DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + EXPECT_EQ(final.size(), N * num_iters); - for (int i = 0; i < N * num_iters; ++i) { - auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(pointer) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (testcase.nested_path[0] == "int") { + auto val = + milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = f(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } else { + auto val = + milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = f(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } + } } } -} -TEST_P(ExprTest, TestTerm) { - auto vec_2k_3k = [] { - std::string buf; - for (int i = 2000; i < 3000; ++i) { - buf += "values: < int64_val: " + std::to_string(i) + " >\n"; - } - return buf; - }(); + struct TestArrayCase { + proto::plan::GenericValue val; + std::vector nested_path; + }; - std::vector>> testcases = { - {R"(values: < - int64_val: 2000 + proto::plan::GenericValue value; + auto* arr = value.mutable_array_val(); + arr->set_same_type(true); + proto::plan::GenericValue int_val1; + int_val1.set_int64_val(int64_t(1)); + arr->add_array()->CopyFrom(int_val1); + + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(2)); + arr->add_array()->CopyFrom(int_val2); + + proto::plan::GenericValue int_val3; + int_val3.set_int64_val(int64_t(3)); + arr->add_array()->CopyFrom(int_val3); + + std::vector array_cases = {{value, {"array"}}}; + for (const auto& testcase : array_cases) { + auto check = [&](OpType op, bool valid) { + if (!valid) { + return false; + } + if (testcase.nested_path[0] == "array" && op == OpType::Equal) { + return true; + } + return false; + }; + for (auto& op : ops) { + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + op, + testcase.val); + BitsetType final; + auto plan = std::make_shared( + DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = check(op, valid_data_col[i]); + ASSERT_EQ(ans, ref) << "@" << i << "op" << op; + } + } + } +} + +TEST_P(ExprTest, TestTermJson) { + struct Testcase { + std::vector term; + std::vector nested_path; + }; + std::vector testcases{ + {{1, 2, 3, 4}, {"int"}}, + {{10, 100, 1000, 10000}, {"int"}}, + {{100, 10000, 9999, 444}, {"int"}}, + {{23, 42, 66, 17, 25}, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](int64_t value) { + std::unordered_set term_set(testcase.term.begin(), + testcase.term.end()); + return term_set.find(value) != term_set.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue value; + value.set_int64_val(val); + values.push_back(value); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref); + } + } +} + +TEST_P(ExprTest, TestTermJsonNullable) { + struct Testcase { + std::vector term; + std::vector nested_path; + }; + std::vector testcases{ + {{1, 2, 3, 4}, {"int"}}, + {{10, 100, 1000, 10000}, {"int"}}, + {{100, 10000, 9999, 444}, {"int"}}, + {{23, 42, 66, 17, 25}, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data_col; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + auto new_valid_data_col = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + valid_data_col.insert(valid_data_col.end(), + new_valid_data_col.begin(), + new_valid_data_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set(testcase.term.begin(), + testcase.term.end()); + return term_set.find(value) != term_set.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue value; + value.set_int64_val(val); + values.push_back(value); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data_col[i]); + ASSERT_EQ(ans, ref); + } + } +} + +TEST_P(ExprTest, TestTerm) { + auto vec_2k_3k = [] { + std::string buf; + for (int i = 2000; i < 3000; ++i) { + buf += "values: < int64_val: " + std::to_string(i) + " >\n"; + } + return buf; + }(); + + std::vector>> testcases = { + {R"(values: < + int64_val: 2000 > values: < int64_val: 3000 @@ -901,30 +1605,73 @@ TEST_P(ExprTest, TestTerm) { } } -TEST_P(ExprTest, TestCompare) { - std::vector>> +TEST_P(ExprTest, TestTermNullable) { + auto vec_2k_3k = [] { + std::string buf; + for (int i = 2000; i < 3000; ++i) { + buf += "values: < int64_val: " + std::to_string(i) + " >\n"; + } + return buf; + }(); + + std::vector>> testcases = { - {R"(LessThan)", [](int a, int64_t b) { return a < b; }}, - {R"(LessEqual)", [](int a, int64_t b) { return a <= b; }}, - {R"(GreaterThan)", [](int a, int64_t b) { return a > b; }}, - {R"(GreaterEqual)", [](int a, int64_t b) { return a >= b; }}, - {R"(Equal)", [](int a, int64_t b) { return a == b; }}, - {R"(NotEqual)", [](int a, int64_t b) { return a != b; }}, + {R"(values: < + int64_val: 2000 + > + values: < + int64_val: 3000 + > + )", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 2000 || v == 3000; + }}, + {R"(values: < + int64_val: 2000 + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 2000; + }}, + {R"(values: < + int64_val: 3000 + >)", + [](int v, bool valid) { + if (!valid) { + return false; + } + return v == 3000; + }}, + {R"()", + [](int v, bool valid) { + if (!valid) { + return false; + } + return false; + }}, + {vec_2k_3k, + [](int v, bool valid) { + if (!valid) { + return false; + } + return 2000 <= v && v < 3000; + }}, }; std::string raw_plan_tmp = R"(vector_anns: < field_id: 100 predicates: < - compare_expr: < - left_column_info: < - field_id: 101 - data_type: Int32 - > - right_column_info: < + term_expr: < + column_info: < field_id: 102 data_type: Int64 > - op: @@@@ + @@@@ > > query_info: < @@ -937,23 +1684,26 @@ TEST_P(ExprTest, TestCompare) { >)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i32_fid = schema->AddDebugField("age1", DataType::INT32); - auto i64_fid = schema->AddDebugField("age2", DataType::INT64); + auto i64_fid = schema->AddDebugField("age", DataType::INT64); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - std::vector age1_col; - std::vector age2_col; - int num_iters = 1; + std::vector nullable_col; + FixedVector valid_data_col; + int num_iters = 100; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); - auto new_age1_col = raw_data.get_col(i32_fid); - auto new_age2_col = raw_data.get_col(i64_fid); - age1_col.insert( - age1_col.end(), new_age1_col.begin(), new_age1_col.end()); - age2_col.insert( - age2_col.end(), new_age2_col.begin(), new_age2_col.end()); + auto new_nullable_col = raw_data.get_col(nullable_fid); + auto new_valid_data_col = raw_data.get_col_valid(nullable_fid); + valid_data_col.insert(valid_data_col.end(), + new_valid_data_col.begin(), + new_valid_data_col.end()); + nullable_col.insert(nullable_col.end(), + new_nullable_col.begin(), + new_nullable_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, @@ -981,16 +1731,14 @@ TEST_P(ExprTest, TestCompare) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val1 = age1_col[i]; - auto val2 = age2_col[i]; - auto ref = ref_func(val1, val2); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" - << boost::format("[%1%, %2%]") % val1 % val2; + auto val = nullable_col[i]; + auto ref = ref_func(val, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } } } -TEST_P(ExprTest, TestCompareWithScalarIndex) { +TEST_P(ExprTest, TestCompare) { std::vector>> testcases = { {R"(LessThan)", [](int a, int64_t b) { return a < b; }}, @@ -1001,85 +1749,77 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { {R"(NotEqual)", [](int a, int64_t b) { return a != b; }}, }; - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - compare_expr: < - left_column_info: < - field_id: %3% - data_type: %4% - > - right_column_info: < - field_id: %5% - data_type: %6% - > - op: %2% - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + compare_expr: < + left_column_info: < + field_id: 101 + data_type: Int32 + > + right_column_info: < + field_id: 102 + data_type: Int64 + > + op: @@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" >)"; - auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i32_fid = schema->AddDebugField("age32", DataType::INT32); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto i32_fid = schema->AddDebugField("age1", DataType::INT32); + auto i64_fid = schema->AddDebugField("age2", DataType::INT64); schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); + auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - auto raw_data = DataGen(schema, N); - segcore::LoadIndexInfo load_index_info; - - // load index for int32 field - auto age32_col = raw_data.get_col(i32_fid); - age32_col[0] = 1000; - GenScalarIndexing(N, age32_col.data()); - auto age32_index = milvus::index::CreateScalarIndexSort(); - age32_index->Build(N, age32_col.data()); - load_index_info.field_id = i32_fid.get(); - load_index_info.field_type = DataType::INT32; - load_index_info.index = std::move(age32_index); - seg->LoadIndex(load_index_info); - - // load index for int64 field - auto age64_col = raw_data.get_col(i64_fid); - age64_col[0] = 2000; - GenScalarIndexing(N, age64_col.data()); - auto age64_index = milvus::index::CreateScalarIndexSort(); - age64_index->Build(N, age64_col.data()); - load_index_info.field_id = i64_fid.get(); - load_index_info.field_type = DataType::INT64; - load_index_info.index = std::move(age64_index); - seg->LoadIndex(load_index_info); + std::vector age1_col; + std::vector age2_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_age1_col = raw_data.get_col(i32_fid); + auto new_age2_col = raw_data.get_col(i64_fid); + age1_col.insert( + age1_col.end(), new_age1_col.begin(), new_age1_col.end()); + age2_col.insert( + age2_col.end(), new_age2_col.begin(), new_age2_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { - auto dsl_string = - boost::format(serialized_expr_plan) % vec_fid.get() % clause % - i32_fid.get() % proto::schema::DataType_Name(int(DataType::INT32)) % - i64_fid.get() % proto::schema::DataType_Name(int(DataType::INT64)); - auto binary_plan = - translate_text_plan_with_metric_type(dsl_string.str()); - auto plan = CreateSearchPlanByExpr( - *schema, binary_plan.data(), binary_plan.size()); - // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg.get(), - N, - MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); - for (int i = 0; i < N; ++i) { + for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val1 = age32_col[i]; - auto val2 = age64_col[i]; + + auto val1 = age1_col[i]; + auto val2 = age2_col[i]; auto ref = ref_func(val1, val2); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << boost::format("[%1%, %2%]") % val1 % val2; @@ -1087,515 +1827,628 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { } } -TEST_P(ExprTest, TestCompareExpr) { +TEST_P(ExprTest, TestCompareNullable) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + compare_expr: < + left_column_info: < + field_id: 101 + data_type: Int32 + > + right_column_info: < + field_id: 103 + data_type: Int64 + > + op: @@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); - auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto str3_fid = schema->AddDebugField("string3", DataType::VARCHAR); - schema->set_primary_field_id(str1_fid); + auto i32_fid = schema->AddDebugField("age1", DataType::INT32); + auto i64_fid = schema->AddDebugField("age2", DataType::INT64); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); + schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); - size_t N = 1000; - auto raw_data = DataGen(schema, N); - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age1_col; + std::vector nullable_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_age1_col = raw_data.get_col(i32_fid); + auto new_nullable_col = raw_data.get_col(nullable_fid); + valid_data_col = raw_data.get_col_valid(nullable_fid); + age1_col.insert( + age1_col.end(), new_age1_col.begin(), new_age1_col.end()); + nullable_col.insert(nullable_col.end(), + new_nullable_col.begin(), + new_nullable_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); - seg->LoadFieldData(FieldId(field_id), info); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val1 = age1_col[i]; + auto val2 = nullable_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } } +} - auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { - switch (type) { - case DataType::BOOL: { - auto compare_expr = std::make_shared( - bool_fid, - bool_1_fid, - DataType::BOOL, - DataType::BOOL, - proto::plan::OpType::LessThan); - return compare_expr; - } - case DataType::INT8: { - auto compare_expr = - std::make_shared(int8_fid, - int8_1_fid, - DataType::INT8, - DataType::INT8, - OpType::LessThan); - return compare_expr; - } - case DataType::INT16: { - auto compare_expr = - std::make_shared(int16_fid, - int16_1_fid, - DataType::INT16, - DataType::INT16, - OpType::LessThan); - return compare_expr; - } - case DataType::INT32: { - auto compare_expr = - std::make_shared(int32_fid, - int32_1_fid, - DataType::INT32, - DataType::INT32, - OpType::LessThan); - return compare_expr; - } - case DataType::INT64: { - auto compare_expr = - std::make_shared(int64_fid, - int64_1_fid, - DataType::INT64, - DataType::INT64, - OpType::LessThan); - return compare_expr; - } - case DataType::FLOAT: { - auto compare_expr = - std::make_shared(float_fid, - float_1_fid, - DataType::FLOAT, - DataType::FLOAT, - OpType::LessThan); - return compare_expr; - } - case DataType::DOUBLE: { - auto compare_expr = - std::make_shared(double_fid, - double_1_fid, - DataType::DOUBLE, - DataType::DOUBLE, - OpType::LessThan); - return compare_expr; - } - case DataType::VARCHAR: { - auto compare_expr = - std::make_shared(str2_fid, - str3_fid, - DataType::VARCHAR, - DataType::VARCHAR, - OpType::LessThan); - return compare_expr; - } - default: - return std::make_shared(int8_fid, - int8_1_fid, - DataType::INT8, - DataType::INT8, - OpType::LessThan); - } - }; - std::cout << "start compare test" << std::endl; - auto expr = build_expr(DataType::BOOL); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT8); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT16); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT32); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::INT64); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::FLOAT); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - expr = build_expr(DataType::DOUBLE); - plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - std::cout << "end compare test" << std::endl; -} +TEST_P(ExprTest, TestCompareNullable2) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; -TEST(Expr, TestExprPerformance) { - GTEST_SKIP() << "Skip performance test, open it when test performance"; + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + compare_expr: < + left_column_info: < + field_id: 103 + data_type: Int64 + > + right_column_info: < + field_id: 101 + data_type: Int32 + > + op: @@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; auto schema = std::make_shared(); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); - - std::map fids = {{DataType::INT8, int8_fid}, - {DataType::INT16, int16_fid}, - {DataType::INT32, int32_fid}, - {DataType::INT64, int64_fid}, - {DataType::VARCHAR, str2_fid}, - {DataType::FLOAT, float_fid}, - {DataType::DOUBLE, double_fid}}; + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i32_fid = schema->AddDebugField("age1", DataType::INT32); + auto i64_fid = schema->AddDebugField("age2", DataType::INT64); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT64, true); + schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age1_col; + std::vector nullable_col; + FixedVector valid_data_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_age1_col = raw_data.get_col(i32_fid); + auto new_nullable_col = raw_data.get_col(nullable_fid); + valid_data_col = raw_data.get_col_valid(nullable_fid); + age1_col.insert( + age1_col.end(), new_age1_col.begin(), new_age1_col.end()); + nullable_col.insert(nullable_col.end(), + new_nullable_col.begin(), + new_nullable_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; - seg->LoadFieldData(FieldId(field_id), info); + auto val2 = age1_col[i]; + auto val1 = nullable_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } } +} - enum ExprType { - UnaryRangeExpr = 0, - TermExprImpl = 1, - CompareExpr = 2, - LogicalUnaryExpr = 3, - BinaryRangeExpr = 4, - LogicalBinaryExpr = 5, - BinaryArithOpEvalRangeExpr = 6, - }; - - auto build_unary_range_expr = [&](DataType data_type, - int64_t value) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - proto::plan::GenericValue val; - val.set_int64_val(value); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::LessThan, - val); - } else if (IsFloatDataType(data_type)) { - proto::plan::GenericValue val; - val.set_float_val(float(value)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::LessThan, - val); - } else if (IsStringDataType(data_type)) { - proto::plan::GenericValue val; - val.set_string_val(std::to_string(value)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::LessThan, - val); - } else { - throw std::runtime_error("not supported type"); - } - }; +TEST_P(ExprTest, TestCompareWithScalarIndex) { + std::vector>> + testcases = { + {R"(LessThan)", [](int a, int64_t b) { return a < b; }}, + {R"(LessEqual)", [](int a, int64_t b) { return a <= b; }}, + {R"(GreaterThan)", [](int a, int64_t b) { return a > b; }}, + {R"(GreaterEqual)", [](int a, int64_t b) { return a >= b; }}, + {R"(Equal)", [](int a, int64_t b) { return a == b; }}, + {R"(NotEqual)", [](int a, int64_t b) { return a != b; }}, + }; - auto build_binary_range_expr = [&](DataType data_type, - int64_t low, - int64_t high) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_int64_val(low); - proto::plan::GenericValue val2; - val2.set_int64_val(high); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - val1, - val2, - true, - true); - } else if (IsFloatDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_float_val(float(low)); - proto::plan::GenericValue val2; - val2.set_float_val(float(high)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - val1, - val2, - true, - true); - } else if (IsStringDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_string_val(std::to_string(low)); - proto::plan::GenericValue val2; - val2.set_string_val(std::to_string(low)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - val1, - val2, - true, - true); - } else { - throw std::runtime_error("not supported type"); - } - }; + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: %4% + > + right_column_info: < + field_id: %5% + data_type: %6% + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; - auto build_term_expr = - [&](DataType data_type, - std::vector in_vals) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - std::vector vals; - for (auto& v : in_vals) { - proto::plan::GenericValue val; - val.set_int64_val(v); - vals.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), vals, false); - } else if (IsFloatDataType(data_type)) { - std::vector vals; - for (auto& v : in_vals) { - proto::plan::GenericValue val; - val.set_float_val(float(v)); - vals.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), vals, false); - } else if (IsStringDataType(data_type)) { - std::vector vals; - for (auto& v : in_vals) { - proto::plan::GenericValue val; - val.set_string_val(std::to_string(v)); - vals.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), vals, false); - } else { - throw std::runtime_error("not supported type"); - } - }; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i32_fid = schema->AddDebugField("age32", DataType::INT32); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); - auto build_compare_expr = [&](DataType data_type) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type) || IsFloatDataType(data_type) || - IsStringDataType(data_type)) { - return std::make_shared( - fids[data_type], - fids[data_type], - data_type, - data_type, - proto::plan::OpType::LessThan); - } else { - throw std::runtime_error("not supported type"); - } - }; + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; - auto build_logical_unary_expr = - [&](DataType data_type) -> expr::TypedExprPtr { - auto child_expr = build_unary_range_expr(data_type, 10); - return std::make_shared( - expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); - }; + // load index for int32 field + auto age32_col = raw_data.get_col(i32_fid); + age32_col[0] = 1000; + auto age32_index = milvus::index::CreateScalarIndexSort(); + age32_index->Build(N, age32_col.data()); + load_index_info.field_id = i32_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(age32_index); + seg->LoadIndex(load_index_info); - auto build_logical_binary_expr = - [&](DataType data_type) -> expr::TypedExprPtr { - auto child1_expr = build_unary_range_expr(data_type, 10); - auto child2_expr = build_unary_range_expr(data_type, 10); - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - }; + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data()); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); - auto build_multi_logical_binary_expr = - [&](DataType data_type) -> expr::TypedExprPtr { - auto child1_expr = build_unary_range_expr(data_type, 100); - auto child2_expr = build_unary_range_expr(data_type, 100); - auto child3_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - auto child4_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - auto child5_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); - auto child6_expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child5_expr, child6_expr); - }; + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = + boost::format(serialized_expr_plan) % vec_fid.get() % clause % + i32_fid.get() % proto::schema::DataType_Name(int(DataType::INT32)) % + i64_fid.get() % proto::schema::DataType_Name(int(DataType::INT64)); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + EXPECT_EQ(final.size(), N); - auto build_arith_op_expr = [&](DataType data_type, - int64_t right_val, - int64_t val) -> expr::TypedExprPtr { - if (IsIntegerDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_int64_val(right_val); - proto::plan::GenericValue val2; - val2.set_int64_val(val); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val1, - val2); - } else if (IsFloatDataType(data_type)) { - proto::plan::GenericValue val1; - val1.set_float_val(float(right_val)); - proto::plan::GenericValue val2; - val2.set_float_val(float(val)); - return std::make_shared( - expr::ColumnInfo(fids[data_type], data_type), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val1, - val2); - } else { - throw std::runtime_error("not supported type"); + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = age32_col[i]; + auto val2 = age64_col[i]; + auto ref = ref_func(val1, val2); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; } - }; + } +} - auto test_case_base = [=, &seg](expr::TypedExprPtr expr) { - std::cout << expr->ToString() << std::endl; - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - auto start = std::chrono::steady_clock::now(); - for (int i = 0; i < 100; i++) { - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); - } - std::cout << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() / - 100.0 - << "us" << std::endl; - }; +TEST_P(ExprTest, TestCompareWithScalarIndexNullable) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; - std::cout << "test unary range operator" << std::endl; - auto expr = build_unary_range_expr(DataType::INT8, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::INT16, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::INT32, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::INT64, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::FLOAT, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::DOUBLE, 10); - test_case_base(expr); - expr = build_unary_range_expr(DataType::VARCHAR, 10); - test_case_base(expr); + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: %4% + > + right_column_info: < + field_id: %5% + data_type: %6% + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; - std::cout << "test binary range operator" << std::endl; - expr = build_binary_range_expr(DataType::INT8, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::INT16, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::INT32, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::INT64, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::FLOAT, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::DOUBLE, 10, 100); - test_case_base(expr); - expr = build_binary_range_expr(DataType::VARCHAR, 10, 100); - test_case_base(expr); + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT32, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); - std::cout << "test compare expr operator" << std::endl; - expr = build_compare_expr(DataType::INT8); - test_case_base(expr); - expr = build_compare_expr(DataType::INT16); - test_case_base(expr); - expr = build_compare_expr(DataType::INT32); - test_case_base(expr); - expr = build_compare_expr(DataType::INT64); - test_case_base(expr); - expr = build_compare_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_compare_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_compare_expr(DataType::VARCHAR); - test_case_base(expr); + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; - std::cout << "test artih op val operator" << std::endl; - expr = build_arith_op_expr(DataType::INT8, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::INT16, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::INT32, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::INT64, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::FLOAT, 10, 100); - test_case_base(expr); - expr = build_arith_op_expr(DataType::DOUBLE, 10, 100); - test_case_base(expr); + // load index for int32 field + auto nullable_col = raw_data.get_col(nullable_fid); + nullable_col[0] = 1000; + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto nullable_index = milvus::index::CreateScalarIndexSort(); + nullable_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(nullable_index); + seg->LoadIndex(load_index_info); - std::cout << "test logical unary expr operator" << std::endl; - expr = build_logical_unary_expr(DataType::INT8); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::INT16); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::INT32); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::INT64); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_logical_unary_expr(DataType::VARCHAR); - test_case_base(expr); + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data()); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); - std::cout << "test logical binary expr operator" << std::endl; - expr = build_logical_binary_expr(DataType::INT8); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::INT16); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::INT32); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::INT64); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_logical_binary_expr(DataType::VARCHAR); - test_case_base(expr); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)) % + i64_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + EXPECT_EQ(final.size(), N); - std::cout << "test multi logical binary expr operator" << std::endl; - expr = build_multi_logical_binary_expr(DataType::INT8); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::INT16); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::INT32); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::INT64); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::FLOAT); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::DOUBLE); - test_case_base(expr); - expr = build_multi_logical_binary_expr(DataType::VARCHAR); - test_case_base(expr); + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = nullable_col[i]; + auto val2 = age64_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } } -TEST_P(ExprTest, test_term_pk) { +TEST_P(ExprTest, TestCompareWithScalarIndexNullable2) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a < b; + }}, + {R"(LessEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a <= b; + }}, + {R"(GreaterThan)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a > b; + }}, + {R"(GreaterEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a >= b; + }}, + {R"(Equal)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a == b; + }}, + {R"(NotEqual)", + [](int a, int64_t b, bool valid) { + if (!valid) { + return false; + } + return a != b; + }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: %4% + > + right_column_info: < + field_id: %5% + data_type: %6% + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto nullable_fid = + schema->AddDebugField("nullable", DataType::INT32, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto nullable_col = raw_data.get_col(nullable_fid); + nullable_col[0] = 1000; + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto nullable_index = milvus::index::CreateScalarIndexSort(); + nullable_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(nullable_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data()); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % i64_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)) % + nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val2 = nullable_col[i]; + auto val1 = age64_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, test_term_pk_with_sorted) { auto schema = std::make_shared(); schema->AddField( FieldName("Timestamp"), FieldId(1), DataType::INT64, false); @@ -1604,7 +2457,8 @@ TEST_P(ExprTest, test_term_pk) { auto int64_fid = schema->AddDebugField("int64", DataType::INT64); schema->set_primary_field_id(int64_fid); - auto seg = CreateSealedSegment(schema); + auto seg = CreateSealedSegment( + schema, nullptr, 1, SegcoreConfig::default_config(), false, true); int N = 100000; auto raw_data = DataGen(schema, N); @@ -1631,7 +2485,7 @@ TEST_P(ExprTest, test_term_pk) { } auto expr = std::make_shared( expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); - + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); @@ -1659,23 +2513,32 @@ TEST_P(ExprTest, test_term_pk) { } } -TEST_P(ExprTest, test_term_pk_with_sorted) { +TEST_P(ExprTest, TestSealedSegmentGetBatchSize) { auto schema = std::make_shared(); - schema->AddField( - FieldName("Timestamp"), FieldId(1), DataType::INT64, false); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - schema->set_primary_field_id(int64_fid); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto str3_fid = schema->AddDebugField("string3", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); - auto seg = CreateSealedSegment( - schema, nullptr, 1, SegcoreConfig::default_config(), false, true); - int N = 100000; + auto seg = CreateSealedSegment(schema); + size_t N = 1000; auto raw_data = DataGen(schema, N); - - // load field data auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -1688,130 +2551,149 @@ TEST_P(ExprTest, test_term_pk_with_sorted) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector retrieve_ints; - for (int i = 0; i < 10; ++i) { - proto::plan::GenericValue val; - val.set_int64_val(i); - retrieve_ints.push_back(val); - } - auto expr = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { + switch (type) { + case DataType::BOOL: { + auto compare_expr = std::make_shared( + bool_fid, + bool_1_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); + return compare_expr; + } + case DataType::INT8: { + auto compare_expr = + std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case DataType::INT16: { + auto compare_expr = + std::make_shared(int16_fid, + int16_1_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); + return compare_expr; + } + case DataType::INT32: { + auto compare_expr = + std::make_shared(int32_fid, + int32_1_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); + return compare_expr; + } + case DataType::INT64: { + auto compare_expr = + std::make_shared(int64_fid, + int64_1_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); + return compare_expr; + } + case DataType::FLOAT: { + auto compare_expr = + std::make_shared(float_fid, + float_1_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); + return compare_expr; + } + case DataType::DOUBLE: { + auto compare_expr = + std::make_shared(double_fid, + double_1_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); + return compare_expr; + } + case DataType::VARCHAR: { + auto compare_expr = + std::make_shared(str2_fid, + str3_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); + return compare_expr; + } + default: + return std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + } + }; + std::cout << "start compare test" << std::endl; + auto expr = build_expr(DataType::BOOL); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(final[i], true); - } - for (int i = 10; i < N; ++i) { - EXPECT_EQ(final[i], false); - } - retrieve_ints.clear(); - for (int i = 0; i < 10; ++i) { - proto::plan::GenericValue val; - val.set_int64_val(i + N); - retrieve_ints.push_back(val); - } - expr = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT8); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); - for (int i = 0; i < N; ++i) { - EXPECT_EQ(final[i], false); - } + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT16); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT32); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT64); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::FLOAT); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::DOUBLE); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + std::cout << "end compare test" << std::endl; } -TEST_P(ExprTest, TestConjuctExpr) { +TEST_P(ExprTest, TestCompareExprNullable) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_nullable_fid = + schema->AddDebugField("bool1", DataType::BOOL, true); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int8_nullable_fid = + schema->AddDebugField("int81", DataType::INT8, true); auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int16_nullable_fid = + schema->AddDebugField("int161", DataType::INT16, true); auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int32_nullable_fid = + schema->AddDebugField("int321", DataType::INT32, true); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto int64_nullable_fid = + schema->AddDebugField("int641", DataType::INT64, true); auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_nullable_fid = + schema->AddDebugField("float1", DataType::FLOAT, true); auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); - - auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); - - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); - - seg->LoadFieldData(FieldId(field_id), info); - } - - auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { - ::milvus::proto::plan::GenericValue value; - value.set_int64_val(l); - auto left = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - value); - value.set_int64_val(r); - auto right = std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::LessThan, - value); - - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, left, right); - }; - - std::vector> test_case = { - {100, 0}, {0, 100}, {8192, 8194}}; - for (auto& pair : test_case) { - std::cout << pair.first << "|" << pair.second << std::endl; - auto expr = build_expr(pair.first, pair.second); - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - BitsetType final; - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - for (int i = 0; i < N; ++i) { - EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; - } - } -} - -TEST_P(ExprTest, TestUnaryBenchTest) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto double_nullable_fid = + schema->AddDebugField("double1", DataType::DOUBLE, true); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto str_nullable_fid = + schema->AddDebugField("string3", DataType::VARCHAR, true); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + size_t N = 1000; auto raw_data = DataGen(schema, N); - - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -1825,62 +2707,149 @@ TEST_P(ExprTest, TestUnaryBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(10); - } else { - val.set_int64_val(10); - } - auto expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::GreaterThan, - val); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 10; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - } - std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; - } -} - -TEST_P(ExprTest, TestBinaryRangeBenchTest) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { + switch (type) { + case DataType::BOOL: { + auto compare_expr = std::make_shared( + bool_fid, + bool_nullable_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); + return compare_expr; + } + case DataType::INT8: { + auto compare_expr = + std::make_shared(int8_fid, + int8_nullable_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case DataType::INT16: { + auto compare_expr = + std::make_shared(int16_fid, + int16_nullable_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); + return compare_expr; + } + case DataType::INT32: { + auto compare_expr = + std::make_shared(int32_fid, + int32_nullable_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); + return compare_expr; + } + case DataType::INT64: { + auto compare_expr = + std::make_shared(int64_fid, + int64_nullable_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); + return compare_expr; + } + case DataType::FLOAT: { + auto compare_expr = + std::make_shared(float_fid, + float_nullable_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); + return compare_expr; + } + case DataType::DOUBLE: { + auto compare_expr = + std::make_shared(double_fid, + double_nullable_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); + return compare_expr; + } + case DataType::VARCHAR: { + auto compare_expr = + std::make_shared(str2_fid, + str_nullable_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); + return compare_expr; + } + default: + return std::make_shared(int8_fid, + int8_nullable_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + } + }; + std::cout << "start compare test" << std::endl; + auto expr = build_expr(DataType::BOOL); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT8); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT16); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT32); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT64); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::FLOAT); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::DOUBLE); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + std::cout << "end compare test" << std::endl; +} + +TEST_P(ExprTest, TestCompareExprNullable2) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + auto bool_nullable_fid = + schema->AddDebugField("bool1", DataType::BOOL, true); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_nullable_fid = + schema->AddDebugField("int81", DataType::INT8, true); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_nullable_fid = + schema->AddDebugField("int161", DataType::INT16, true); auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int32_nullable_fid = + schema->AddDebugField("int321", DataType::INT32, true); auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto int64_nullable_fid = + schema->AddDebugField("int641", DataType::INT64, true); auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_nullable_fid = + schema->AddDebugField("float1", DataType::FLOAT, true); auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_nullable_fid = + schema->AddDebugField("double1", DataType::DOUBLE, true); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto str_nullable_fid = + schema->AddDebugField("string3", DataType::VARCHAR, true); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + size_t N = 1000; auto raw_data = DataGen(schema, N); - - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -1894,52 +2863,119 @@ TEST_P(ExprTest, TestBinaryRangeBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue lower; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - lower.set_float_val(10); - } else { - lower.set_int64_val(10); - } - proto::plan::GenericValue upper; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - upper.set_float_val(45); - } else { - upper.set_int64_val(45); - } - auto expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - lower, - upper, - true, - true); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 10; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { + switch (type) { + case DataType::BOOL: { + auto compare_expr = std::make_shared( + bool_fid, + bool_nullable_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); + return compare_expr; + } + case DataType::INT8: { + auto compare_expr = + std::make_shared(int8_nullable_fid, + int8_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case DataType::INT16: { + auto compare_expr = + std::make_shared(int16_nullable_fid, + int16_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); + return compare_expr; + } + case DataType::INT32: { + auto compare_expr = + std::make_shared(int32_nullable_fid, + int32_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); + return compare_expr; + } + case DataType::INT64: { + auto compare_expr = + std::make_shared(int64_nullable_fid, + int64_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); + return compare_expr; + } + case DataType::FLOAT: { + auto compare_expr = + std::make_shared(float_nullable_fid, + float_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); + return compare_expr; + } + case DataType::DOUBLE: { + auto compare_expr = + std::make_shared(double_nullable_fid, + double_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); + return compare_expr; + } + case DataType::VARCHAR: { + auto compare_expr = + std::make_shared(str_nullable_fid, + str2_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); + return compare_expr; + } + default: + return std::make_shared(int8_nullable_fid, + int8_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); } - std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; - } + }; + std::cout << "start compare test" << std::endl; + auto expr = build_expr(DataType::BOOL); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT8); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT16); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT32); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::INT64); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::FLOAT); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + expr = build_expr(DataType::DOUBLE); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + std::cout << "end compare test" << std::endl; } -TEST_P(ExprTest, TestLogicalUnaryBenchTest) { +TEST(Expr, TestExprPerformance) { + GTEST_SKIP() << "Skip performance test, open it when test performance"; auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); auto int16_fid = schema->AddDebugField("int16", DataType::INT16); @@ -1954,8 +2990,16 @@ TEST_P(ExprTest, TestLogicalUnaryBenchTest) { auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); schema->set_primary_field_id(str1_fid); + std::map fids = {{DataType::INT8, int8_fid}, + {DataType::INT16, int16_fid}, + {DataType::INT32, int32_fid}, + {DataType::INT64, int64_fid}, + {DataType::VARCHAR, str2_fid}, + {DataType::FLOAT, float_fid}, + {DataType::DOUBLE, double_fid}}; + auto seg = CreateSealedSegment(schema); - int N = 10000; + int N = 1000; auto raw_data = DataGen(schema, N); // load field data @@ -1972,148 +3016,341 @@ TEST_P(ExprTest, TestLogicalUnaryBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(10); + enum ExprType { + UnaryRangeExpr = 0, + TermExprImpl = 1, + CompareExpr = 2, + LogicalUnaryExpr = 3, + BinaryRangeExpr = 4, + LogicalBinaryExpr = 5, + BinaryArithOpEvalRangeExpr = 6, + }; + + auto build_unary_range_expr = [&](DataType data_type, + int64_t value) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val; + val.set_int64_val(value); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val; + val.set_float_val(float(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); } else { - val.set_int64_val(10); - } - auto child_expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::GreaterThan, - val); - auto expr = std::make_shared( - expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 50; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); + throw std::runtime_error("not supported type"); } - std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; - } -} + }; -TEST_P(ExprTest, TestBinaryLogicalBenchTest) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); + auto build_binary_range_expr = [&](DataType data_type, + int64_t low, + int64_t high) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(low); + proto::plan::GenericValue val2; + val2.set_int64_val(high); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(low)); + proto::plan::GenericValue val2; + val2.set_float_val(float(high)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_string_val(std::to_string(low)); + proto::plan::GenericValue val2; + val2.set_string_val(std::to_string(low)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else { + throw std::runtime_error("not supported type"); + } + }; - auto seg = CreateSealedSegment(schema); - int N = 10000; - auto raw_data = DataGen(schema, N); + auto build_term_expr = + [&](DataType data_type, + std::vector in_vals) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_int64_val(v); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsFloatDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_float_val(float(v)); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsStringDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(v)); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else { + throw std::runtime_error("not supported type"); + } + }; - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); + auto build_compare_expr = [&](DataType data_type) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type) || IsFloatDataType(data_type) || + IsStringDataType(data_type)) { + return std::make_shared( + fids[data_type], + fids[data_type], + data_type, + data_type, + proto::plan::OpType::LessThan); + } else { + throw std::runtime_error("not supported type"); + } + }; - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); + auto build_logical_unary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + }; - seg->LoadFieldData(FieldId(field_id), info); - } + auto build_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 10); + auto child2_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + }; - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; + auto build_multi_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 100); + auto child2_expr = build_unary_range_expr(data_type, 100); + auto child3_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child4_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child5_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + auto child6_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child5_expr, child6_expr); + }; - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; - proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(-1000000); + auto build_arith_op_expr = [&](DataType data_type, + int64_t right_val, + int64_t val) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(right_val); + proto::plan::GenericValue val2; + val2.set_int64_val(val); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(right_val)); + proto::plan::GenericValue val2; + val2.set_float_val(float(val)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); } else { - val.set_int64_val(-1000000); + throw std::runtime_error("not supported type"); } - proto::plan::GenericValue val1; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val1.set_float_val(-100); - } else { - val1.set_int64_val(-100); + }; + + auto test_case_base = [=, &seg](expr::TypedExprPtr expr) { + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + std::cout << expr->ToString() << std::endl; + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i < 100; i++) { + visitor.ExecuteExprNode(plan, seg.get(), N, final); + EXPECT_EQ(final.size(), N); } - auto child1_expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::LessThan, - val); - auto child2_expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::NotEqual, - val1); - auto expr = std::make_shared( - expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 50; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - } - std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; - } + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() / + 100.0 + << "us" << std::endl; + }; + + std::cout << "test unary range operator" << std::endl; + auto expr = build_unary_range_expr(DataType::INT8, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::INT16, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::INT32, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::INT64, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::FLOAT, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::DOUBLE, 10); + test_case_base(expr); + expr = build_unary_range_expr(DataType::VARCHAR, 10); + test_case_base(expr); + + std::cout << "test binary range operator" << std::endl; + expr = build_binary_range_expr(DataType::INT8, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::INT16, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::INT32, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::INT64, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::FLOAT, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::DOUBLE, 10, 100); + test_case_base(expr); + expr = build_binary_range_expr(DataType::VARCHAR, 10, 100); + test_case_base(expr); + + std::cout << "test compare expr operator" << std::endl; + expr = build_compare_expr(DataType::INT8); + test_case_base(expr); + expr = build_compare_expr(DataType::INT16); + test_case_base(expr); + expr = build_compare_expr(DataType::INT32); + test_case_base(expr); + expr = build_compare_expr(DataType::INT64); + test_case_base(expr); + expr = build_compare_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_compare_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_compare_expr(DataType::VARCHAR); + test_case_base(expr); + + std::cout << "test artih op val operator" << std::endl; + expr = build_arith_op_expr(DataType::INT8, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::INT16, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::INT32, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::INT64, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::FLOAT, 10, 100); + test_case_base(expr); + expr = build_arith_op_expr(DataType::DOUBLE, 10, 100); + test_case_base(expr); + + std::cout << "test logical unary expr operator" << std::endl; + expr = build_logical_unary_expr(DataType::INT8); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::INT16); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::INT32); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::INT64); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_logical_unary_expr(DataType::VARCHAR); + test_case_base(expr); + + std::cout << "test logical binary expr operator" << std::endl; + expr = build_logical_binary_expr(DataType::INT8); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::INT16); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::INT32); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::INT64); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_logical_binary_expr(DataType::VARCHAR); + test_case_base(expr); + + std::cout << "test multi logical binary expr operator" << std::endl; + expr = build_multi_logical_binary_expr(DataType::INT8); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::INT16); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::INT32); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::INT64); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::FLOAT); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::DOUBLE); + test_case_base(expr); + expr = build_multi_logical_binary_expr(DataType::VARCHAR); + test_case_base(expr); } -TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { +TEST_P(ExprTest, test_term_pk) { auto schema = std::make_shared(); + schema->AddField( + FieldName("Timestamp"), FieldId(1), DataType::INT64, false); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(int64_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + int N = 1000; auto raw_data = DataGen(schema, N); // load field data auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -2126,73 +3363,52 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector> test_cases = { - {int8_fid, DataType::INT8}, - {int16_fid, DataType::INT16}, - {int32_fid, DataType::INT32}, - {int64_fid, DataType::INT64}, - {float_fid, DataType::FLOAT}, - {double_fid, DataType::DOUBLE}}; - - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.second) << std::endl; + std::vector retrieve_ints; + for (int i = 0; i < 10; ++i) { proto::plan::GenericValue val; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - val.set_float_val(100); - } else { - val.set_int64_val(100); - } - proto::plan::GenericValue right; - if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { - right.set_float_val(10); - } else { - right.set_int64_val(10); - } - auto expr = std::make_shared( - expr::ColumnInfo(pair.first, pair.second), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val, - right); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 50; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - } - std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + val.set_int64_val(i); + retrieve_ints.push_back(val); + } + auto expr = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + EXPECT_EQ(final.size(), N); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(final[i], true); + } + for (int i = 10; i < N; ++i) { + EXPECT_EQ(final[i], false); + } + retrieve_ints.clear(); + for (int i = 0; i < 10; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i + N); + retrieve_ints.push_back(val); + } + expr = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + EXPECT_EQ(final.size(), N); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(final[i], false); } } -TEST_P(ExprTest, TestCompareExprBenchTest) { +TEST_P(ExprTest, TestSealedSegmentGetBatchSize) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); - auto int16_fid = schema->AddDebugField("int16", DataType::INT16); - auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); - auto int32_fid = schema->AddDebugField("int32", DataType::INT32); - auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); - auto int64_fid = schema->AddDebugField("int64", DataType::INT64); - auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); - auto float_fid = schema->AddDebugField("float", DataType::FLOAT); - auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); - auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); - auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); - schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + int N = 1000; auto raw_data = DataGen(schema, N); - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { @@ -2207,40 +3423,99 @@ TEST_P(ExprTest, TestCompareExprBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - std::vector< - std::pair, std::pair>> - test_cases = { - {{int8_fid, DataType::INT8}, {int8_1_fid, DataType::INT8}}, - {{int16_fid, DataType::INT16}, {int16_fid, DataType::INT16}}, - {{int32_fid, DataType::INT32}, {int32_1_fid, DataType::INT32}}, - {{int64_fid, DataType::INT64}, {int64_1_fid, DataType::INT64}}, - {{float_fid, DataType::FLOAT}, {float_1_fid, DataType::FLOAT}}, - {{double_fid, DataType::DOUBLE}, {double_1_fid, DataType::DOUBLE}}}; + proto::plan::GenericValue val; + val.set_int64_val(10); + auto expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto plan_node = + std::make_shared(DEFAULT_PLANNODE_ID, expr); - for (const auto& pair : test_cases) { - std::cout << "start test type:" << int(pair.first.second) << std::endl; - proto::plan::GenericValue lower; - auto expr = std::make_shared(pair.first.first, - pair.second.first, - pair.first.second, - pair.second.second, - OpType::LessThan); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - int64_t all_cost = 0; - for (int i = 0; i < 10; i++) { - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - all_cost += std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); + std::vector test_batch_size = { + 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; + for (const auto& batch_size : test_batch_size) { + EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; + auto plan = plan::PlanFragment(plan_node); + auto query_context = std::make_shared( + "query id", seg.get(), N, MAX_TIMESTAMP); + + auto task = + milvus::exec::Task::Create("task_expr", plan, 0, query_context); + auto last_num = N % batch_size; + auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; + int iter = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + auto childrens = result->childrens(); + if (++iter != iter_num) { + EXPECT_EQ(childrens[0]->size(), batch_size); + } else { + EXPECT_EQ(childrens[0]->size(), last_num); + } } - std::cout << " cost: " << all_cost / 10 << "us" << std::endl; } } -TEST_P(ExprTest, TestRefactorExprs) { +TEST_P(ExprTest, TestGrowingSegmentGetBatchSize) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + auto raw_data = DataGen(schema, N); + seg->PreInsert(N); + seg->Insert(0, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + + proto::plan::GenericValue val; + val.set_int64_val(10); + auto expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto plan_node = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + + std::vector test_batch_size = { + 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; + + for (const auto& batch_size : test_batch_size) { + EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; + auto plan = plan::PlanFragment(plan_node); + auto query_context = std::make_shared( + "query id", seg.get(), N, MAX_TIMESTAMP); + + auto task = + milvus::exec::Task::Create("task_expr", plan, 0, query_context); + auto last_num = N % batch_size; + auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; + int iter = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + auto childrens = result->childrens(); + if (++iter != iter_num) { + EXPECT_EQ(childrens[0]->size(), batch_size); + } else { + EXPECT_EQ(childrens[0]->size(), last_num); + } + } + } +} + +TEST_P(ExprTest, TestConjuctExpr) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto int8_fid = schema->AddDebugField("int8", DataType::INT8); @@ -2258,9 +3533,8 @@ TEST_P(ExprTest, TestRefactorExprs) { schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 10000; + int N = 1000; auto raw_data = DataGen(schema, N); - // load field data auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { @@ -2274,490 +3548,3918 @@ TEST_P(ExprTest, TestRefactorExprs) { seg->LoadFieldData(FieldId(field_id), info); } + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - enum ExprType { - UnaryRangeExpr = 0, - TermExprImpl = 1, - CompareExpr = 2, - LogicalUnaryExpr = 3, - BinaryRangeExpr = 4, - LogicalBinaryExpr = 5, - BinaryArithOpEvalRangeExpr = 6, - }; + auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(l); + auto left = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + value); + value.set_int64_val(r); + auto right = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::LessThan, + value); - auto build_expr = [&](enum ExprType test_type, - int n) -> expr::TypedExprPtr { - switch (test_type) { - case UnaryRangeExpr: { - proto::plan::GenericValue val; - val.set_int64_val(10); - return std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - proto::plan::OpType::GreaterThan, - val); - } - case TermExprImpl: { - std::vector retrieve_ints; - // for (int i = 0; i < n; ++i) { - // retrieve_ints.push_back("xxxxxx" + std::to_string(i % 10)); - // } - // return std::make_shared>( - // ColumnInfo(str1_fid, DataType::VARCHAR), - // retrieve_ints, - // proto::plan::GenericValue::ValCase::kStringVal); - for (int i = 0; i < n; ++i) { - proto::plan::GenericValue val; - val.set_float_val(i); - retrieve_ints.push_back(val); - } - return std::make_shared( - expr::ColumnInfo(double_fid, DataType::DOUBLE), - retrieve_ints); - } - case CompareExpr: { - auto compare_expr = - std::make_shared(int8_fid, - int8_1_fid, - DataType::INT8, - DataType::INT8, - OpType::LessThan); - return compare_expr; - } - case BinaryRangeExpr: { - proto::plan::GenericValue lower; - lower.set_int64_val(10); - proto::plan::GenericValue upper; - upper.set_int64_val(45); - return std::make_shared( - expr::ColumnInfo(int64_fid, DataType::INT64), - lower, - upper, - true, - true); - } - case LogicalUnaryExpr: { - proto::plan::GenericValue val; - val.set_int64_val(10); - auto child_expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - return std::make_shared( - expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); - } - case LogicalBinaryExpr: { - proto::plan::GenericValue val; - val.set_int64_val(10); - auto child1_expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - auto child2_expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::NotEqual, - val); - ; - return std::make_shared( - expr::LogicalBinaryExpr::OpType::And, - child1_expr, - child2_expr); - } - case BinaryArithOpEvalRangeExpr: { - proto::plan::GenericValue val; - val.set_int64_val(100); - proto::plan::GenericValue right; - right.set_int64_val(10); - return std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::Equal, - proto::plan::ArithOpType::Add, - val, - right); - } - default: { - proto::plan::GenericValue val; - val.set_int64_val(10); - return std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - } - } + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, left, right); }; - auto test_case = [&](int n) { - auto expr = build_expr(UnaryRangeExpr, n); - BitsetType final; + std::vector> test_case = { + {100, 0}, {0, 100}, {8192, 8194}}; + for (auto& pair : test_case) { + std::cout << pair.first << "|" << pair.second << std::endl; + auto expr = build_expr(pair.first, pair.second); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - std::cout << "start test" << std::endl; - auto start = std::chrono::steady_clock::now(); - final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); - std::cout << n << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << "us" << std::endl; - }; - test_case(3); - test_case(10); - test_case(20); - test_case(30); - test_case(50); - test_case(100); - test_case(200); - // test_case(500); + BitsetType final; + visitor.ExecuteExprNode(plan, seg.get(), N, final); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; + } + } } -TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { - std::vector< - std::tuple>> - testcases = { - {R"(LessThan)", - [](std::string a, std::string b) { return a.compare(b) < 0; }}, - {R"(LessEqual)", - [](std::string a, std::string b) { return a.compare(b) <= 0; }}, - {R"(GreaterThan)", - [](std::string a, std::string b) { return a.compare(b) > 0; }}, - {R"(GreaterEqual)", - [](std::string a, std::string b) { return a.compare(b) >= 0; }}, - {R"(Equal)", - [](std::string a, std::string b) { return a.compare(b) == 0; }}, - {R"(NotEqual)", - [](std::string a, std::string b) { return a.compare(b) != 0; }}, - }; - - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - compare_expr: < - left_column_info: < - field_id: %3% - data_type: VarChar - > - right_column_info: < - field_id: %4% - data_type: VarChar - > - op: %2% - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; - +TEST_P(ExprTest, TestConjuctExprNullable) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_nullable_fid = + schema->AddDebugField("int8_nullable", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_nullable_fid = + schema->AddDebugField("int16_nullable", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_nullable_fid = + schema->AddDebugField("int32_nullable", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_nullable_fid = + schema->AddDebugField("int64_nullable", DataType::INT64); auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); int N = 1000; auto raw_data = DataGen(schema, N); - segcore::LoadIndexInfo load_index_info; + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); - // load index for int32 field - auto str1_col = raw_data.get_col(str1_fid); - GenScalarIndexing(N, str1_col.data()); - auto str1_index = milvus::index::CreateScalarIndexSort(); - str1_index->Build(N, str1_col.data()); - load_index_info.field_id = str1_fid.get(); - load_index_info.field_type = DataType::VARCHAR; - load_index_info.index = std::move(str1_index); - seg->LoadIndex(load_index_info); + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); - // load index for int64 field - auto str2_col = raw_data.get_col(str2_fid); - GenScalarIndexing(N, str2_col.data()); - auto str2_index = milvus::index::CreateScalarIndexSort(); - str2_index->Build(N, str2_col.data()); - load_index_info.field_id = str2_fid.get(); - load_index_info.field_type = DataType::VARCHAR; - load_index_info.index = std::move(str2_index); - seg->LoadIndex(load_index_info); + seg->LoadFieldData(FieldId(field_id), info); + } + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - for (auto [clause, ref_func] : testcases) { - auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % - clause % str1_fid.get() % str2_fid.get(); - auto binary_plan = - translate_text_plan_with_metric_type(dsl_string.str()); - auto plan = CreateSearchPlanByExpr( - *schema, binary_plan.data(), binary_plan.size()); - // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; - BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg.get(), - N, - MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(l); + auto left = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + value); + value.set_int64_val(r); + auto right = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::LessThan, + value); + + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, left, right); + }; + std::vector> test_case = { + {100, 0}, {0, 100}, {8192, 8194}}; + for (auto& pair : test_case) { + std::cout << pair.first << "|" << pair.second << std::endl; + auto expr = build_expr(pair.first, pair.second); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; + visitor.ExecuteExprNode(plan, seg.get(), N, final); for (int i = 0; i < N; ++i) { - auto ans = final[i]; - auto val1 = str1_col[i]; - auto val2 = str2_col[i]; - auto ref = ref_func(val1, val2); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" - << boost::format("[%1%, %2%]") % val1 % val2; + EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; } } } -TEST_P(ExprTest, TestBinaryArithOpEvalRange) { - std::vector, DataType>> testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 101 - data_type: Int8 - > - arith_op: Add - right_operand: < - int64_val: 4 - > - op: Equal - value: < - int64_val: 8 - > - >)", - [](int8_t v) { return (v + 4) == 8; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 102 - data_type: Int16 - > - arith_op: Sub - right_operand: < - int64_val: 500 - > - op: Equal - value: < - int64_val: 1500 - > - >)", - [](int16_t v) { return (v - 500) == 1500; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 103 - data_type: Int32 - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4000 - > - >)", - [](int32_t v) { return (v * 2) == 4000; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 1000 - > - >)", - [](int64_t v) { return (v / 2) == 1000; }, - DataType::INT64}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 103 - data_type: Int32 - > - arith_op: Mod - right_operand: < - int64_val: 100 - > - op: Equal - value: < - int64_val: 0 - > - >)", - [](int32_t v) { return (v % 100) == 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 105 - data_type: Float - > - arith_op: Add - right_operand: < - float_val: 500 - > - op: Equal - value: < - float_val: 2500 - > - >)", - [](float v) { return (v + 500) == 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 106 - data_type: Double - > - arith_op: Add - right_operand: < - float_val: 500 - > - op: Equal - value: < - float_val: 2500 - > - >)", - [](double v) { return (v + 500) == 2500; }, - DataType::DOUBLE}, - // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 105 - data_type: Float - > - arith_op: Add - right_operand: < - float_val: 500 - > - op: NotEqual - value: < - float_val: 2500 - > - >)", - [](float v) { return (v + 500) != 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 106 - data_type: Double - > - arith_op: Sub - right_operand: < - float_val: 500 - > - op: NotEqual - value: < - float_val: 2500 - > - >)", - [](double v) { return (v - 500) != 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 101 - data_type: Int8 - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 2 - > - >)", - [](int8_t v) { return (v * 2) != 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 102 - data_type: Int16 - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 1000 - > - >)", - [](int16_t v) { return (v / 2) != 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 103 - data_type: Int32 - > - arith_op: Mod - right_operand: < - int64_val: 100 - > - op: NotEqual - value: < - int64_val: 0 - > - >)", - [](int32_t v) { return (v % 100) != 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 104 - data_type: Int64 - > +TEST_P(ExprTest, TestUnaryBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(10); + } else { + val.set_int64_val(10); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::GreaterThan, + val); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestBinaryRangeBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue lower; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + lower.set_float_val(10); + } else { + lower.set_int64_val(10); + } + proto::plan::GenericValue upper; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + upper.set_float_val(45); + } else { + upper.set_int64_val(45); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + lower, + upper, + true, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestLogicalUnaryBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(10); + } else { + val.set_int64_val(10); + } + auto child_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::GreaterThan, + val); + auto expr = std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestBinaryLogicalBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(-1000000); + } else { + val.set_int64_val(-1000000); + } + proto::plan::GenericValue val1; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val1.set_float_val(-100); + } else { + val1.set_int64_val(-100); + } + auto child1_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::LessThan, + val); + auto child2_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::NotEqual, + val1); + auto expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(100); + } else { + val.set_int64_val(100); + } + proto::plan::GenericValue right; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + right.set_float_val(10); + } else { + right.set_int64_val(10); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val, + right); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestCompareExprBenchTest) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector< + std::pair, std::pair>> + test_cases = { + {{int8_fid, DataType::INT8}, {int8_1_fid, DataType::INT8}}, + {{int16_fid, DataType::INT16}, {int16_fid, DataType::INT16}}, + {{int32_fid, DataType::INT32}, {int32_1_fid, DataType::INT32}}, + {{int64_fid, DataType::INT64}, {int64_1_fid, DataType::INT64}}, + {{float_fid, DataType::FLOAT}, {float_1_fid, DataType::FLOAT}}, + {{double_fid, DataType::DOUBLE}, {double_1_fid, DataType::DOUBLE}}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.first.second) << std::endl; + proto::plan::GenericValue lower; + auto expr = std::make_shared(pair.first.first, + pair.second.first, + pair.first.second, + pair.second.second, + OpType::LessThan); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10 << "us" << std::endl; + } +} + +TEST_P(ExprTest, TestRefactorExprs) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + enum ExprType { + UnaryRangeExpr = 0, + TermExprImpl = 1, + CompareExpr = 2, + LogicalUnaryExpr = 3, + BinaryRangeExpr = 4, + LogicalBinaryExpr = 5, + BinaryArithOpEvalRangeExpr = 6, + }; + + auto build_expr = [&](enum ExprType test_type, + int n) -> expr::TypedExprPtr { + switch (test_type) { + case UnaryRangeExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + } + case TermExprImpl: { + std::vector retrieve_ints; + // for (int i = 0; i < n; ++i) { + // retrieve_ints.push_back("xxxxxx" + std::to_string(i % 10)); + // } + // return std::make_shared>( + // ColumnInfo(str1_fid, DataType::VARCHAR), + // retrieve_ints, + // proto::plan::GenericValue::ValCase::kStringVal); + for (int i = 0; i < n; ++i) { + proto::plan::GenericValue val; + val.set_float_val(i); + retrieve_ints.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(double_fid, DataType::DOUBLE), + retrieve_ints); + } + case CompareExpr: { + auto compare_expr = + std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case BinaryRangeExpr: { + proto::plan::GenericValue lower; + lower.set_int64_val(10); + proto::plan::GenericValue upper; + upper.set_int64_val(45); + return std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + lower, + upper, + true, + true); + } + case LogicalUnaryExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + auto child_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + } + case LogicalBinaryExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + auto child1_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto child2_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::NotEqual, + val); + ; + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, + child1_expr, + child2_expr); + } + case BinaryArithOpEvalRangeExpr: { + proto::plan::GenericValue val; + val.set_int64_val(100); + proto::plan::GenericValue right; + right.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val, + right); + } + default: { + proto::plan::GenericValue val; + val.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + } + } + }; + auto test_case = [&](int n) { + auto expr = build_expr(UnaryRangeExpr, n); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + std::cout << n << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + }; + test_case(3); + test_case(10); + test_case(20); + test_case(30); + test_case(50); + test_case(100); + test_case(200); + // test_case(500); +} + +TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { + std::vector< + std::tuple>> + testcases = { + {R"(LessThan)", + [](std::string a, std::string b) { return a.compare(b) < 0; }}, + {R"(LessEqual)", + [](std::string a, std::string b) { return a.compare(b) <= 0; }}, + {R"(GreaterThan)", + [](std::string a, std::string b) { return a.compare(b) > 0; }}, + {R"(GreaterEqual)", + [](std::string a, std::string b) { return a.compare(b) >= 0; }}, + {R"(Equal)", + [](std::string a, std::string b) { return a.compare(b) == 0; }}, + {R"(NotEqual)", + [](std::string a, std::string b) { return a.compare(b) != 0; }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: VarChar + > + right_column_info: < + field_id: %4% + data_type: VarChar + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto str1_col = raw_data.get_col(str1_fid); + auto str1_index = milvus::index::CreateScalarIndexSort(); + str1_index->Build(N, str1_col.data()); + load_index_info.field_id = str1_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str1_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto str2_col = raw_data.get_col(str2_fid); + auto str2_index = milvus::index::CreateScalarIndexSort(); + str2_index->Build(N, str2_col.data()); + load_index_info.field_id = str2_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str2_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % str1_fid.get() % str2_fid.get(); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = str1_col[i]; + auto val2 = str2_col[i]; + auto ref = ref_func(val1, val2); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable) { + std::vector>> + testcases = { + {R"(LessThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) < 0; + }}, + {R"(LessEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) <= 0; + }}, + {R"(GreaterThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) > 0; + }}, + {R"(GreaterEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) >= 0; + }}, + {R"(Equal)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) == 0; + }}, + {R"(NotEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) != 0; + }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: VarChar + > + right_column_info: < + field_id: %4% + data_type: VarChar + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto nullable_fid = + schema->AddDebugField("nullable_fid", DataType::VARCHAR, true); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto str1_col = raw_data.get_col(str1_fid); + auto str1_index = milvus::index::CreateScalarIndexSort(); + str1_index->Build(N, str1_col.data()); + load_index_info.field_id = str1_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str1_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto nullable_col = raw_data.get_col(nullable_fid); + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto str2_index = milvus::index::CreateScalarIndexSort(); + str2_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str2_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % str1_fid.get() % nullable_fid.get(); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = str1_col[i]; + auto val2 = nullable_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable2) { + std::vector>> + testcases = { + {R"(LessThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) < 0; + }}, + {R"(LessEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) <= 0; + }}, + {R"(GreaterThan)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) > 0; + }}, + {R"(GreaterEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) >= 0; + }}, + {R"(Equal)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) == 0; + }}, + {R"(NotEqual)", + [](std::string a, std::string b, bool valid) { + if (!valid) { + return false; + } + return a.compare(b) != 0; + }}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + compare_expr: < + left_column_info: < + field_id: %3% + data_type: VarChar + > + right_column_info: < + field_id: %4% + data_type: VarChar + > + op: %2% + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto nullable_fid = + schema->AddDebugField("nullable_fid", DataType::VARCHAR, true); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + // load index for int32 field + auto str1_col = raw_data.get_col(str1_fid); + auto str1_index = milvus::index::CreateScalarIndexSort(); + str1_index->Build(N, str1_col.data()); + load_index_info.field_id = str1_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str1_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto nullable_col = raw_data.get_col(nullable_fid); + auto valid_data_col = raw_data.get_col_valid(nullable_fid); + auto str2_index = milvus::index::CreateScalarIndexSort(); + str2_index->Build(N, nullable_col.data(), valid_data_col.data()); + load_index_info.field_id = nullable_fid.get(); + load_index_info.field_type = DataType::VARCHAR; + load_index_info.index = std::move(str2_index); + seg->LoadIndex(load_index_info); + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % + clause % nullable_fid.get() % str1_fid.get(); + auto binary_plan = + translate_text_plan_with_metric_type(dsl_string.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + auto val1 = nullable_col[i]; + auto val2 = str1_col[i]; + auto ref = ref_func(val1, val2, valid_data_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" + << boost::format("[%1%, %2%]") % val1 % val2; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRange) { + std::vector, DataType>> testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + > + >)", + [](int8_t v) { return (v + 4) == 8; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + > + >)", + [](int16_t v) { return (v - 500) == 1500; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + > + >)", + [](int32_t v) { return (v * 2) == 4000; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + > + >)", + [](int64_t v) { return (v / 2) == 1000; }, + DataType::INT64}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) == 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) == 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](double v) { return (v + 500) == 2500; }, + DataType::DOUBLE}, + // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) != 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) != 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) != 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) != 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) != 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) != 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) > 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) > 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) > 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) > 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) > 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) > 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) >= 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) >= 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) >= 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) >= 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) >= 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) >= 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) < 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) < 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) < 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) < 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) < 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) < 2500; }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](float v) { return (v + 500) <= 2500; }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](double v) { return (v - 500) <= 2500; }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v) { return (v * 2) <= 2; }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v) { return (v / 2) <= 1000; }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v) { return (v % 100) <= 0; }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v) { return (v + 500) <= 2500; }, + DataType::INT64}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i8_fid = schema->AddDebugField("age8", DataType::INT8); + auto i16_fid = schema->AddDebugField("age16", DataType::INT16); + auto i32_fid = schema->AddDebugField("age32", DataType::INT32); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto float_fid = schema->AddDebugField("age_float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age8_col; + std::vector age16_col; + std::vector age32_col; + std::vector age64_col; + std::vector age_float_col; + std::vector age_double_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto new_age8_col = raw_data.get_col(i8_fid); + auto new_age16_col = raw_data.get_col(i16_fid); + auto new_age32_col = raw_data.get_col(i32_fid); + auto new_age64_col = raw_data.get_col(i64_fid); + auto new_age_float_col = raw_data.get_col(float_fid); + auto new_age_double_col = raw_data.get_col(double_fid); + + age8_col.insert( + age8_col.end(), new_age8_col.begin(), new_age8_col.end()); + age16_col.insert( + age16_col.end(), new_age16_col.begin(), new_age16_col.end()); + age32_col.insert( + age32_col.end(), new_age32_col.begin(), new_age32_col.end()); + age64_col.insert( + age64_col.end(), new_age64_col.begin(), new_age64_col.end()); + age_float_col.insert(age_float_col.end(), + new_age_float_col.begin(), + new_age_float_col.end()); + age_double_col.insert(age_double_col.end(), + new_age_double_col.begin(), + new_age_double_col.end()); + + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + // if (dtype == DataType::INT8) { + // dsl_string.replace(loc, 5, dsl_string_int8); + // } else if (dtype == DataType::INT16) { + // dsl_string.replace(loc, 5, dsl_string_int16); + // } else if (dtype == DataType::INT32) { + // dsl_string.replace(loc, 5, dsl_string_int32); + // } else if (dtype == DataType::INT64) { + // dsl_string.replace(loc, 5, dsl_string_int64); + // } else if (dtype == DataType::FLOAT) { + // dsl_string.replace(loc, 5, dsl_string_float); + // } else if (dtype == DataType::DOUBLE) { + // dsl_string.replace(loc, 5, dsl_string_double); + // } else { + // ASSERT_TRUE(false) << "No test case defined for this data type"; + // } + // loc = dsl_string.find("@@@@"); + // dsl_string.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) + << clause << "@" << i << "!!" << val << std::endl; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeNullable) { + std::vector< + std::tuple, DataType>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) == 8; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) == 1500; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) == 4000; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) == 1000; + }, + DataType::INT64}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) == 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::DOUBLE}, + // Add test cases for BinaryArithOpEvalRangeExpr NE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) != 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) != 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) != 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) != 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) > 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterThan + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) > 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) > 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) > 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) > 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) > 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) >= 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: GreaterEqual + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) >= 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) >= 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) >= 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) >= 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) >= 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) < 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessThan + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) < 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) < 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) < 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) < 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) < 2500; + }, + DataType::INT64}, + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 105 + data_type: Float + > + arith_op: Add + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) <= 2500; + }, + DataType::FLOAT}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 106 + data_type: Double + > + arith_op: Sub + right_operand: < + float_val: 500 + > + op: LessEqual + value: < + float_val: 2500 + > + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) <= 2500; + }, + DataType::DOUBLE}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Int8 + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) <= 2; + }, + DataType::INT8}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Int16 + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + > + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) <= 1000; + }, + DataType::INT16}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Int32 + > + arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + > + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) <= 0; + }, + DataType::INT32}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Int64 + > + arith_op: Mod + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 2500 + > + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) <= 2500; + }, + DataType::INT64}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i8_nullable_fid = schema->AddDebugField("age8", DataType::INT8, true); + auto i16_nullable_fid = + schema->AddDebugField("age16", DataType::INT16, true); + auto i32_nullable_fid = + schema->AddDebugField("age32", DataType::INT32, true); + auto i64_nullable_fid = + schema->AddDebugField("age64_nullable", DataType::INT64, true); + auto float_nullable_fid = + schema->AddDebugField("age_float", DataType::FLOAT, true); + auto double_nullable_fid = + schema->AddDebugField("age_double", DataType::DOUBLE, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + schema->set_primary_field_id(i64_fid); + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector age8_col; + std::vector age16_col; + std::vector age32_col; + std::vector age64_col; + std::vector age_float_col; + std::vector age_double_col; + FixedVector age8_valid_col; + FixedVector age16_valid_col; + FixedVector age32_valid_col; + FixedVector age64_valid_col; + FixedVector age_float_valid_col; + FixedVector age_double_valid_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto new_age8_col = raw_data.get_col(i8_nullable_fid); + auto new_age16_col = raw_data.get_col(i16_nullable_fid); + auto new_age32_col = raw_data.get_col(i32_nullable_fid); + auto new_age64_col = raw_data.get_col(i64_nullable_fid); + auto new_age_float_col = raw_data.get_col(float_nullable_fid); + auto new_age_double_col = raw_data.get_col(double_nullable_fid); + age8_valid_col = raw_data.get_col_valid(i8_nullable_fid); + age16_valid_col = raw_data.get_col_valid(i16_nullable_fid); + age32_valid_col = raw_data.get_col_valid(i32_nullable_fid); + age64_valid_col = raw_data.get_col_valid(i64_nullable_fid); + age_float_valid_col = raw_data.get_col_valid(float_nullable_fid); + age_double_valid_col = raw_data.get_col_valid(double_nullable_fid); + + age8_col.insert( + age8_col.end(), new_age8_col.begin(), new_age8_col.end()); + age16_col.insert( + age16_col.end(), new_age16_col.begin(), new_age16_col.end()); + age32_col.insert( + age32_col.end(), new_age32_col.begin(), new_age32_col.end()); + age64_col.insert( + age64_col.end(), new_age64_col.begin(), new_age64_col.end()); + age_float_col.insert(age_float_col.end(), + new_age_float_col.begin(), + new_age_float_col.end()); + age_double_col.insert(age_double_col.end(), + new_age_double_col.begin(), + new_age_double_col.end()); + + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + // if (dtype == DataType::INT8) { + // dsl_string.replace(loc, 5, dsl_string_int8); + // } else if (dtype == DataType::INT16) { + // dsl_string.replace(loc, 5, dsl_string_int16); + // } else if (dtype == DataType::INT32) { + // dsl_string.replace(loc, 5, dsl_string_int32); + // } else if (dtype == DataType::INT64) { + // dsl_string.replace(loc, 5, dsl_string_int64); + // } else if (dtype == DataType::FLOAT) { + // dsl_string.replace(loc, 5, dsl_string_float); + // } else if (dtype == DataType::DOUBLE) { + // dsl_string.replace(loc, 5, dsl_string_double); + // } else { + // ASSERT_TRUE(false) << "No test case defined for this data type"; + // } + // loc = dsl_string.find("@@@@"); + // dsl_string.replace(loc, 4, clause); + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val, age8_valid_col[i]); + ASSERT_EQ(ans, ref) + << clause << "@" << i << "!!" << val << std::endl; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val, age16_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val, age32_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val, age64_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val, age_float_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val, age_double_valid_col[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + std::vector< + std::tuple>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: Equal + value: + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: Equal + value: + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length == 4; + }}, + // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length != 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: GreaterThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: GreaterThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length > 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: GreaterEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: GreaterEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length >= 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: LessThan + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: LessThan + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length < 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: LessEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: LessEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json) { + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length <= 4; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = + ref_func(milvus::Json(simdjson::padded_string(json_col[i]))); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONNullable) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + std::vector< + std::tuple>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: Equal + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) == 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > arith_op: Mod right_operand: < - int64_val: 500 + int64_val: 2 + > + op: Equal + value: + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) == 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: Equal + value: + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length == 4; + }}, + // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Add + right_operand: < + int64_val: 1 > op: NotEqual value: < - int64_val: 2500 + int64_val: 2 > >)", - [](int64_t v) { return (v + 500) != 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Sub + right_operand: < + int64_val: 1 + > + op: NotEqual + value: < + int64_val: 2 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" + > + arith_op: Mod + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) != 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" + > + arith_op: ArrayLength + op: NotEqual + value: < + int64_val: 4 + > + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length != 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) > 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) > 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) > 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2765,15 +7467,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterThan value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) > 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2781,80 +7490,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterThan value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) > 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: GreaterThan value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) > 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) > 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: GreaterThan value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) > 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length > 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) >= 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: GreaterEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) >= 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) >= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2862,15 +7609,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterEqual value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) >= 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2878,80 +7632,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: GreaterEqual value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) >= 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: GreaterEqual value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) >= 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) >= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: GreaterEqual value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) >= 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length >= 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: LessThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) < 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: LessThan value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) < 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) < 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -2959,15 +7751,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessThan value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) < 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -2975,80 +7774,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessThan value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) < 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: LessThan - value: < - int64_val: 0 - > - >)", - [](int32_t v) { return (v % 100) < 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id: 104 - data_type: Int64 + value: < + int64_val: 4 > - arith_op: Mod - right_operand: < - int64_val: 500 + >)", + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) < 4; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: LessThan value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) < 2500; }, - DataType::INT64}, - // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length < 4; + }}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 105 - data_type: Float + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Add right_operand: < - float_val: 500 + int64_val: 1 > op: LessEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](float v) { return (v + 500) <= 2500; }, - DataType::FLOAT}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val + 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 106 - data_type: Double + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Sub right_operand: < - float_val: 500 + int64_val: 1 > op: LessEqual value: < - float_val: 2500 + int64_val: 2 > >)", - [](double v) { return (v - 500) <= 2500; }, - DataType::DOUBLE}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val - 1) <= 2; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 101 - data_type: Int8 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mul right_operand: < @@ -3056,15 +7893,22 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessEqual value: < - int64_val: 2 + int64_val: 4 > >)", - [](int8_t v) { return (v * 2) <= 2; }, - DataType::INT8}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val * 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 102 - data_type: Int16 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Div right_operand: < @@ -3072,58 +7916,761 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { > op: LessEqual value: < - int64_val: 1000 + int64_val: 4 > >)", - [](int16_t v) { return (v / 2) <= 1000; }, - DataType::INT16}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val / 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 103 - data_type: Int32 + field_id:102 + data_type:JSON + nested_path:"int" > arith_op: Mod right_operand: < - int64_val: 100 + int64_val: 2 > op: LessEqual value: < - int64_val: 0 + int64_val: 4 > >)", - [](int32_t v) { return (v % 100) <= 0; }, - DataType::INT32}, - {R"(binary_arith_op_eval_range_expr: < + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"int"}); + auto val = json.template at(pointer).value(); + return (val % 2) <= 4; + }}, + {R"(binary_arith_op_eval_range_expr: < column_info: < - field_id: 104 - data_type: Int64 - > - arith_op: Mod - right_operand: < - int64_val: 500 + field_id:102 + data_type:JSON + nested_path:"array" > + arith_op: ArrayLength op: LessEqual value: < - int64_val: 2500 + int64_val: 4 > >)", - [](int64_t v) { return (v + 500) <= 2500; }, - DataType::INT64}, + [](const milvus::Json& json, bool valid) { + if (!valid) { + return false; + } + auto pointer = milvus::Json::pointer({"array"}); + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length <= 4; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto nullable_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(nullable_fid); + valid_data = raw_data.get_col_valid(nullable_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 5, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = + ref_func(milvus::Json(simdjson::padded_string(json_col[i])), + valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { + struct Testcase { + double right_operand; + double value; + OpType op; + std::vector nested_path; + }; + std::vector testcases{ + {10, 20, OpType::Equal, {"double"}}, + {20, 30, OpType::Equal, {"double"}}, + {30, 40, OpType::NotEqual, {"double"}}, + {40, 50, OpType::NotEqual, {"double"}}, + {10, 20, OpType::Equal, {"int"}}, + {20, 30, OpType::Equal, {"int"}}, + {30, 40, OpType::NotEqual, {"int"}}, + {40, 50, OpType::NotEqual, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](double value) { + if (testcase.op == OpType::Equal) { + return value + testcase.right_operand == testcase.value; + } + return value + testcase.right_operand != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_float_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_float_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::Add, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref) + << testcase.value << " " << val << " " << testcase.op; + } + } + + std::vector array_testcases{ + {0, 3, OpType::Equal, {"array"}}, + {0, 5, OpType::NotEqual, {"array"}}, + }; + + for (auto testcase : array_testcases) { + auto check = [&](int64_t value) { + if (testcase.op == OpType::Equal) { + return value == testcase.value; + } + return value != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_int64_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_int64_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::ArrayLength, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto json = milvus::Json(simdjson::padded_string(json_col[i])); + int64_t array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + auto ref = check(array_length); + ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloatNullable) { + struct Testcase { + double right_operand; + double value; + OpType op; + std::vector nested_path; + }; + std::vector testcases{ + {10, 20, OpType::Equal, {"double"}}, + {20, 30, OpType::Equal, {"double"}}, + {30, 40, OpType::NotEqual, {"double"}}, + {40, 50, OpType::NotEqual, {"double"}}, + {10, 20, OpType::Equal, {"int"}}, + {20, 30, OpType::Equal, {"int"}}, + {30, 40, OpType::NotEqual, {"int"}}, + {40, 50, OpType::NotEqual, {"int"}}, + }; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { + auto check = [&](double value, bool valid) { + if (!valid) { + return false; + } + if (testcase.op == OpType::Equal) { + return value + testcase.right_operand == testcase.value; + } + return value + testcase.right_operand != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_float_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_float_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::Add, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(pointer) + .value(); + auto ref = check(val, valid_data[i]); + ASSERT_EQ(ans, ref) + << testcase.value << " " << val << " " << testcase.op; + } + } + + std::vector array_testcases{ + {0, 3, OpType::Equal, {"array"}}, + {0, 5, OpType::NotEqual, {"array"}}, }; - std::string raw_plan_tmp = R"(vector_anns: < - field_id: 100 - predicates: < - @@@@@ - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" + for (auto testcase : array_testcases) { + auto check = [&](int64_t value, bool valid) { + if (!valid) { + return false; + } + if (testcase.op == OpType::Equal) { + return value == testcase.value; + } + return value != testcase.value; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + proto::plan::GenericValue value; + value.set_int64_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_int64_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::ArrayLength, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto json = milvus::Json(simdjson::padded_string(json_col[i])); + int64_t array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + auto ref = check(array_length, valid_data[i]); + ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; + } + } +} + +TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { + std::vector, DataType>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) == 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) == 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) == 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) == 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) == 0; }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](float v) { return (v + 500) == 2500; }, + DataType::FLOAT}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](double v) { return (v + 500) == 2500; }, + DataType::DOUBLE}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2000 + >)", + [](float v) { return (v + 500) != 2000; }, + DataType::FLOAT}, + {R"(arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + >)", + [](double v) { return (v - 500) != 2000; }, + DataType::DOUBLE}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + >)", + [](int8_t v) { return (v * 2) != 2; }, + DataType::INT8}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int16_t v) { return (v / 2) != 2000; }, + DataType::INT16}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 1 + >)", + [](int32_t v) { return (v % 100) != 1; }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int64_t v) { return (v + 500) != 2000; }, + DataType::INT64}, + + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterThan + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) > 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) > 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) > 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) > 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) > 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterEqual + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) >= 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) >= 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) >= 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) >= 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) >= 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessThan + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) < 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) < 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) < 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) < 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) < 0; }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessEqual + value: < + int64_val: 8 + >)", + [](int8_t v) { return (v + 4) <= 8; }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 1500 + >)", + [](int16_t v) { return (v - 500) <= 1500; }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4000 + >)", + [](int32_t v) { return (v * 2) <= 4000; }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + >)", + [](int64_t v) { return (v / 2) <= 1000; }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + >)", + [](int32_t v) { return (v % 100) <= 0; }, + DataType::INT32}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + binary_arith_op_eval_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + > + @@@@)"; + auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto i8_fid = schema->AddDebugField("age8", DataType::INT8); @@ -3134,89 +8681,118 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); schema->set_primary_field_id(i64_fid); - auto seg = CreateGrowingSegment(schema, empty_index_meta); + auto seg = CreateSealedSegment(schema); int N = 1000; - std::vector age8_col; - std::vector age16_col; - std::vector age32_col; - std::vector age64_col; - std::vector age_float_col; - std::vector age_double_col; - int num_iters = 1; - for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; - auto new_age8_col = raw_data.get_col(i8_fid); - auto new_age16_col = raw_data.get_col(i16_fid); - auto new_age32_col = raw_data.get_col(i32_fid); - auto new_age64_col = raw_data.get_col(i64_fid); - auto new_age_float_col = raw_data.get_col(float_fid); - auto new_age_double_col = raw_data.get_col(double_fid); + // load index for int8 field + auto age8_col = raw_data.get_col(i8_fid); + age8_col[0] = 4; + auto age8_index = milvus::index::CreateScalarIndexSort(); + age8_index->Build(N, age8_col.data(), nullptr); + load_index_info.field_id = i8_fid.get(); + load_index_info.field_type = DataType::INT8; + load_index_info.index = std::move(age8_index); + seg->LoadIndex(load_index_info); - age8_col.insert( - age8_col.end(), new_age8_col.begin(), new_age8_col.end()); - age16_col.insert( - age16_col.end(), new_age16_col.begin(), new_age16_col.end()); - age32_col.insert( - age32_col.end(), new_age32_col.begin(), new_age32_col.end()); - age64_col.insert( - age64_col.end(), new_age64_col.begin(), new_age64_col.end()); - age_float_col.insert(age_float_col.end(), - new_age_float_col.begin(), - new_age_float_col.end()); - age_double_col.insert(age_double_col.end(), - new_age_double_col.begin(), - new_age_double_col.end()); + // load index for 16 field + auto age16_col = raw_data.get_col(i16_fid); + age16_col[0] = 2000; + auto age16_index = milvus::index::CreateScalarIndexSort(); + age16_index->Build(N, age16_col.data(), nullptr); + load_index_info.field_id = i16_fid.get(); + load_index_info.field_type = DataType::INT16; + load_index_info.index = std::move(age16_index); + seg->LoadIndex(load_index_info); - seg->PreInsert(N); - seg->Insert(iter * N, - N, - raw_data.row_ids_.data(), - raw_data.timestamps_.data(), - raw_data.raw_); - } + // load index for int32 field + auto age32_col = raw_data.get_col(i32_fid); + age32_col[0] = 2000; + auto age32_index = milvus::index::CreateScalarIndexSort(); + age32_index->Build(N, age32_col.data(), nullptr); + load_index_info.field_id = i32_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(age32_index); + seg->LoadIndex(load_index_info); - auto seg_promote = dynamic_cast(seg.get()); + // load index for int64 field + auto age64_col = raw_data.get_col(i64_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data(), nullptr); + load_index_info.field_id = i64_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); + + // load index for float field + auto age_float_col = raw_data.get_col(float_fid); + age_float_col[0] = 2000; + auto age_float_index = milvus::index::CreateScalarIndexSort(); + age_float_index->Build(N, age_float_col.data(), nullptr); + load_index_info.field_id = float_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_float_index); + seg->LoadIndex(load_index_info); + + // load index for double field + auto age_double_col = raw_data.get_col(double_fid); + age_double_col[0] = 2000; + auto age_double_index = milvus::index::CreateScalarIndexSort(); + age_double_index->Build(N, age_double_col.data(), nullptr); + load_index_info.field_id = double_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_double_index); + seg->LoadIndex(load_index_info); + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { - auto loc = raw_plan_tmp.find("@@@@@"); - auto raw_plan = raw_plan_tmp; - raw_plan.replace(loc, 5, clause); - // if (dtype == DataType::INT8) { - // dsl_string.replace(loc, 5, dsl_string_int8); - // } else if (dtype == DataType::INT16) { - // dsl_string.replace(loc, 5, dsl_string_int16); - // } else if (dtype == DataType::INT32) { - // dsl_string.replace(loc, 5, dsl_string_int32); - // } else if (dtype == DataType::INT64) { - // dsl_string.replace(loc, 5, dsl_string_int64); - // } else if (dtype == DataType::FLOAT) { - // dsl_string.replace(loc, 5, dsl_string_float); - // } else if (dtype == DataType::DOUBLE) { - // dsl_string.replace(loc, 5, dsl_string_double); - // } else { - // ASSERT_TRUE(false) << "No test case defined for this data type"; - // } - // loc = dsl_string.find("@@@@"); - // dsl_string.replace(loc, 4, clause); - auto plan_str = translate_text_plan_with_metric_type(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + if (dtype == DataType::INT8) { + expr = boost::format(expr_plan) % vec_fid.get() % i8_fid.get() % + proto::schema::DataType_Name(int(DataType::INT8)); + } else if (dtype == DataType::INT16) { + expr = boost::format(expr_plan) % vec_fid.get() % i16_fid.get() % + proto::schema::DataType_Name(int(DataType::INT16)); + } else if (dtype == DataType::INT32) { + expr = boost::format(expr_plan) % vec_fid.get() % i32_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)); + } else if (dtype == DataType::INT64) { + expr = boost::format(expr_plan) % vec_fid.get() % i64_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)); + } else if (dtype == DataType::FLOAT) { + expr = boost::format(expr_plan) % vec_fid.get() % float_fid.get() % + proto::schema::DataType_Name(int(DataType::FLOAT)); + } else if (dtype == DataType::DOUBLE) { + expr = boost::format(expr_plan) % vec_fid.get() % double_fid.get() % + proto::schema::DataType_Name(int(DataType::DOUBLE)); + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + + auto binary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg.get(), - N * num_iters, - MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N * num_iters); + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, N, final); + EXPECT_EQ(final.size(), N); - for (int i = 0; i < N * num_iters; ++i) { + for (int i = 0; i < N; ++i) { auto ans = final[i]; if (dtype == DataType::INT8) { auto val = age8_col[i]; auto ref = ref_func(val); - ASSERT_EQ(ans, ref) - << clause << "@" << i << "!!" << val << std::endl; + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } else if (dtype == DataType::INT16) { auto val = age16_col[i]; auto ref = ref_func(val); @@ -3244,771 +8820,1205 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { } } -TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; +TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndexNullable) { + std::vector< + std::tuple, DataType>> + testcases = { + // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: Equal + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) == 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: Equal + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) == 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) == 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: Equal + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) == 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: Equal + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) == 0; + }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::FLOAT}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: Equal + value: < + float_val: 2500 + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) == 2500; + }, + DataType::DOUBLE}, + {R"(arith_op: Add + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2000 + >)", + [](float v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2000; + }, + DataType::FLOAT}, + {R"(arith_op: Sub + right_operand: < + float_val: 500 + > + op: NotEqual + value: < + float_val: 2500 + >)", + [](double v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) != 2000; + }, + DataType::DOUBLE}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) != 2; + }, + DataType::INT8}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) != 2000; + }, + DataType::INT16}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: NotEqual + value: < + int64_val: 1 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) != 1; + }, + DataType::INT32}, + {R"(arith_op: Add + right_operand: < + int64_val: 500 + > + op: NotEqual + value: < + int64_val: 2000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 500) != 2000; + }, + DataType::INT64}, - std::vector< - std::tuple>> - testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: Equal - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) == 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: Equal - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) == 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) == 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) == 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: Equal - value: - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) == 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: Equal - value: - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterThan + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) > 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterThan + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) > 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; } - return array_length == 4; - }}, - // Add test cases for BinaryArithOpEvalRangeExpr NQ of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: NotEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) != 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: NotEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) != 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) != 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) != 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) != 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: NotEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + return (v * 2) > 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterThan + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) > 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterThan + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) > 0; + }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: GreaterEqual + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) >= 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: GreaterEqual + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) >= 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) >= 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: GreaterEqual + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) >= 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: GreaterEqual + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) >= 0; + }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessThan + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) < 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessThan + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) < 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) < 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessThan + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) < 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessThan + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; } - return array_length != 4; - }}, + return (v % 100) < 0; + }, + DataType::INT32}, + + // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types + {R"(arith_op: Add + right_operand: < + int64_val: 4 + > + op: LessEqual + value: < + int64_val: 8 + >)", + [](int8_t v, bool valid) { + if (!valid) { + return false; + } + return (v + 4) <= 8; + }, + DataType::INT8}, + {R"(arith_op: Sub + right_operand: < + int64_val: 500 + > + op: LessEqual + value: < + int64_val: 1500 + >)", + [](int16_t v, bool valid) { + if (!valid) { + return false; + } + return (v - 500) <= 1500; + }, + DataType::INT16}, + {R"(arith_op: Mul + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 4000 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v * 2) <= 4000; + }, + DataType::INT32}, + {R"(arith_op: Div + right_operand: < + int64_val: 2 + > + op: LessEqual + value: < + int64_val: 1000 + >)", + [](int64_t v, bool valid) { + if (!valid) { + return false; + } + return (v / 2) <= 1000; + }, + DataType::INT64}, + {R"(arith_op: Mod + right_operand: < + int64_val: 100 + > + op: LessEqual + value: < + int64_val: 0 + >)", + [](int32_t v, bool valid) { + if (!valid) { + return false; + } + return (v % 100) <= 0; + }, + DataType::INT32}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + binary_arith_op_eval_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i8_nullable_fid = schema->AddDebugField("age8", DataType::INT8, true); + auto i16_nullable_fid = + schema->AddDebugField("age16", DataType::INT16, true); + auto i32_nullable_fid = + schema->AddDebugField("age32", DataType::INT32, true); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto i64_nullable_fid = + schema->AddDebugField("age641", DataType::INT64, true); + auto float_nullable_fid = + schema->AddDebugField("age_float", DataType::FLOAT, true); + auto double_nullable_fid = + schema->AddDebugField("age_double", DataType::DOUBLE, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000; + auto raw_data = DataGen(schema, N); + segcore::LoadIndexInfo load_index_info; + + auto i8_valid_data = raw_data.get_col_valid(i8_nullable_fid); + auto i16_valid_data = raw_data.get_col_valid(i16_nullable_fid); + auto i32_valid_data = raw_data.get_col_valid(i32_nullable_fid); + auto i64_valid_data = raw_data.get_col_valid(i64_nullable_fid); + auto float_valid_data = raw_data.get_col_valid(float_nullable_fid); + auto double_valid_data = raw_data.get_col_valid(double_nullable_fid); + + // load index for int8 field + auto age8_col = raw_data.get_col(i8_nullable_fid); + age8_col[0] = 4; + auto age8_index = milvus::index::CreateScalarIndexSort(); + age8_index->Build(N, age8_col.data(), i8_valid_data.data()); + load_index_info.field_id = i8_nullable_fid.get(); + load_index_info.field_type = DataType::INT8; + load_index_info.index = std::move(age8_index); + seg->LoadIndex(load_index_info); + + // load index for 16 field + auto age16_col = raw_data.get_col(i16_nullable_fid); + age16_col[0] = 2000; + auto age16_index = milvus::index::CreateScalarIndexSort(); + age16_index->Build(N, age16_col.data(), i16_valid_data.data()); + load_index_info.field_id = i16_nullable_fid.get(); + load_index_info.field_type = DataType::INT16; + load_index_info.index = std::move(age16_index); + seg->LoadIndex(load_index_info); + + // load index for int32 field + auto age32_col = raw_data.get_col(i32_nullable_fid); + age32_col[0] = 2000; + auto age32_index = milvus::index::CreateScalarIndexSort(); + age32_index->Build(N, age32_col.data(), i32_valid_data.data()); + load_index_info.field_id = i32_nullable_fid.get(); + load_index_info.field_type = DataType::INT32; + load_index_info.index = std::move(age32_index); + seg->LoadIndex(load_index_info); + + // load index for int64 field + auto age64_col = raw_data.get_col(i64_nullable_fid); + age64_col[0] = 2000; + auto age64_index = milvus::index::CreateScalarIndexSort(); + age64_index->Build(N, age64_col.data(), i64_valid_data.data()); + load_index_info.field_id = i64_nullable_fid.get(); + load_index_info.field_type = DataType::INT64; + load_index_info.index = std::move(age64_index); + seg->LoadIndex(load_index_info); + + // load index for float field + auto age_float_col = raw_data.get_col(float_nullable_fid); + age_float_col[0] = 2000; + auto age_float_index = milvus::index::CreateScalarIndexSort(); + age_float_index->Build(N, age_float_col.data(), float_valid_data.data()); + load_index_info.field_id = float_nullable_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_float_index); + seg->LoadIndex(load_index_info); + + // load index for double field + auto age_double_col = raw_data.get_col(double_nullable_fid); + age_double_col[0] = 2000; + auto age_double_index = milvus::index::CreateScalarIndexSort(); + age_double_index->Build(N, age_double_col.data(), double_valid_data.data()); + load_index_info.field_id = double_nullable_fid.get(); + load_index_info.field_type = DataType::FLOAT; + load_index_info.index = std::move(age_double_index); + seg->LoadIndex(load_index_info); + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + if (dtype == DataType::INT8) { + expr = boost::format(expr_plan) % vec_fid.get() % + i8_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT8)); + } else if (dtype == DataType::INT16) { + expr = boost::format(expr_plan) % vec_fid.get() % + i16_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT16)); + } else if (dtype == DataType::INT32) { + expr = boost::format(expr_plan) % vec_fid.get() % + i32_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT32)); + } else if (dtype == DataType::INT64) { + expr = boost::format(expr_plan) % vec_fid.get() % + i64_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::INT64)); + } else if (dtype == DataType::FLOAT) { + expr = boost::format(expr_plan) % vec_fid.get() % + float_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::FLOAT)); + } else if (dtype == DataType::DOUBLE) { + expr = boost::format(expr_plan) % vec_fid.get() % + double_nullable_fid.get() % + proto::schema::DataType_Name(int(DataType::DOUBLE)); + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + + auto binary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, binary_plan.data(), binary_plan.size()); + + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, N, final); + EXPECT_EQ(final.size(), N); + + for (int i = 0; i < N; ++i) { + auto ans = final[i]; + if (dtype == DataType::INT8) { + auto val = age8_col[i]; + auto ref = ref_func(val, i8_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT16) { + auto val = age16_col[i]; + auto ref = ref_func(val, i16_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT32) { + auto val = age32_col[i]; + auto ref = ref_func(val, i32_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = age64_col[i]; + auto ref = ref_func(val, i64_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::FLOAT) { + auto val = age_float_col[i]; + auto ref = ref_func(val, float_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = age_double_col[i]; + auto ref = ref_func(val, double_valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestUnaryRangeWithJSON) { + std::vector< + std::tuple)>, + DataType>> + testcases = { + {R"(op: Equal + value: < + bool_val: true + >)", + [](std::variant v) { + return std::get(v); + }, + DataType::BOOL}, + {R"(op: LessEqual + value: < + int64_val: 1500 + >)", + [](std::variant v) { + return std::get(v) < 1500; + }, + DataType::INT64}, + {R"(op: LessEqual + value: < + float_val: 4000 + >)", + [](std::variant v) { + return std::get(v) <= 4000; + }, + DataType::DOUBLE}, + {R"(op: GreaterThan + value: < + float_val: 1000 + >)", + [](std::variant v) { + return std::get(v) > 1000; + }, + DataType::DOUBLE}, + {R"(op: GreaterEqual + value: < + int64_val: 0 + >)", + [](std::variant v) { + return std::get(v) >= 0; + }, + DataType::INT64}, + {R"(op: NotEqual + value: < + bool_val: true + >)", + [](std::variant v) { + return !std::get(v); + }, + DataType::BOOL}, + {R"(op: Equal + value: < + string_val: "test" + >)", + [](std::variant v) { + return std::get(v) == "test"; + }, + DataType::STRING}, + }; - // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: GreaterThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) > 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: GreaterThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) > 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) > 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) > 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) > 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: GreaterThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); - } - return array_length > 4; - }}, + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + unary_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} - // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: GreaterEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) >= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: GreaterEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) >= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) >= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) >= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) >= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: GreaterEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); +TEST_P(ExprTest, TestUnaryRangeWithJSONNullable) { + std::vector, bool)>, + DataType>> + testcases = { + {R"(op: Equal + value: < + bool_val: true + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length >= 4; - }}, - - // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: LessThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) < 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: LessThan - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) < 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) < 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) < 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) < 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: LessThan - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + return std::get(v); + }, + DataType::BOOL}, + {R"(op: LessEqual + value: < + int64_val: 1500 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length < 4; - }}, - - // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Add - right_operand: < - int64_val: 1 - > - op: LessEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val + 1) <= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Sub - right_operand: < - int64_val: 1 - > - op: LessEqual - value: < - int64_val: 2 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val - 1) <= 2; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val * 2) <= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val / 2) <= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"int" - > - arith_op: Mod - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"int"}); - auto val = json.template at(pointer).value(); - return (val % 2) <= 4; - }}, - {R"(binary_arith_op_eval_range_expr: < - column_info: < - field_id:102 - data_type:JSON - nested_path:"array" - > - arith_op: ArrayLength - op: LessEqual - value: < - int64_val: 4 - > - >)", - [](const milvus::Json& json) { - auto pointer = milvus::Json::pointer({"array"}); - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); + return std::get(v) < 1500; + }, + DataType::INT64}, + {R"(op: LessEqual + value: < + float_val: 4000 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; } - return array_length <= 4; - }}, + return std::get(v) <= 4000; + }, + DataType::DOUBLE}, + {R"(op: GreaterThan + value: < + float_val: 1000 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return std::get(v) > 1000; + }, + DataType::DOUBLE}, + {R"(op: GreaterEqual + value: < + int64_val: 0 + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return std::get(v) >= 0; + }, + DataType::INT64}, + {R"(op: NotEqual + value: < + bool_val: true + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return !std::get(v); + }, + DataType::BOOL}, + {R"(op: Equal + value: < + string_val: "test" + >)", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return std::get(v) == "test"; + }, + DataType::STRING}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + unary_range_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestTermWithJSON) { + std::vector< + std::tuple)>, + DataType>> + testcases = { + {R"(values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {true, false}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::BOOL}, + {R"(values: , values: , values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {1500, 2048, 3216}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::INT64}, + {R"(values: , values: , values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {1500.0, 4000, 235.14}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::DOUBLE}, + {R"(values: , values: , values: )", + [](std::variant v) { + std::unordered_set term_set; + term_set = {"aaa", "abc", "235.14"}; + return term_set.find(std::get(v)) != + term_set.end(); + }, + DataType::STRING}, + {R"()", + [](std::variant v) { + return false; + }, + DataType::INT64}, }; - std::string raw_plan_tmp = R"(vector_anns: < - field_id: 100 - predicates: < - @@@@@ - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + term_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; + auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField( - "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); - auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); auto json_fid = schema->AddDebugField("json", DataType::JSON); schema->set_primary_field_id(i64_fid); @@ -4031,507 +10041,304 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); - for (auto [clause, ref_func] : testcases) { - auto loc = raw_plan_tmp.find("@@@@@"); - auto raw_plan = raw_plan_tmp; - raw_plan.replace(loc, 5, clause); - auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto ref = - ref_func(milvus::Json(simdjson::padded_string(json_col[i]))); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << json_col[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } } } } -TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { - struct Testcase { - double right_operand; - double value; - OpType op; - std::vector nested_path; - }; - std::vector testcases{ - {10, 20, OpType::Equal, {"double"}}, - {20, 30, OpType::Equal, {"double"}}, - {30, 40, OpType::NotEqual, {"double"}}, - {40, 50, OpType::NotEqual, {"double"}}, - {10, 20, OpType::Equal, {"int"}}, - {20, 30, OpType::Equal, {"int"}}, - {30, 40, OpType::NotEqual, {"int"}}, - {40, 50, OpType::NotEqual, {"int"}}, - }; +TEST_P(ExprTest, TestTermWithJSONNullable) { + std::vector, bool)>, + DataType>> + testcases = { + {R"(values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {true, false}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::BOOL}, + {R"(values: , values: , values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {1500, 2048, 3216}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::INT64}, + {R"(values: , values: , values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {1500.0, 4000, 235.14}; + return term_set.find(std::get(v)) != term_set.end(); + }, + DataType::DOUBLE}, + {R"(values: , values: , values: )", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + std::unordered_set term_set; + term_set = {"aaa", "abc", "235.14"}; + return term_set.find(std::get(v)) != + term_set.end(); + }, + DataType::STRING}, + {R"()", + [](std::variant v, + bool valid) { + if (!valid) { + return false; + } + return false; + }, + DataType::INT64}, + }; + + std::string serialized_expr_plan = R"(vector_anns: < + field_id: %1% + predicates: < + term_expr: < + @@@@@ + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + std::string arith_expr = R"( + column_info: < + field_id: %2% + data_type: %3% + nested_path:"%4%" + > + @@@@)"; auto schema = std::make_shared(); - auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, - raw_data.row_ids_.data(), - raw_data.timestamps_.data(), - raw_data.raw_); - } - - auto seg_promote = dynamic_cast(seg.get()); - - for (auto testcase : testcases) { - auto check = [&](double value) { - if (testcase.op == OpType::Equal) { - return value + testcase.right_operand == testcase.value; - } - return value + testcase.right_operand != testcase.value; - }; - auto pointer = milvus::Json::pointer(testcase.nested_path); - proto::plan::GenericValue value; - value.set_float_val(testcase.value); - proto::plan::GenericValue right_operand; - right_operand.set_float_val(testcase.right_operand); - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - testcase.op, - ArithOpType::Add, - value, - right_operand); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N * num_iters); - - for (int i = 0; i < N * num_iters; ++i) { - auto ans = final[i]; - - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(pointer) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref) - << testcase.value << " " << val << " " << testcase.op; - } - } - - std::vector array_testcases{ - {0, 3, OpType::Equal, {"array"}}, - {0, 5, OpType::NotEqual, {"array"}}, - }; - - for (auto testcase : array_testcases) { - auto check = [&](int64_t value) { - if (testcase.op == OpType::Equal) { - return value == testcase.value; - } - return value != testcase.value; - }; - auto pointer = milvus::Json::pointer(testcase.nested_path); - proto::plan::GenericValue value; - value.set_int64_val(testcase.value); - proto::plan::GenericValue right_operand; - right_operand.set_int64_val(testcase.right_operand); - auto expr = std::make_shared( - milvus::expr::ColumnInfo( - json_fid, DataType::JSON, testcase.nested_path), - testcase.op, - ArithOpType::ArrayLength, - value, - right_operand); - BitsetType final; - auto plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N * num_iters); - - for (int i = 0; i < N * num_iters; ++i) { - auto ans = final[i]; - - auto json = milvus::Json(simdjson::padded_string(json_col[i])); - int64_t array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); - } - auto ref = check(array_length); - ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; - } - } -} - -TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { - std::vector, DataType>> - testcases = { - // Add test cases for BinaryArithOpEvalRangeExpr EQ of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: Equal - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) == 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: Equal - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) == 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) == 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: Equal - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) == 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: Equal - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) == 0; }, - DataType::INT32}, - {R"(arith_op: Add - right_operand: < - float_val: 500 - > - op: Equal - value: < - float_val: 2500 - >)", - [](float v) { return (v + 500) == 2500; }, - DataType::FLOAT}, - {R"(arith_op: Add - right_operand: < - float_val: 500 - > - op: Equal - value: < - float_val: 2500 - >)", - [](double v) { return (v + 500) == 2500; }, - DataType::DOUBLE}, - {R"(arith_op: Add - right_operand: < - float_val: 500 - > - op: NotEqual - value: < - float_val: 2000 - >)", - [](float v) { return (v + 500) != 2000; }, - DataType::FLOAT}, - {R"(arith_op: Sub - right_operand: < - float_val: 500 - > - op: NotEqual - value: < - float_val: 2500 - >)", - [](double v) { return (v - 500) != 2000; }, - DataType::DOUBLE}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 2 - >)", - [](int8_t v) { return (v * 2) != 2; }, - DataType::INT8}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: NotEqual - value: < - int64_val: 2000 - >)", - [](int16_t v) { return (v / 2) != 2000; }, - DataType::INT16}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: NotEqual - value: < - int64_val: 1 - >)", - [](int32_t v) { return (v % 100) != 1; }, - DataType::INT32}, - {R"(arith_op: Add - right_operand: < - int64_val: 500 - > - op: NotEqual - value: < - int64_val: 2000 - >)", - [](int64_t v) { return (v + 500) != 2000; }, - DataType::INT64}, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } - // Add test cases for BinaryArithOpEvalRangeExpr GT of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: GreaterThan - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) > 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: GreaterThan - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) > 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) > 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterThan - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) > 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: GreaterThan - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) > 0; }, - DataType::INT32}, + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { + auto loc = serialized_expr_plan.find("@@@@@"); + auto expr_plan = serialized_expr_plan; + expr_plan.replace(loc, 5, arith_expr); + loc = expr_plan.find("@@@@"); + expr_plan.replace(loc, 4, clause); + boost::format expr; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } - // Add test cases for BinaryArithOpEvalRangeExpr GE of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: GreaterEqual - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) >= 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: GreaterEqual - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) >= 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) >= 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: GreaterEqual - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) >= 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: GreaterEqual - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) >= 0; }, - DataType::INT32}, + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); - // Add test cases for BinaryArithOpEvalRangeExpr LT of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: LessThan - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) < 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: LessThan - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) < 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) < 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessThan - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) < 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: LessThan - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) < 0; }, - DataType::INT32}, + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); - // Add test cases for BinaryArithOpEvalRangeExpr LE of various data types - {R"(arith_op: Add - right_operand: < - int64_val: 4 - > - op: LessEqual - value: < - int64_val: 8 - >)", - [](int8_t v) { return (v + 4) <= 8; }, - DataType::INT8}, - {R"(arith_op: Sub - right_operand: < - int64_val: 500 - > - op: LessEqual - value: < - int64_val: 1500 - >)", - [](int16_t v) { return (v - 500) <= 1500; }, - DataType::INT16}, - {R"(arith_op: Mul - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 4000 - >)", - [](int32_t v) { return (v * 2) <= 4000; }, - DataType::INT32}, - {R"(arith_op: Div - right_operand: < - int64_val: 2 - > - op: LessEqual - value: < - int64_val: 1000 - >)", - [](int64_t v) { return (v / 2) <= 1000; }, - DataType::INT64}, - {R"(arith_op: Mod - right_operand: < - int64_val: 100 - > - op: LessEqual - value: < - int64_val: 0 - >)", - [](int32_t v) { return (v % 100) <= 0; }, - DataType::INT32}, + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/bool") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/int") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/double") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at("/string") + .value(); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestExistsWithJSON) { + std::vector, DataType>> + testcases = { + {R"()", [](bool v) { return v; }, DataType::BOOL}, + {R"()", [](bool v) { return v; }, DataType::INT64}, + {R"()", [](bool v) { return v; }, DataType::STRING}, + {R"()", [](bool v) { return v; }, DataType::VARCHAR}, + {R"()", [](bool v) { return v; }, DataType::DOUBLE}, }; std::string serialized_expr_plan = R"(vector_anns: < field_id: %1% predicates: < - binary_arith_op_eval_range_expr: < + exists_expr: < @@@@@ > > @@ -4545,95 +10352,39 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { >)"; std::string arith_expr = R"( - column_info: < + info: < field_id: %2% data_type: %3% + nested_path:"%4%" > @@@@)"; auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i8_fid = schema->AddDebugField("age8", DataType::INT8); - auto i16_fid = schema->AddDebugField("age16", DataType::INT16); - auto i32_fid = schema->AddDebugField("age32", DataType::INT32); auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto float_fid = schema->AddDebugField("age_float", DataType::FLOAT); - auto double_fid = schema->AddDebugField("age_double", DataType::DOUBLE); + auto json_fid = schema->AddDebugField("json", DataType::JSON); schema->set_primary_field_id(i64_fid); - auto seg = CreateSealedSegment(schema); + auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; - auto raw_data = DataGen(schema, N); - segcore::LoadIndexInfo load_index_info; - - // load index for int8 field - auto age8_col = raw_data.get_col(i8_fid); - age8_col[0] = 4; - GenScalarIndexing(N, age8_col.data()); - auto age8_index = milvus::index::CreateScalarIndexSort(); - age8_index->Build(N, age8_col.data()); - load_index_info.field_id = i8_fid.get(); - load_index_info.field_type = DataType::INT8; - load_index_info.index = std::move(age8_index); - seg->LoadIndex(load_index_info); - - // load index for 16 field - auto age16_col = raw_data.get_col(i16_fid); - age16_col[0] = 2000; - GenScalarIndexing(N, age16_col.data()); - auto age16_index = milvus::index::CreateScalarIndexSort(); - age16_index->Build(N, age16_col.data()); - load_index_info.field_id = i16_fid.get(); - load_index_info.field_type = DataType::INT16; - load_index_info.index = std::move(age16_index); - seg->LoadIndex(load_index_info); - - // load index for int32 field - auto age32_col = raw_data.get_col(i32_fid); - age32_col[0] = 2000; - GenScalarIndexing(N, age32_col.data()); - auto age32_index = milvus::index::CreateScalarIndexSort(); - age32_index->Build(N, age32_col.data()); - load_index_info.field_id = i32_fid.get(); - load_index_info.field_type = DataType::INT32; - load_index_info.index = std::move(age32_index); - seg->LoadIndex(load_index_info); - - // load index for int64 field - auto age64_col = raw_data.get_col(i64_fid); - age64_col[0] = 2000; - GenScalarIndexing(N, age64_col.data()); - auto age64_index = milvus::index::CreateScalarIndexSort(); - age64_index->Build(N, age64_col.data()); - load_index_info.field_id = i64_fid.get(); - load_index_info.field_type = DataType::INT64; - load_index_info.index = std::move(age64_index); - seg->LoadIndex(load_index_info); - - // load index for float field - auto age_float_col = raw_data.get_col(float_fid); - age_float_col[0] = 2000; - GenScalarIndexing(N, age_float_col.data()); - auto age_float_index = milvus::index::CreateScalarIndexSort(); - age_float_index->Build(N, age_float_col.data()); - load_index_info.field_id = float_fid.get(); - load_index_info.field_type = DataType::FLOAT; - load_index_info.index = std::move(age_float_index); - seg->LoadIndex(load_index_info); - - // load index for double field - auto age_double_col = raw_data.get_col(double_fid); - age_double_col[0] = 2000; - GenScalarIndexing(N, age_double_col.data()); - auto age_double_index = milvus::index::CreateScalarIndexSort(); - age_double_index->Build(N, age_double_col.data()); - load_index_info.field_id = double_fid.get(); - load_index_info.field_type = DataType::FLOAT; - load_index_info.index = std::move(age_double_index); - seg->LoadIndex(load_index_info); + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); - auto seg_promote = dynamic_cast(seg.get()); + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -4642,142 +10393,141 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { loc = expr_plan.find("@@@@"); expr_plan.replace(loc, 4, clause); boost::format expr; - if (dtype == DataType::INT8) { - expr = boost::format(expr_plan) % vec_fid.get() % i8_fid.get() % - proto::schema::DataType_Name(int(DataType::INT8)); - } else if (dtype == DataType::INT16) { - expr = boost::format(expr_plan) % vec_fid.get() % i16_fid.get() % - proto::schema::DataType_Name(int(DataType::INT16)); - } else if (dtype == DataType::INT32) { - expr = boost::format(expr_plan) % vec_fid.get() % i32_fid.get() % - proto::schema::DataType_Name(int(DataType::INT32)); - } else if (dtype == DataType::INT64) { - expr = boost::format(expr_plan) % vec_fid.get() % i64_fid.get() % - proto::schema::DataType_Name(int(DataType::INT64)); - } else if (dtype == DataType::FLOAT) { - expr = boost::format(expr_plan) % vec_fid.get() % float_fid.get() % - proto::schema::DataType_Name(int(DataType::FLOAT)); - } else if (dtype == DataType::DOUBLE) { - expr = boost::format(expr_plan) % vec_fid.get() % double_fid.get() % - proto::schema::DataType_Name(int(DataType::DOUBLE)); - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + switch (dtype) { + case DataType::BOOL: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; + break; + } + case DataType::INT64: { + expr = + boost::format(expr_plan) % vec_fid.get() % json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % "int"; + break; + } + case DataType::DOUBLE: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "double"; + break; + } + case DataType::STRING: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "string"; + break; + } + case DataType::VARCHAR: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "varchar"; + break; + } + default: { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } } - auto binary_plan = translate_text_plan_with_metric_type(expr.str()); + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); auto plan = CreateSearchPlanByExpr( - *schema, binary_plan.data(), binary_plan.size()); + *schema, unary_plan.data(), unary_plan.size()); BitsetType final; final = ExecuteQueryExpr( plan->plan_node_->plannodes_->sources()[0]->sources()[0], seg_promote, - N, + N * num_iters, MAX_TIMESTAMP); - EXPECT_EQ(final.size(), N); + EXPECT_EQ(final.size(), N * num_iters); - for (int i = 0; i < N; ++i) { + for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::INT8) { - auto val = age8_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT16) { - auto val = age16_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT32) { - auto val = age32_col[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/bool"); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } else if (dtype == DataType::INT64) { - auto val = age64_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::FLOAT) { - auto val = age_float_col[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/int"); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } else if (dtype == DataType::DOUBLE) { - auto val = age_double_col[i]; + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/double"); auto ref = ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; - } - } - } -} - -TEST_P(ExprTest, TestUnaryRangeWithJSON) { - std::vector< - std::tuple)>, - DataType>> - testcases = { - {R"(op: Equal - value: < - bool_val: true - >)", - [](std::variant v) { - return std::get(v); - }, - DataType::BOOL}, - {R"(op: LessEqual - value: < - int64_val: 1500 - >)", - [](std::variant v) { - return std::get(v) < 1500; - }, - DataType::INT64}, - {R"(op: LessEqual - value: < - float_val: 4000 - >)", - [](std::variant v) { - return std::get(v) <= 4000; - }, - DataType::DOUBLE}, - {R"(op: GreaterThan - value: < - float_val: 1000 - >)", - [](std::variant v) { - return std::get(v) > 1000; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/string"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::VARCHAR) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/varchar"); + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +TEST_P(ExprTest, TestExistsWithJSONNullable) { + std::vector< + std::tuple, DataType>> + testcases = { + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; }, - DataType::DOUBLE}, - {R"(op: GreaterEqual - value: < - int64_val: 0 - >)", - [](std::variant v) { - return std::get(v) >= 0; + DataType::BOOL}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; }, DataType::INT64}, - {R"(op: NotEqual - value: < - bool_val: true - >)", - [](std::variant v) { - return !std::get(v); - }, - DataType::BOOL}, - {R"(op: Equal - value: < - string_val: "test" - >)", - [](std::variant v) { - return std::get(v) == "test"; + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; }, DataType::STRING}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; + }, + DataType::VARCHAR}, + {R"()", + [](bool v, bool valid) { + if (!valid) { + return false; + } + return v; + }, + DataType::DOUBLE}, }; std::string serialized_expr_plan = R"(vector_anns: < field_id: %1% predicates: < - unary_range_expr: < + exists_expr: < @@@@@ > > @@ -4791,7 +10541,7 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { >)"; std::string arith_expr = R"( - column_info: < + info: < field_id: %2% data_type: %3% nested_path:"%4%" @@ -4801,16 +10551,18 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -4823,7 +10575,7 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -4859,136 +10611,301 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { "string"; break; } + case DataType::VARCHAR: { + expr = boost::format(expr_plan) % vec_fid.get() % + json_fid.get() % + proto::schema::DataType_Name(int(DataType::JSON)) % + "varchar"; + break; + } default: { ASSERT_TRUE(false) << "No test case defined for this data type"; } } - auto unary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, unary_plan.data(), unary_plan.size()); + auto unary_plan = translate_text_plan_with_metric_type(expr.str()); + auto plan = CreateSearchPlanByExpr( + *schema, unary_plan.data(), unary_plan.size()); + + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (dtype == DataType::BOOL) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/bool"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::INT64) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/int"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::DOUBLE) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/double"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::STRING) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/string"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else if (dtype == DataType::VARCHAR) { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .exist("/varchar"); + auto ref = ref_func(val, valid_data[i]); + ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + } else { + ASSERT_TRUE(false) << "No test case defined for this data type"; + } + } + } +} + +template +struct Testcase { + std::vector term; + std::vector nested_path; + bool res; +}; + +TEST_P(ExprTest, TestTermInFieldJson) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + // std::cout << "cost" + // << std::chrono::duration_cast( + // std::chrono::steady_clock::now() - start) + // .count() + // << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::BOOL) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/bool") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/int") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/double") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::STRING) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/string") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res)); } } } -TEST_P(ExprTest, TestTermWithJSON) { - std::vector< - std::tuple)>, - DataType>> - testcases = { - {R"(values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {true, false}; - return term_set.find(std::get(v)) != term_set.end(); - }, - DataType::BOOL}, - {R"(values: , values: , values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {1500, 2048, 3216}; - return term_set.find(std::get(v)) != term_set.end(); - }, - DataType::INT64}, - {R"(values: , values: , values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {1500.0, 4000, 235.14}; - return term_set.find(std::get(v)) != term_set.end(); - }, - DataType::DOUBLE}, - {R"(values: , values: , values: )", - [](std::variant v) { - std::unordered_set term_set; - term_set = {"aaa", "abc", "235.14"}; - return term_set.find(std::get(v)) != - term_set.end(); - }, - DataType::STRING}, - {R"()", - [](std::variant v) { - return false; - }, - DataType::INT64}, - }; - - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - term_expr: < - @@@@@ - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; - - std::string arith_expr = R"( - column_info: < - field_id: %2% - data_type: %3% - nested_path:"%4%" - > - @@@@)"; - +TEST_P(ExprTest, TestTermInFieldJsonNullable) { auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); + auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -5001,129 +10918,355 @@ TEST_P(ExprTest, TestTermWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - int offset = 0; - for (auto [clause, ref_func, dtype] : testcases) { - auto loc = serialized_expr_plan.find("@@@@@"); - auto expr_plan = serialized_expr_plan; - expr_plan.replace(loc, 5, arith_expr); - loc = expr_plan.find("@@@@"); - expr_plan.replace(loc, 4, clause); - boost::format expr; - switch (dtype) { - case DataType::BOOL: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; - break; + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; } - case DataType::INT64: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "int"; - break; + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + // std::cout << "cost" + // << std::chrono::duration_cast( + // std::chrono::steady_clock::now() - start) + // .count() + // << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - case DataType::DOUBLE: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "double"; - break; + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - case DataType::STRING: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "string"; - break; + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; } - default: { - ASSERT_TRUE(false) << "No test case defined for this data type"; + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res, valid_data[i])); } + } - auto unary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, unary_plan.data(), unary_plan.size()); + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values, + bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, + true); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::BOOL) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/bool") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/int") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/double") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::STRING) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at("/string") - .value(); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res, valid_data[i])); } } } -TEST_P(ExprTest, TestExistsWithJSON) { - std::vector, DataType>> - testcases = { - {R"()", [](bool v) { return v; }, DataType::BOOL}, - {R"()", [](bool v) { return v; }, DataType::INT64}, - {R"()", [](bool v) { return v; }, DataType::STRING}, - {R"()", [](bool v) { return v; }, DataType::VARCHAR}, - {R"()", [](bool v) { return v; }, DataType::DOUBLE}, - }; - - std::string serialized_expr_plan = R"(vector_anns: < - field_id: %1% - predicates: < - exists_expr: < - @@@@@ - > - > - query_info: < - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > - placeholder_tag: "$0" - >)"; - - std::string arith_expr = R"( - info: < - field_id: %2% - data_type: %3% - nested_path:"%4%" - > - @@@@)"; - +TEST_P(ExprTest, PraseJsonContainsExpr) { + std::vector raw_plans{ + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAny + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:JSON + nested_path:"A" + > + elements: + elements: + elements: + elements: + op:ContainsAll + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + }; + + for (auto& raw_plan : raw_plans) { + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto schema = std::make_shared(); + schema->AddDebugField("fakevec", data_type, 16, metric_type); + schema->AddDebugField("json", DataType::JSON); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + } +} + +TEST_P(ExprTest, TestJsonContainsAny) { auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); auto json_fid = schema->AddDebugField("json", DataType::JSON); schema->set_primary_field_id(i64_fid); @@ -5132,7 +11275,7 @@ TEST_P(ExprTest, TestExistsWithJSON) { std::vector json_col; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); + auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); json_col.insert( @@ -5146,120 +11289,216 @@ TEST_P(ExprTest, TestExistsWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - int offset = 0; - for (auto [clause, ref_func, dtype] : testcases) { - auto loc = serialized_expr_plan.find("@@@@@"); - auto expr_plan = serialized_expr_plan; - expr_plan.replace(loc, 5, arith_expr); - loc = expr_plan.find("@@@@"); - expr_plan.replace(loc, 4, clause); - boost::format expr; - switch (dtype) { - case DataType::BOOL: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "bool"; - break; - } - case DataType::INT64: { - expr = - boost::format(expr_plan) % vec_fid.get() % json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % "int"; - break; - } - case DataType::DOUBLE: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "double"; - break; - } - case DataType::STRING: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "string"; - break; + std::vector> bool_testcases{{{true}, {"bool"}}, + {{false}, {"bool"}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - case DataType::VARCHAR: { - expr = boost::format(expr_plan) % vec_fid.get() % - json_fid.get() % - proto::schema::DataType_Name(int(DataType::JSON)) % - "varchar"; - break; + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } - default: { - ASSERT_TRUE(false) << "No test case defined for this data type"; + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); } - - auto unary_plan = translate_text_plan_with_metric_type(expr.str()); - auto plan = CreateSearchPlanByExpr( - *schema, unary_plan.data(), unary_plan.size()); - + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); BitsetType final; - final = ExecuteQueryExpr( - plan->plan_node_->plannodes_->sources()[0]->sources()[0], - seg_promote, - N * num_iters, - MAX_TIMESTAMP); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - if (dtype == DataType::BOOL) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/bool"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::INT64) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/int"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::DOUBLE) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/double"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::STRING) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/string"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else if (dtype == DataType::VARCHAR) { - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .exist("/varchar"); - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } else { - ASSERT_TRUE(false) << "No test case defined for this data type"; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); } + ASSERT_EQ(ans, check(res)); } } } -template -struct Testcase { - std::vector term; - std::vector nested_path; - bool res; -}; - -TEST_P(ExprTest, TestTermInFieldJson) { +TEST_P(ExprTest, TestJsonContainsAnyNullable) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -5272,12 +11511,16 @@ TEST_P(ExprTest, TestTermInFieldJson) { } auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true}, {"bool"}}, {{false}, {"bool"}}}; for (auto testcase : bool_testcases) { - auto check = [&](const std::vector& values) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; @@ -5288,22 +11531,22 @@ TEST_P(ExprTest, TestTermInFieldJson) { val.set_bool_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - // std::cout << "cost" - // << std::chrono::duration_cast( - // std::chrono::steady_clock::now() - start) - // .count() - // << std::endl; + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -5314,41 +11557,225 @@ TEST_P(ExprTest, TestTermInFieldJson) { for (const auto& element : array) { res.push_back(element.template get()); } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> double_testcases{ + {{1.123}, {"double"}}, + {{10.34}, {"double"}}, + {{100.234}, {"double"}}, + {{1000.4546}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases{ + {{1}, {"int"}}, + {{10}, {"int"}}, + {{100}, {"int"}}, + {{1000}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases_string = { + {{"1sads"}, {"string"}}, + {{"10dsf"}, {"string"}}, + {{"100"}, {"string"}}, + {{"100ddfdsssdfdsfsd0"}, {"string"}}, + }; + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values, + bool valid) { + if (!valid) { + return false; + } + return std::find(values.begin(), values.end(), testcase.term[0]) != + values.end(); + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); } } +} + +TEST_P(ExprTest, TestJsonContainsAll) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); - std::vector> double_testcases{ - {{1.123}, {"double"}}, - {{10.34}, {"double"}}, - {{100.234}, {"double"}}, - {{1000.4546}, {"double"}}, - }; + std::vector> bool_testcases{{{true, true}, {"bool"}}, + {{false, false}, {"bool"}}}; - for (auto testcase : double_testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; for (auto v : testcase.term) { proto::plan::GenericValue val; - val.set_float_val(v); + val.set_bool_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - final = - ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); - + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5360,44 +11787,53 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto ans = final[i]; auto array = milvus::Json(simdjson::padded_string(json_col[i])) .array_at(pointer); - std::vector res; + std::vector res; for (const auto& element : array) { - res.push_back(element.template get()); + res.push_back(element.template get()); } ASSERT_EQ(ans, check(res)); } } - std::vector> testcases{ - {{1}, {"int"}}, - {{10}, {"int"}}, - {{100}, {"int"}}, - {{1000}, {"int"}}, + std::vector> double_testcases{ + {{1.123, 10.34}, {"double"}}, + {{10.34, 100.234}, {"double"}}, + {{100.234, 1000.4546}, {"double"}}, + {{1000.4546, 1.123}, {"double"}}, + {{1000.4546, 10.34}, {"double"}}, + {{1.123, 100.234}, {"double"}}, }; - for (auto testcase : testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; for (auto& v : testcase.term) { proto::plan::GenericValue val; - val.set_int64_val(v); + val.set_float_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5409,38 +11845,46 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto ans = final[i]; auto array = milvus::Json(simdjson::padded_string(json_col[i])) .array_at(pointer); - std::vector res; + std::vector res; for (const auto& element : array) { - res.push_back(element.template get()); + res.push_back(element.template get()); } ASSERT_EQ(ans, check(res)); } } - std::vector> testcases_string = { - {{"1sads"}, {"string"}}, - {{"10dsf"}, {"string"}}, - {{"100"}, {"string"}}, - {{"100ddfdsssdfdsfsd0"}, {"string"}}, + std::vector> testcases{ + {{1, 10}, {"int"}}, + {{10, 100}, {"int"}}, + {{100, 1000}, {"int"}}, + {{1000, 10}, {"int"}}, + {{2, 4, 6, 8, 10}, {"int"}}, + {{1, 2, 3, 4, 5}, {"int"}}, }; - for (auto testcase : testcases_string) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; for (auto& v : testcase.term) { proto::plan::GenericValue val; - val.set_string_val(v); + val.set_int64_val(v); values.push_back(val); } - auto expr = std::make_shared( + auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - values, - true); + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); @@ -5458,170 +11902,84 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto ans = final[i]; auto array = milvus::Json(simdjson::padded_string(json_col[i])) .array_at(pointer); - std::vector res; + std::vector res; for (const auto& element : array) { - res.push_back(element.template get()); + res.push_back(element.template get()); } ASSERT_EQ(ans, check(res)); } } -} -TEST_P(ExprTest, PraseJsonContainsExpr) { - std::vector raw_plans{ - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAny - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: elements: elements: - op:ContainsAll - elements_same_type:true - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", - R"(vector_anns:< - field_id:100 - predicates:< - json_contains_expr:< - column_info:< - field_id:101 - data_type:JSON - nested_path:"A" - > - elements: - elements: - elements: - elements: - op:ContainsAll - > - > - query_info:< - topk: 10 - round_decimal: 3 - metric_type: "L2" - search_params: "{\"nprobe\": 10}" - > placeholder_tag:"$0" - >)", + std::vector> testcases_string = { + {{"1sads", "10dsf"}, {"string"}}, + {{"10dsf", "100"}, {"string"}}, + {{"100", "10dsf", "1sads"}, {"string"}}, + {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, }; - - for (auto& raw_plan : raw_plans) { - auto plan_str = translate_text_plan_with_metric_type(raw_plan); - auto schema = std::make_shared(); - schema->AddDebugField("fakevec", data_type, 16, metric_type); - schema->AddDebugField("json", DataType::JSON); + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res)); + } } } -TEST_P(ExprTest, TestJsonContainsAny) { +TEST_P(ExprTest, TestJsonContainsAllNullable) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -5635,13 +11993,21 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto seg_promote = dynamic_cast(seg.get()); - std::vector> bool_testcases{{{true}, {"bool"}}, - {{false}, {"bool"}}}; + std::vector> bool_testcases{{{true, true}, {"bool"}}, + {{false, false}, {"bool"}}}; for (auto testcase : bool_testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; @@ -5653,7 +12019,7 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, values); BitsetType final; @@ -5677,21 +12043,31 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (const auto& element : array) { res.push_back(element.template get()); } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check(res, valid_data[i])); } } std::vector> double_testcases{ - {{1.123}, {"double"}}, - {{10.34}, {"double"}}, - {{100.234}, {"double"}}, - {{1000.4546}, {"double"}}, + {{1.123, 10.34}, {"double"}}, + {{10.34, 100.234}, {"double"}}, + {{100.234, 1000.4546}, {"double"}}, + {{1000.4546, 1.123}, {"double"}}, + {{1000.4546, 10.34}, {"double"}}, + {{1.123, 100.234}, {"double"}}, }; for (auto testcase : double_testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); std::vector values; @@ -5703,7 +12079,7 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, values); BitsetType final; @@ -5727,38 +12103,212 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (const auto& element : array) { res.push_back(element.template get()); } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases{ + {{1, 10}, {"int"}}, + {{10, 100}, {"int"}}, + {{100, 1000}, {"int"}}, + {{1000, 10}, {"int"}}, + {{2, 4, 6, 8, 10}, {"int"}}, + {{1, 2, 3, 4, 5}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values, bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } + + std::vector> testcases_string = { + {{"1sads", "10dsf"}, {"string"}}, + {{"10dsf", "100"}, {"string"}}, + {{"100", "10dsf", "1sads"}, {"string"}}, + {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, + }; + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values, + bool valid) { + if (!valid) { + return false; + } + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Json(simdjson::padded_string(json_col[i])) + .array_at(pointer); + std::vector res; + for (const auto& element : array) { + res.push_back(element.template get()); + } + ASSERT_EQ(ans, check(res, valid_data[i])); + } + } +} + +TEST_P(ExprTest, TestJsonContainsArray) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + proto::plan::GenericValue generic_a; + auto* a = generic_a.mutable_array_val(); + a->set_same_type(false); + for (int i = 0; i < 4; ++i) { + if (i % 4 == 0) { + proto::plan::GenericValue int_val; + int_val.set_int64_val(int64_t(i)); + a->add_array()->CopyFrom(int_val); + } else if ((i - 1) % 4 == 0) { + proto::plan::GenericValue bool_val; + bool_val.set_bool_val(bool(i)); + a->add_array()->CopyFrom(bool_val); + } else if ((i - 2) % 4 == 0) { + proto::plan::GenericValue float_val; + float_val.set_float_val(double(i)); + a->add_array()->CopyFrom(float_val); + } else if ((i - 3) % 4 == 0) { + proto::plan::GenericValue string_val; + string_val.set_string_val(std::to_string(i)); + a->add_array()->CopyFrom(string_val); } } + proto::plan::GenericValue generic_b; + auto* b = generic_b.mutable_array_val(); + b->set_same_type(true); + proto::plan::GenericValue int_val1; + int_val1.set_int64_val(int64_t(1)); + b->add_array()->CopyFrom(int_val1); + + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(2)); + b->add_array()->CopyFrom(int_val2); + + proto::plan::GenericValue int_val3; + int_val3.set_int64_val(int64_t(3)); + b->add_array()->CopyFrom(int_val3); - std::vector> testcases{ - {{1}, {"int"}}, - {{10}, {"int"}}, - {{100}, {"int"}}, - {{1000}, {"int"}}, - }; + std::vector> diff_testcases{ + {{generic_a}, {"string"}}, {{generic_b}, {"array"}}}; - for (auto testcase : testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto& testcase : diff_testcases) { + auto check = [&](const std::vector& values, int i) { + if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { + return true; + } + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_int64_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAny, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5771,44 +12321,28 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } - std::vector> testcases_string = { - {{"1sads"}, {"string"}}, - {{"10dsf"}, {"string"}}, - {{"100"}, {"string"}}, - {{"100ddfdsssdfdsfsd0"}, {"string"}}, - }; - - for (auto testcase : testcases_string) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), values.end(), testcase.term[0]) != - values.end(); + for (auto& testcase : diff_testcases) { + auto check = [&](const std::vector& values, int i) { + if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { + return true; + } + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_string_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5821,72 +12355,47 @@ TEST_P(ExprTest, TestJsonContainsAny) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } -} - -TEST_P(ExprTest, TestJsonContainsAll) { - auto schema = std::make_shared(); - auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); - schema->set_primary_field_id(i64_fid); - auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 1000; - std::vector json_col; - int num_iters = 1; - for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGenForJsonArray(schema, N, iter); - auto new_json_col = raw_data.get_col(json_fid); + proto::plan::GenericValue g_sub_arr1; + auto* sub_arr1 = g_sub_arr1.mutable_array_val(); + sub_arr1->set_same_type(true); + proto::plan::GenericValue int_val11; + int_val11.set_int64_val(int64_t(1)); + sub_arr1->add_array()->CopyFrom(int_val11); - json_col.insert( - json_col.end(), new_json_col.begin(), new_json_col.end()); - seg->PreInsert(N); - seg->Insert(iter * N, - N, - raw_data.row_ids_.data(), - raw_data.timestamps_.data(), - raw_data.raw_); - } + proto::plan::GenericValue int_val12; + int_val12.set_int64_val(int64_t(2)); + sub_arr1->add_array()->CopyFrom(int_val12); - auto seg_promote = dynamic_cast(seg.get()); + proto::plan::GenericValue g_sub_arr2; + auto* sub_arr2 = g_sub_arr2.mutable_array_val(); + sub_arr2->set_same_type(true); + proto::plan::GenericValue int_val21; + int_val21.set_int64_val(int64_t(3)); + sub_arr2->add_array()->CopyFrom(int_val21); - std::vector> bool_testcases{{{true, true}, {"bool"}}, - {{false, false}, {"bool"}}}; + proto::plan::GenericValue int_val22; + int_val22.set_int64_val(int64_t(4)); + sub_arr2->add_array()->CopyFrom(int_val22); + std::vector> diff_testcases2{ + {{g_sub_arr1, g_sub_arr2}, {"array2"}}}; - for (auto testcase : bool_testcases) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } - return true; - }; + for (auto& testcase : diff_testcases2) { + auto check = [&]() { return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto v : testcase.term) { - proto::plan::GenericValue val; - val.set_bool_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5899,51 +12408,24 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check()); } } - std::vector> double_testcases{ - {{1.123, 10.34}, {"double"}}, - {{10.34, 100.234}, {"double"}}, - {{100.234, 1000.4546}, {"double"}}, - {{1000.4546, 1.123}, {"double"}}, - {{1000.4546, 10.34}, {"double"}}, - {{1.123, 100.234}, {"double"}}, - }; - - for (auto testcase : double_testcases) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } + for (auto& testcase : diff_testcases2) { + auto check = [&](const std::vector& values, int i) { return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_float_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -5956,51 +12438,49 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } - std::vector> testcases{ - {{1, 10}, {"int"}}, - {{10, 100}, {"int"}}, - {{100, 1000}, {"int"}}, - {{1000, 10}, {"int"}}, - {{2, 4, 6, 8, 10}, {"int"}}, - {{1, 2, 3, 4, 5}, {"int"}}, - }; + proto::plan::GenericValue g_sub_arr3; + auto* sub_arr3 = g_sub_arr3.mutable_array_val(); + sub_arr3->set_same_type(true); + proto::plan::GenericValue int_val31; + int_val31.set_int64_val(int64_t(5)); + sub_arr3->add_array()->CopyFrom(int_val31); - for (auto testcase : testcases) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } - return true; + proto::plan::GenericValue int_val32; + int_val32.set_int64_val(int64_t(6)); + sub_arr3->add_array()->CopyFrom(int_val32); + + proto::plan::GenericValue g_sub_arr4; + auto* sub_arr4 = g_sub_arr4.mutable_array_val(); + sub_arr4->set_same_type(true); + proto::plan::GenericValue int_val41; + int_val41.set_int64_val(int64_t(7)); + sub_arr4->add_array()->CopyFrom(int_val41); + + proto::plan::GenericValue int_val42; + int_val42.set_int64_val(int64_t(8)); + sub_arr4->add_array()->CopyFrom(int_val42); + std::vector> diff_testcases3{ + {{g_sub_arr3, g_sub_arr4}, {"array2"}}}; + + for (auto& testcase : diff_testcases3) { + auto check = [&](const std::vector& values, int i) { + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_int64_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -6013,49 +12493,25 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } - std::vector> testcases_string = { - {{"1sads", "10dsf"}, {"string"}}, - {{"10dsf", "100"}, {"string"}}, - {{"100", "10dsf", "1sads"}, {"string"}}, - {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, - }; - - for (auto testcase : testcases_string) { - auto check = [&](const std::vector& values) { - for (auto const& e : testcase.term) { - if (std::find(values.begin(), values.end(), e) == - values.end()) { - return false; - } - } - return true; + for (auto& testcase : diff_testcases3) { + auto check = [&](const std::vector& values, int i) { + return false; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - std::vector values; - for (auto& v : testcase.term) { - proto::plan::GenericValue val; - val.set_string_val(v); - values.push_back(val); - } auto expr = std::make_shared( milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, true, - values); - BitsetType final; + testcase.term); auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); @@ -6068,30 +12524,27 @@ TEST_P(ExprTest, TestJsonContainsAll) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - ASSERT_EQ(ans, check(res)); + std::vector res; + ASSERT_EQ(ans, check(res, i)); } } } -TEST_P(ExprTest, TestJsonContainsArray) { +TEST_P(ExprTest, TestJsonContainsArrayNullable) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); - auto json_fid = schema->AddDebugField("json", DataType::JSON); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; + FixedVector valid_data; int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); json_col.insert( json_col.end(), new_json_col.begin(), new_json_col.end()); @@ -6146,7 +12599,10 @@ TEST_P(ExprTest, TestJsonContainsArray) { {{generic_a}, {"string"}}, {{generic_b}, {"array"}}}; for (auto& testcase : diff_testcases) { - auto check = [&](const std::vector& values, int i) { + auto check = [&](const std::vector& values, int i, bool valid) { + if (!valid) { + return false; + } if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { return true; } @@ -6175,12 +12631,15 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check(res, i, valid_data[i])); } } for (auto& testcase : diff_testcases) { - auto check = [&](const std::vector& values, int i) { + auto check = [&](const std::vector& values, int i, bool valid) { + if (!valid) { + return false; + } if (testcase.nested_path[0] == "array" && (i == 1 || i == N + 1)) { return true; } @@ -6209,7 +12668,7 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check(res, i, valid_data[i])); } } @@ -6238,7 +12697,12 @@ TEST_P(ExprTest, TestJsonContainsArray) { {{g_sub_arr1, g_sub_arr2}, {"array2"}}}; for (auto& testcase : diff_testcases2) { - auto check = [&]() { return true; }; + auto check = [&](bool valid) { + if (!valid) { + return false; + } + return true; + }; auto pointer = milvus::Json::pointer(testcase.nested_path); auto expr = std::make_shared( milvus::expr::ColumnInfo( @@ -6261,12 +12725,15 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - ASSERT_EQ(ans, check()); + ASSERT_EQ(ans, check(valid_data[i])); } } for (auto& testcase : diff_testcases2) { - auto check = [&](const std::vector& values, int i) { + auto check = [&](const std::vector& values, int i, bool valid) { + if (!valid) { + return false; + } return true; }; auto pointer = milvus::Json::pointer(testcase.nested_path); @@ -6292,7 +12759,7 @@ TEST_P(ExprTest, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check(res, i, valid_data[i])); } } @@ -6514,6 +12981,115 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArray) { } } +TEST_P(ExprTest, TestJsonContainsDiffTypeArrayNullable) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + proto::plan::GenericValue int_value; + int_value.set_int64_val(1); + auto diff_type_array1 = + generatedArrayWithFourDiffType(1, 2.2, false, "abc"); + auto diff_type_array2 = + generatedArrayWithFourDiffType(1, 2.2, false, "def"); + auto diff_type_array3 = generatedArrayWithFourDiffType(1, 2.2, true, "abc"); + auto diff_type_array4 = + generatedArrayWithFourDiffType(1, 3.3, false, "abc"); + auto diff_type_array5 = + generatedArrayWithFourDiffType(2, 2.2, false, "abc"); + + std::vector> diff_testcases{ + {{diff_type_array1, int_value}, {"array3"}, true}, + {{diff_type_array2, int_value}, {"array3"}, false}, + {{diff_type_array3, int_value}, {"array3"}, false}, + {{diff_type_array4, int_value}, {"array3"}, false}, + {{diff_type_array5, int_value}, {"array3"}, false}, + }; + + for (auto& testcase : diff_testcases) { + auto check = [&](bool valid) { + if (!valid) { + return false; + } + return testcase.res; + }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + ASSERT_EQ(ans, check(valid_data[i])); + } + } + + for (auto& testcase : diff_testcases) { + auto check = [&]() { return false; }; + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + ASSERT_EQ(ans, check()); + } + } +} + TEST_P(ExprTest, TestJsonContainsDiffType) { auto schema = std::make_shared(); auto i64_fid = schema->AddDebugField("id", DataType::INT64); @@ -6621,3 +13197,119 @@ TEST_P(ExprTest, TestJsonContainsDiffType) { } } } +TEST_P(ExprTest, TestJsonContainsDiffTypeNullable) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON, true); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + FixedVector valid_data; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + valid_data = raw_data.get_col_valid(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + + proto::plan::GenericValue int_val; + int_val.set_int64_val(int64_t(3)); + proto::plan::GenericValue bool_val; + bool_val.set_bool_val(bool(false)); + proto::plan::GenericValue float_val; + float_val.set_float_val(double(100.34)); + proto::plan::GenericValue string_val; + string_val.set_string_val("10dsf"); + + proto::plan::GenericValue string_val2; + string_val2.set_string_val("abc"); + proto::plan::GenericValue bool_val2; + bool_val2.set_bool_val(bool(true)); + proto::plan::GenericValue float_val2; + float_val2.set_float_val(double(2.2)); + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(1)); + + std::vector> diff_testcases{ + {{int_val, bool_val, float_val, string_val}, + {"diff_type_array"}, + false}, + {{string_val2, bool_val2, float_val2, int_val2}, + {"diff_type_array"}, + true}, + }; + + for (auto& testcase : diff_testcases) { + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + } else { + ASSERT_EQ(ans, testcase.res); + } + } + } + + for (auto& testcase : diff_testcases) { + auto pointer = milvus::Json::pointer(testcase.nested_path); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + } else { + ASSERT_EQ(ans, testcase.res); + } + } + } +} diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index a406ab8e86bb0..f45079db184f5 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -299,32 +299,475 @@ TEST(StringExpr, Term) { } } +TEST(StringExpr, TermNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto vec_2k_3k = []() -> std::vector { + std::vector ret; + for (int i = 2000; i < 3000; i++) { + ret.push_back(std::to_string(i)); + } + return ret; + }(); + + std::map> terms = { + {0, {"2000", "3000"}}, + {1, {"2000"}}, + {2, {"3000"}}, + {3, {}}, + {4, {vec_2k_3k}}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_str_col = raw_data.get_col(str_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [_, term] : terms) { + auto plan_proto = GenTermPlan(fvec_meta, str_meta, term); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto ref = std::find(term.begin(), term.end(), val) != term.end(); + ASSERT_EQ(ans, ref) << "@" << i << "!!" << val; + } + } +} + TEST(StringExpr, Compare) { auto schema = GenTestSchema(); const auto& fvec_meta = schema->operator[](FieldName("fvec")); const auto& str_meta = schema->operator[](FieldName("str")); - const auto& another_str_meta = schema->operator[](FieldName("another_str")); - - auto gen_compare_plan = - [&, fvec_meta, str_meta, another_str_meta]( - proto::plan::OpType op) -> std::unique_ptr { - auto str_col_info = - test::GenColumnInfo(str_meta.get_id().get(), - proto::schema::DataType::VarChar, - false, - false); - auto another_str_col_info = - test::GenColumnInfo(another_str_meta.get_id().get(), - proto::schema::DataType::VarChar, - false, - false); + const auto& another_str_meta = schema->operator[](FieldName("another_str")); + + auto gen_compare_plan = + [&, fvec_meta, str_meta, another_str_meta]( + proto::plan::OpType op) -> std::unique_ptr { + auto str_col_info = + test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto another_str_col_info = + test::GenColumnInfo(another_str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + + auto compare_expr = GenCompareExpr(op); + compare_expr->set_allocated_left_column_info(str_col_info); + compare_expr->set_allocated_right_column_info(another_str_col_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_compare_expr(compare_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + std::vector>> + testcases{ + {proto::plan::OpType::GreaterThan, + [](std::string& v1, std::string& v2) { return v1 > v2; }}, + {proto::plan::OpType::GreaterEqual, + [](std::string& v1, std::string& v2) { return v1 >= v2; }}, + {proto::plan::OpType::LessThan, + [](std::string& v1, std::string& v2) { return v1 < v2; }}, + {proto::plan::OpType::LessEqual, + [](std::string& v1, std::string& v2) { return v1 <= v2; }}, + {proto::plan::OpType::Equal, + [](std::string& v1, std::string& v2) { return v1 == v2; }}, + {proto::plan::OpType::NotEqual, + [](std::string& v1, std::string& v2) { return v1 != v2; }}, + {proto::plan::OpType::PrefixMatch, + [](std::string& v1, std::string& v2) { + return PrefixMatch(v1, v2); + }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + std::vector another_str_col; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto reserve_col = [&, raw_data](const FieldMeta& field_meta, + std::vector& str_col) { + auto new_str_col = raw_data.get_col(field_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + }; + + reserve_col(str_meta, str_col); + reserve_col(another_str_meta, another_str_col); + + { + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [op, ref_func] : testcases) { + auto plan_proto = gen_compare_plan(op); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto val = str_col[i]; + auto another_val = another_str_col[i]; + auto ref = ref_func(val, another_val); + ASSERT_EQ(ans, ref) << "@" << op << "@" << i << "!!" << val; + } + } +} + +TEST(StringExpr, CompareNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + const auto& another_str_meta = schema->operator[](FieldName("another_str")); + + auto gen_compare_plan = + [&, fvec_meta, str_meta, another_str_meta]( + proto::plan::OpType op) -> std::unique_ptr { + auto str_col_info = + test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto another_str_col_info = + test::GenColumnInfo(another_str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + + auto compare_expr = GenCompareExpr(op); + compare_expr->set_allocated_left_column_info(str_col_info); + compare_expr->set_allocated_right_column_info(another_str_col_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_compare_expr(compare_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + std::vector>> + testcases{ + {proto::plan::OpType::GreaterThan, + [](std::string& v1, std::string& v2) { return v1 > v2; }}, + {proto::plan::OpType::GreaterEqual, + [](std::string& v1, std::string& v2) { return v1 >= v2; }}, + {proto::plan::OpType::LessThan, + [](std::string& v1, std::string& v2) { return v1 < v2; }}, + {proto::plan::OpType::LessEqual, + [](std::string& v1, std::string& v2) { return v1 <= v2; }}, + {proto::plan::OpType::Equal, + [](std::string& v1, std::string& v2) { return v1 == v2; }}, + {proto::plan::OpType::NotEqual, + [](std::string& v1, std::string& v2) { return v1 != v2; }}, + {proto::plan::OpType::PrefixMatch, + [](std::string& v1, std::string& v2) { + return PrefixMatch(v1, v2); + }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + std::vector another_str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto reserve_col = [&, raw_data](const FieldMeta& field_meta, + std::vector& str_col) { + auto new_str_col = raw_data.get_col(field_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + }; + + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + + reserve_col(str_meta, str_col); + reserve_col(another_str_meta, another_str_col); + + { + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [op, ref_func] : testcases) { + auto plan_proto = gen_compare_plan(op); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto another_val = another_str_col[i]; + auto ref = ref_func(val, another_val); + ASSERT_EQ(ans, ref) << "@" << op << "@" << i << "!!" << val; + } + } +} + +TEST(StringExpr, CompareNullable2) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR); + schema->AddDebugField("another_str", DataType::VARCHAR, true); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + const auto& another_str_meta = schema->operator[](FieldName("another_str")); + + auto gen_compare_plan = + [&, fvec_meta, str_meta, another_str_meta]( + proto::plan::OpType op) -> std::unique_ptr { + auto str_col_info = + test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto another_str_col_info = + test::GenColumnInfo(another_str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + + auto compare_expr = GenCompareExpr(op); + compare_expr->set_allocated_left_column_info(str_col_info); + compare_expr->set_allocated_right_column_info(another_str_col_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_compare_expr(compare_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + std::vector>> + testcases{ + {proto::plan::OpType::GreaterThan, + [](std::string& v1, std::string& v2) { return v1 > v2; }}, + {proto::plan::OpType::GreaterEqual, + [](std::string& v1, std::string& v2) { return v1 >= v2; }}, + {proto::plan::OpType::LessThan, + [](std::string& v1, std::string& v2) { return v1 < v2; }}, + {proto::plan::OpType::LessEqual, + [](std::string& v1, std::string& v2) { return v1 <= v2; }}, + {proto::plan::OpType::Equal, + [](std::string& v1, std::string& v2) { return v1 == v2; }}, + {proto::plan::OpType::NotEqual, + [](std::string& v1, std::string& v2) { return v1 != v2; }}, + {proto::plan::OpType::PrefixMatch, + [](std::string& v1, std::string& v2) { + return PrefixMatch(v1, v2); + }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + std::vector another_str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + + auto reserve_col = [&, raw_data](const FieldMeta& field_meta, + std::vector& str_col) { + auto new_str_col = raw_data.get_col(field_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + }; + + auto new_str_valid_col = + raw_data.get_col_valid(another_str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + + reserve_col(str_meta, str_col); + reserve_col(another_str_meta, another_str_col); + + { + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [op, ref_func] : testcases) { + auto plan_proto = gen_compare_plan(op); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto another_val = another_str_col[i]; + auto ref = ref_func(val, another_val); + ASSERT_EQ(ans, ref) << "@" << op << "@" << i << "!!" << val; + } + } +} + +TEST(StringExpr, UnaryRange) { + auto schema = GenTestSchema(); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); - auto compare_expr = GenCompareExpr(op); - compare_expr->set_allocated_left_column_info(str_col_info); - compare_expr->set_allocated_right_column_info(another_str_col_info); + auto gen_unary_range_plan = + [&, fvec_meta, str_meta]( + proto::plan::OpType op, + std::string value) -> std::unique_ptr { + auto column_info = test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto unary_range_expr = test::GenUnaryRangeExpr(op, value); + unary_range_expr->set_allocated_column_info(column_info); auto expr = test::GenExpr().release(); - expr->set_allocated_compare_expr(compare_expr); + expr->set_allocated_unary_range_expr(unary_range_expr); proto::plan::VectorType vector_type; if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { @@ -342,58 +785,47 @@ TEST(StringExpr, Compare) { }; std::vector>> + std::string, + std::function>> testcases{ {proto::plan::OpType::GreaterThan, - [](std::string& v1, std::string& v2) { return v1 > v2; }}, + "2000", + [](std::string& val) { return val > "2000"; }}, {proto::plan::OpType::GreaterEqual, - [](std::string& v1, std::string& v2) { return v1 >= v2; }}, + "2000", + [](std::string& val) { return val >= "2000"; }}, {proto::plan::OpType::LessThan, - [](std::string& v1, std::string& v2) { return v1 < v2; }}, + "3000", + [](std::string& val) { return val < "3000"; }}, {proto::plan::OpType::LessEqual, - [](std::string& v1, std::string& v2) { return v1 <= v2; }}, - {proto::plan::OpType::Equal, - [](std::string& v1, std::string& v2) { return v1 == v2; }}, - {proto::plan::OpType::NotEqual, - [](std::string& v1, std::string& v2) { return v1 != v2; }}, + "3000", + [](std::string& val) { return val <= "3000"; }}, {proto::plan::OpType::PrefixMatch, - [](std::string& v1, std::string& v2) { - return PrefixMatch(v1, v2); - }}, + "a", + [](std::string& val) { return PrefixMatch(val, "a"); }}, }; auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector str_col; - std::vector another_str_col; int num_iters = 100; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); - - auto reserve_col = [&, raw_data](const FieldMeta& field_meta, - std::vector& str_col) { - auto new_str_col = raw_data.get_col(field_meta.get_id()); - auto begin = FIELD_DATA(new_str_col, string).begin(); - auto end = FIELD_DATA(new_str_col, string).end(); - str_col.insert(str_col.end(), begin, end); - }; - - reserve_col(str_meta, str_col); - reserve_col(another_str_meta, another_str_col); - - { - seg->PreInsert(N); - seg->Insert(iter * N, - N, - raw_data.row_ids_.data(), - raw_data.timestamps_.data(), - raw_data.raw_); - } + auto new_str_col = raw_data.get_col(str_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); } auto seg_promote = dynamic_cast(seg.get()); - for (const auto& [op, ref_func] : testcases) { - auto plan_proto = gen_compare_plan(op); + for (const auto& [op, value, ref_func] : testcases) { + auto plan_proto = gen_unary_range_plan(op, value); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); BitsetType final; final = ExecuteQueryExpr( @@ -407,15 +839,21 @@ TEST(StringExpr, Compare) { auto ans = final[i]; auto val = str_col[i]; - auto another_val = another_str_col[i]; - auto ref = ref_func(val, another_val); - ASSERT_EQ(ans, ref) << "@" << op << "@" << i << "!!" << val; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) + << "@" << op << "@" << value << "@" << i << "!!" << val; } } } -TEST(StringExpr, UnaryRange) { - auto schema = GenTestSchema(); +TEST(StringExpr, UnaryRangeNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); const auto& fvec_meta = schema->operator[](FieldName("fvec")); const auto& str_meta = schema->operator[](FieldName("str")); @@ -472,6 +910,7 @@ TEST(StringExpr, UnaryRange) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector str_col; + FixedVector valid_data; int num_iters = 100; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); @@ -479,6 +918,10 @@ TEST(StringExpr, UnaryRange) { auto begin = FIELD_DATA(new_str_col, string).begin(); auto end = FIELD_DATA(new_str_col, string).end(); str_col.insert(str_col.end(), begin, end); + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); seg->PreInsert(N); seg->Insert(iter * N, N, @@ -501,7 +944,10 @@ TEST(StringExpr, UnaryRange) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } auto val = str_col[i]; auto ref = ref_func(val); ASSERT_EQ(ans, ref) @@ -625,6 +1071,135 @@ TEST(StringExpr, BinaryRange) { } } +TEST(StringExpr, BinaryRangeNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto gen_binary_range_plan = + [&, fvec_meta, str_meta]( + bool lb_inclusive, + bool ub_inclusive, + std::string lb, + std::string ub) -> std::unique_ptr { + auto column_info = test::GenColumnInfo(str_meta.get_id().get(), + proto::schema::DataType::VarChar, + false, + false); + auto binary_range_expr = + GenBinaryRangeExpr(lb_inclusive, ub_inclusive, lb, ub); + binary_range_expr->set_allocated_column_info(column_info); + + auto expr = test::GenExpr().release(); + expr->set_allocated_binary_range_expr(binary_range_expr); + + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); + + auto plan_node = std::make_unique(); + plan_node->set_allocated_vector_anns(anns); + return plan_node; + }; + + // bool lb_inclusive, bool ub_inclusive, std::string lb, std::string ub + std::vector>> + testcases{ + {false, + false, + "2000", + "3000", + [](std::string& val) { return val > "2000" && val < "3000"; }}, + {false, + true, + "2000", + "3000", + [](std::string& val) { return val > "2000" && val <= "3000"; }}, + {true, + false, + "2000", + "3000", + [](std::string& val) { return val >= "2000" && val < "3000"; }}, + {true, + true, + "2000", + "3000", + [](std::string& val) { return val >= "2000" && val <= "3000"; }}, + {true, + true, + "2000", + "1000", + [](std::string& val) { return false; }}, + }; + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector str_col; + FixedVector valid_data; + int num_iters = 100; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_str_col = raw_data.get_col(str_meta.get_id()); + auto begin = FIELD_DATA(new_str_col, string).begin(); + auto end = FIELD_DATA(new_str_col, string).end(); + str_col.insert(str_col.end(), begin, end); + auto new_str_valid_col = raw_data.get_col_valid(str_meta.get_id()); + valid_data.insert(valid_data.end(), + new_str_valid_col.begin(), + new_str_valid_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + for (const auto& [lb_inclusive, ub_inclusive, lb, ub, ref_func] : + testcases) { + auto plan_proto = + gen_binary_range_plan(lb_inclusive, ub_inclusive, lb, ub); + auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), + seg_promote, + N * num_iters, + final); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + if (!valid_data[i]) { + ASSERT_EQ(ans, false); + continue; + } + auto val = str_col[i]; + auto ref = ref_func(val); + ASSERT_EQ(ans, ref) + << "@" << lb_inclusive << "@" << ub_inclusive << "@" << lb + << "@" << ub << "@" << i << "!!" << val; + } + } +} + TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { auto schema = GenStrPKSchema(); const auto& fvec_meta = schema->operator[](FieldName("fvec")); @@ -733,4 +1308,47 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFields) { ASSERT_EQ(retrieved->fields_data().size(), 1); ASSERT_EQ(retrieved->fields_data(0).scalars().string_data().data().size(), N); + ASSERT_EQ(retrieved->fields_data(0).valid_data_size(), 0); +} + +TEST(AlwaysTrueStringPlan, QueryWithOutputFieldsNullable) { + auto schema = std::make_shared(); + schema->AddDebugField("str", DataType::VARCHAR, true); + schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema->AddDebugField("int64", DataType::INT64); + schema->set_primary_field_id(pk); + const auto& fvec_meta = schema->operator[](FieldName("fvec")); + const auto& str_meta = schema->operator[](FieldName("str")); + + auto N = 10000; + auto dataset = DataGen(schema, N); + auto vec_col = dataset.get_col(fvec_meta.get_id()); + auto str_col = + dataset.get_col(str_meta.get_id())->scalars().string_data().data(); + auto valid_data = dataset.get_col_valid(str_meta.get_id()); + auto segment = CreateGrowingSegment(schema, empty_index_meta); + segment->PreInsert(N); + segment->Insert(0, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + + auto expr_proto = GenAlwaysTrueExpr(fvec_meta, str_meta); + auto plan_proto = GenPlanNode(); + plan_proto->mutable_query()->set_allocated_predicates(expr_proto); + SetTargetEntry(plan_proto, {str_meta.get_id().get()}); + auto plan = ProtoParser(*schema).CreateRetrievePlan(*plan_proto); + + Timestamp time = MAX_TIMESTAMP; + + auto retrieved = segment->Retrieve( + nullptr, plan.get(), time, DEFAULT_MAX_OUTPUT_SIZE, false); + ASSERT_EQ(retrieved->offset().size(), N); + ASSERT_EQ(retrieved->fields_data().size(), 1); + ASSERT_EQ(retrieved->fields_data(0).scalars().string_data().data().size(), + N); + ASSERT_EQ(retrieved->fields_data(0).valid_data().size(), N); } diff --git a/internal/core/unittest/test_utils/AssertUtils.h b/internal/core/unittest/test_utils/AssertUtils.h index 5e92369b90436..16130fd513610 100644 --- a/internal/core/unittest/test_utils/AssertUtils.h +++ b/internal/core/unittest/test_utils/AssertUtils.h @@ -139,7 +139,7 @@ template inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_EQ(index->Reverse_Lookup(offset), arr[offset]); + ASSERT_EQ(index->Reverse_Lookup(offset).value(), arr[offset]); } } @@ -147,7 +147,8 @@ template <> inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE(compare_float(index->Reverse_Lookup(offset), arr[offset])); + ASSERT_TRUE( + compare_float(index->Reverse_Lookup(offset).value(), arr[offset])); } } @@ -155,7 +156,8 @@ template <> inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE(compare_double(index->Reverse_Lookup(offset), arr[offset])); + ASSERT_TRUE( + compare_double(index->Reverse_Lookup(offset).value(), arr[offset])); } } @@ -164,7 +166,8 @@ inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE(arr[offset].compare(index->Reverse_Lookup(offset)) == 0); + ASSERT_TRUE( + arr[offset].compare(index->Reverse_Lookup(offset).value()) == 0); } } diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 6e4faa28b1394..48af55d6ef7d6 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -667,8 +667,14 @@ DataGenForJsonArray(SchemaPtr schema, auto insert_data = std::make_unique(); auto insert_cols = [&insert_data]( auto& data, int64_t count, auto& field_meta) { + FixedVector valid_data(count); + if (field_meta.is_nullable()) { + for (int i = 0; i < count; ++i) { + valid_data[i] = i % 2 == 0 ? true : false; + } + } auto array = milvus::segcore::CreateDataArrayFrom( - data.data(), nullptr, count, field_meta); + data.data(), valid_data.data(), count, field_meta); insert_data->mutable_fields_data()->AddAllocated(array.release()); }; for (auto field_id : schema->get_field_ids()) { From d02068900cf0d2adab048697e8bed8d0eb8529a7 Mon Sep 17 00:00:00 2001 From: lixinguo Date: Mon, 26 Aug 2024 18:48:22 +0800 Subject: [PATCH 2/4] store valid bitset and add test in not expr Signed-off-by: lixinguo --- internal/core/src/common/Vector.h | 23 +- .../src/exec/expression/AlwaysTrueExpr.cpp | 6 +- .../expression/BinaryArithOpEvalRangeExpr.cpp | 310 +++++++++++-- .../expression/BinaryArithOpEvalRangeExpr.h | 4 +- .../src/exec/expression/BinaryRangeExpr.cpp | 66 +-- .../src/exec/expression/BinaryRangeExpr.h | 16 +- .../core/src/exec/expression/CompareExpr.cpp | 64 ++- .../core/src/exec/expression/CompareExpr.h | 18 + .../core/src/exec/expression/ExistsExpr.cpp | 11 +- internal/core/src/exec/expression/Expr.h | 155 ++++++- .../src/exec/expression/JsonContainsExpr.cpp | 94 ++-- .../src/exec/expression/LogicalUnaryExpr.cpp | 3 + .../core/src/exec/expression/TermExpr.cpp | 76 ++-- .../core/src/exec/expression/UnaryExpr.cpp | 110 +++-- internal/core/src/exec/expression/UnaryExpr.h | 9 +- internal/core/src/index/BitmapIndex.cpp | 30 +- internal/core/src/index/BitmapIndex.h | 2 +- .../core/src/segcore/SegmentSealedImpl.cpp | 6 +- internal/core/unittest/test_expr.cpp | 428 +++++++++++++++--- internal/core/unittest/test_string_expr.cpp | 15 +- 20 files changed, 1151 insertions(+), 295 deletions(-) diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index ac5f0b217b0c5..afc1d4766e079 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -19,6 +19,8 @@ #include #include +#include "EasyAssert.h" +#include "Types.h" #include "common/FieldData.h" namespace milvus { @@ -50,6 +52,7 @@ class BaseVector { protected: DataType type_kind_; size_t length_; + // todo: use null_count to skip some bitset operate std::optional null_count_; }; @@ -65,8 +68,8 @@ class ColumnVector final : public BaseVector { size_t length, std::optional null_count = std::nullopt) : BaseVector(data_type, length, null_count) { - //todo: support null expr values_ = InitScalarFieldData(data_type, false, length); + valid_values_ = InitScalarFieldData(data_type, false, length); } // ColumnVector(FixedVector&& data) @@ -75,15 +78,25 @@ class ColumnVector final : public BaseVector { // std::make_shared>(DataType::BOOL, std::move(data)); // } + // // the size is the number of bits + // ColumnVector(TargetBitmap&& bitmap) + // : BaseVector(DataType::INT8, bitmap.size()) { + // values_ = std::make_shared>( + // bitmap.size(), DataType::INT8, false, std::move(bitmap).into()); + // } + // the size is the number of bits - ColumnVector(TargetBitmap&& bitmap) + ColumnVector(TargetBitmap&& bitmap, TargetBitmap&& valid_bitmap) : BaseVector(DataType::INT8, bitmap.size()) { values_ = std::make_shared>(DataType::INT8, std::move(bitmap)); + valid_values_ = std::make_shared>( + DataType::INT8, std::move(valid_bitmap)); } virtual ~ColumnVector() override { values_.reset(); + valid_values_.reset(); } void* @@ -91,6 +104,11 @@ class ColumnVector final : public BaseVector { return values_->Data(); } + void* + GetValidRawData() { + return valid_values_->Data(); + } + template const As* RawAsValues() const { @@ -99,6 +117,7 @@ class ColumnVector final : public BaseVector { private: FieldDataPtr values_; + FieldDataPtr valid_values_; }; using ColumnVectorPtr = std::shared_ptr; diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp index 94ff5b96986ba..920fc86ee6a17 100644 --- a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp @@ -31,11 +31,13 @@ PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) { return; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); res.set(); + valid_res.set(); result = res_vec; current_pos_ += real_batch_size; diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp index 055acb66e5950..6e24141d09596 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -113,9 +113,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto op_type = expr_->op_type_; @@ -131,6 +133,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { \ if (valid_data && !valid_data[i]) { \ res[i] = false; \ + valid_res[i] = false; \ continue; \ } \ auto x = data[i].template at(pointer); \ @@ -152,6 +155,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { \ if (valid_data && !valid_data[i]) { \ res[i] = false; \ + valid_res[i] = false; \ continue; \ } \ auto x = data[i].template at(pointer); \ @@ -172,6 +176,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val, ValueType right_operand, const std::string& pointer) { @@ -208,6 +213,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { res[i] = false; + valid_res[i] = false; continue; } int array_length = 0; @@ -261,6 +267,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { res[i] = false; + valid_res[i] = false; continue; } int array_length = 0; @@ -314,6 +321,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { res[i] = false; + valid_res[i] = false; continue; } int array_length = 0; @@ -367,6 +375,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { res[i] = false; + valid_res[i] = false; continue; } int array_length = 0; @@ -420,6 +429,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { res[i] = false; + valid_res[i] = false; continue; } int array_length = 0; @@ -473,6 +483,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { res[i] = false; + valid_res[i] = false; continue; } int array_length = 0; @@ -504,6 +515,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { int64_t processed_size = ProcessDataChunks(execute_sub_batch, std::nullptr_t{}, res, + valid_res, value, right_operand, pointer); @@ -525,9 +537,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); int index = -1; if (expr_->column_.nested_path_.size() > 0) { @@ -546,6 +560,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { for (size_t i = 0; i < size; ++i) { \ if (valid_data && !valid_data[i]) { \ res[i] = false; \ + valid_res[i] = false; \ continue; \ } \ if (index >= data[i].length()) { \ @@ -561,6 +576,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val, ValueType right_operand, int index) { @@ -596,6 +612,10 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { + if (valid_data && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } res[i] = data[i].length() == val; } break; @@ -640,7 +660,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = data[i].length() != val; @@ -687,7 +707,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = data[i].length() > val; @@ -734,7 +754,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = data[i].length() >= val; @@ -781,7 +801,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = data[i].length() < val; @@ -828,7 +848,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = data[i].length() <= val; @@ -852,8 +872,14 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } }; - int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, value, right_operand, index); + int64_t processed_size = + ProcessDataChunks(execute_sub_batch, + std::nullptr_t{}, + res, + valid_res, + value, + right_operand, + index); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -1243,12 +1269,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() { return res; }; auto res = ProcessIndexChunks(execute_sub_batch, value, right_operand); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + // return std::make_shared(std::move(res)); + return res; } template @@ -1267,9 +1294,11 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { auto value = GetValueFromProto(expr_->value_); auto right_operand = GetValueFromProto(expr_->right_operand_); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto op_type = expr_->op_type_; auto arith_type = expr_->arith_op_type_; @@ -1278,6 +1307,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, HighPrecisionType value, HighPrecisionType right_operand) { switch (op_type) { @@ -1288,7 +1318,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Add> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Sub: { @@ -1296,7 +1332,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Sub> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mul: { @@ -1304,7 +1346,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Mul> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Div: { @@ -1312,7 +1360,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Div> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mod: { @@ -1320,7 +1374,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Mod> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } default: @@ -1339,7 +1399,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Add> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Sub: { @@ -1347,7 +1413,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Sub> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mul: { @@ -1355,7 +1427,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Mul> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Div: { @@ -1363,7 +1441,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Div> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mod: { @@ -1371,7 +1455,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Mod> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } default: @@ -1390,7 +1480,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Add> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Sub: { @@ -1398,7 +1494,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Sub> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mul: { @@ -1406,7 +1508,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Mul> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Div: { @@ -1414,7 +1522,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Div> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mod: { @@ -1422,7 +1536,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Mod> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } default: @@ -1441,7 +1561,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Add> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Sub: { @@ -1449,7 +1575,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Sub> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mul: { @@ -1457,7 +1589,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Mul> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Div: { @@ -1465,7 +1603,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Div> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mod: { @@ -1473,7 +1617,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Mod> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } default: @@ -1492,7 +1642,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Add> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Sub: { @@ -1500,7 +1656,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Sub> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mul: { @@ -1508,7 +1670,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Mul> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Div: { @@ -1516,7 +1684,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Div> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mod: { @@ -1524,7 +1698,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Mod> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } default: @@ -1543,7 +1723,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Add> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Sub: { @@ -1551,7 +1737,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Sub> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mul: { @@ -1559,7 +1751,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Mul> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Div: { @@ -1567,7 +1765,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Div> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } case proto::plan::ArithOpType::Mod: { @@ -1575,7 +1779,13 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Mod> func; - func(data, valid_data, size, value, right_operand, res); + func(data, + valid_data, + size, + value, + right_operand, + res, + valid_res); break; } default: @@ -1594,8 +1804,12 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { op_type); } }; - int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, value, right_operand); + int64_t processed_size = ProcessDataChunks(execute_sub_batch, + std::nullptr_t{}, + res, + valid_res, + value, + right_operand); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h index a75c5c3f4dbb1..0db5a59f147ed 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -101,7 +101,8 @@ struct ArithOpElementFunc { size_t size, HighPrecisonType val, HighPrecisonType right_operand, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { /* // This is the original code, kept here for the documentation purposes for (int i = 0; i < size; ++i) { @@ -287,6 +288,7 @@ struct ArithOpElementFunc { } continue; } + valid_res[right] = false; execute_sub_batch( src + left, right - left, val, right_operand, res + left); left = right; diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index 94afd3d6abda9..c1ce0ad4383a7 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "BinaryRangeExpr.h" +#include #include "query/Utils.h" @@ -150,8 +151,12 @@ PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1, cached_overflow_res_->size() == batch_size) { return cached_overflow_res_; } - auto res = std::make_shared(TargetBitmap(batch_size)); - return res; + auto valid_res = ProcessChunksForValid(is_index_mode_); + auto res_vec = std::make_shared(TargetBitmap(batch_size), + std::move(valid_res)); + cached_overflow_res_ = res_vec; + + return res_vec; }; if constexpr (std::is_integral_v && !std::is_same_v) { @@ -207,12 +212,12 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { func(index_ptr, val1, val2, lower_inclusive, upper_inclusive)); }; auto res = ProcessIndexChunks(execute_sub_batch, val1, val2); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } template @@ -240,29 +245,32 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { PreCheckOverflow(val1, val2, lower_inclusive, upper_inclusive)) { return res; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const T* data, const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, HighPrecisionType val1, HighPrecisionType val2) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res); + func(val1, val2, data, valid_data, size, res, valid_res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res); + func(val1, val2, data, valid_data, size, res, valid_res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res); + func(val1, val2, data, valid_data, size, res, valid_res); } else { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res); + func(val1, val2, data, valid_data, size, res, valid_res); } }; auto skip_index_func = @@ -283,7 +291,7 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, skip_index_func, res, val1, val2); + execute_sub_batch, skip_index_func, res, valid_res, val1, val2); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -302,9 +310,11 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); bool lower_inclusive = expr_->lower_inclusive_; bool upper_inclusive = expr_->upper_inclusive_; @@ -317,24 +327,25 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val1, ValueType val2) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, valid_data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, valid_data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, valid_data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } else { BinaryRangeElementFuncForJson func; - func(val1, val2, pointer, data, valid_data, size, res); + func(val1, val2, pointer, data, valid_data, size, res, valid_res); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val1, val2); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val1, val2); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -353,9 +364,11 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); bool lower_inclusive = expr_->lower_inclusive_; bool upper_inclusive = expr_->upper_inclusive_; @@ -371,25 +384,26 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val1, ValueType val2, int index) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, valid_data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, valid_data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, valid_data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } else { BinaryRangeElementFuncForArray func; - func(val1, val2, index, data, valid_data, size, res); + func(val1, val2, index, data, valid_data, size, res, valid_res); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val1, val2, index); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val1, val2, index); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index e359224e82b35..66d37e8494eb1 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -40,7 +40,8 @@ struct BinaryRangeElementFunc { const T* src, const bool* valid_data, size_t n, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { auto execute_sub_batch = [](T val1, T val2, const T* src, @@ -79,6 +80,7 @@ struct BinaryRangeElementFunc { } continue; } + valid_res[right] = false; execute_sub_batch( val1, val2, src + left, right - left, res + left); left = right; @@ -91,8 +93,8 @@ struct BinaryRangeElementFunc { #define BinaryRangeJSONCompare(cmp) \ do { \ if (valid_data && !valid_data[i]) { \ - res[i] = false; \ - continue; \ + res[i] = valid_res[i] = false; \ + break; \ } \ auto x = src[i].template at(pointer); \ if (x.error()) { \ @@ -123,7 +125,8 @@ struct BinaryRangeElementFuncForJson { const milvus::Json* src, const bool* valid_data, size_t n, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { for (size_t i = 0; i < n; ++i) { if constexpr (lower_inclusive && upper_inclusive) { BinaryRangeJSONCompare(val1 <= value && value <= val2); @@ -150,10 +153,11 @@ struct BinaryRangeElementFuncForArray { const milvus::ArrayView* src, const bool* valid_data, size_t n, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { for (size_t i = 0; i < n; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (lower_inclusive && upper_inclusive) { diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp index bb1a11800aa7f..27a7e6ad2e05d 100644 --- a/internal/core/src/exec/expression/CompareExpr.cpp +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -425,6 +425,62 @@ PhyCompareFilterExpr::GetChunkData(DataType data_type, } } +// template +// VectorPtr +// PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { +// auto real_batch_size = GetNextBatchSize(); +// if (real_batch_size == 0) { +// return nullptr; +// } + +// auto res_vec = std::make_shared( +// TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); +// TargetBitmapView res(res_vec->GetRawData(), real_batch_size); +// TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); +// valid_res.set(); + +// auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); +// auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); + +// int64_t processed_rows = 0; +// for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; +// ++chunk_id) { +// auto chunk_size = chunk_id == num_chunk_ - 1 +// ? active_count_ - chunk_id * size_per_chunk_ +// : size_per_chunk_; +// auto left = GetChunkData(expr_->left_data_type_, +// expr_->left_field_id_, +// chunk_id, +// left_data_barrier); +// auto right = GetChunkData(expr_->right_data_type_, +// expr_->right_field_id_, +// chunk_id, +// right_data_barrier); + +// for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; +// i < chunk_size; +// ++i) { +// if (!left(i).has_value() || !right(i).has_value()) { +// res[processed_rows] = false; +// valid_res[processed_rows] = false; +// } else { +// res[processed_rows] = boost::apply_visitor( +// milvus::query::Relational{}, +// left(i).value(), +// right(i).value()); +// } +// processed_rows++; + +// if (processed_rows >= batch_size_) { +// current_chunk_id_ = chunk_id; +// current_chunk_pos_ = i + 1; +// return res_vec; +// } +// } +// } +// return res_vec; +// } + void PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { // For segment both fields has no index, can use SIMD to speed up. @@ -528,9 +584,11 @@ PhyCompareFilterExpr::ExecCompareRightType() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* left, @@ -577,7 +635,7 @@ PhyCompareFilterExpr::ExecCompareRightType() { } }; int64_t processed_size = - ProcessBothDataChunks(execute_sub_batch, res); + ProcessBothDataChunks(execute_sub_batch, res, valid_res); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index a4b5dfaab77b2..569e305c5c83c 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -285,6 +285,7 @@ class PhyCompareFilterExpr : public Expr { int64_t ProcessBothDataChunksForSingleChunk(FUNC func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -314,10 +315,12 @@ class PhyCompareFilterExpr : public Expr { for (int i = 0; i < size; ++i) { if (left_valid_data && !left_valid_data[i + data_pos]) { res[processed_size + i] = false; + valid_res[processed_size + i] = false; continue; } if (right_valid_data && !right_valid_data[i + data_pos]) { res[processed_size + i] = false; + valid_res[processed_size + i] = false; } } processed_size += size; @@ -336,6 +339,7 @@ class PhyCompareFilterExpr : public Expr { int64_t ProcessBothDataChunksForMultipleChunk(FUNC func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -363,6 +367,20 @@ class PhyCompareFilterExpr : public Expr { const T* left_data = left_chunk.data() + data_pos; const U* right_data = right_chunk.data() + data_pos; func(left_data, right_data, size, res + processed_size, values...); + const bool* left_valid_data = left_chunk.valid_data(); + const bool* right_valid_data = right_chunk.valid_data(); + // mask with valid_data + for (int i = 0; i < size; ++i) { + if (left_valid_data && !left_valid_data[i + data_pos]) { + res[processed_size + i] = false; + valid_res[processed_size + i] = false; + continue; + } + if (right_valid_data && !right_valid_data[i + data_pos]) { + res[processed_size + i] = false; + valid_res[processed_size + i] = false; + } + } processed_size += size; if (processed_size >= batch_size_) { diff --git a/internal/core/src/exec/expression/ExistsExpr.cpp b/internal/core/src/exec/expression/ExistsExpr.cpp index 8f56d4197ba35..0fab44ebdc463 100644 --- a/internal/core/src/exec/expression/ExistsExpr.cpp +++ b/internal/core/src/exec/expression/ExistsExpr.cpp @@ -44,19 +44,22 @@ PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto execute_sub_batch = [](const milvus::Json* data, const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer) { for (int i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = data[i].exist(pointer); @@ -64,7 +67,7 @@ PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 3806b74b9bc72..2b9c16082218f 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -248,6 +249,7 @@ class SegmentExpr : public Expr { FUNC func, std::function skip_func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { // For sealed segment, only single chunk Assert(num_data_chunk_ == 1); @@ -262,6 +264,7 @@ class SegmentExpr : public Expr { views_info.second.data(), need_size, res, + valid_res, values...); } current_data_chunk_pos_ += need_size; @@ -274,6 +277,7 @@ class SegmentExpr : public Expr { FUNC func, std::function skip_func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -281,7 +285,7 @@ class SegmentExpr : public Expr { std::is_same_v) { if (segment_->type() == SegmentType::Sealed) { return ProcessChunkForSealedSeg( - func, skip_func, res, values...); + func, skip_func, res, valid_res, values...); } } @@ -307,7 +311,12 @@ class SegmentExpr : public Expr { if (valid_data != nullptr) { valid_data += data_pos; } - func(data, valid_data, size, res + processed_size, values...); + func(data, + valid_data, + size, + res + processed_size, + valid_res + processed_size, + values...); } processed_size += size; @@ -407,8 +416,10 @@ class SegmentExpr : public Expr { int ProcessIndexOneChunk(TargetBitmap& result, + TargetBitmap& valid_result, size_t chunk_id, const TargetBitmap& chunk_res, + const TargetBitmap& chunk_valid_res, int processed_rows) { auto data_pos = chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0; @@ -420,33 +431,41 @@ class SegmentExpr : public Expr { // chunk_res.begin() + data_pos, // chunk_res.begin() + data_pos + size); result.append(chunk_res, data_pos, size); + valid_result.append(chunk_valid_res, data_pos, size); return size; } template - TargetBitmap + VectorPtr ProcessIndexChunks(FUNC func, ValTypes... values) { typedef std:: conditional_t, std::string, T> IndexInnerType; using Index = index::ScalarIndex; TargetBitmap result; + TargetBitmap valid_result; int processed_rows = 0; for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { // This cache result help getting result for every batch loop. - // It avoids indexing execute for evevy batch because indexing + // It avoids indexing execute for every batch because indexing // executing costs quite much time. if (cached_index_chunk_id_ != i) { const Index& index = segment_->chunk_scalar_index(field_id_, i); auto* index_ptr = const_cast(&index); cached_index_chunk_res_ = std::move(func(index_ptr, values...)); + auto valid_result = index_ptr->IsNotNull(); + cached_index_chunk_valid_res_ = std::move(valid_result); cached_index_chunk_id_ = i; } - auto size = ProcessIndexOneChunk( - result, i, cached_index_chunk_res_, processed_rows); + auto size = ProcessIndexOneChunk(result, + valid_result, + i, + cached_index_chunk_res_, + cached_index_chunk_valid_res_, + processed_rows); if (processed_rows + size >= batch_size_) { current_index_chunk_ = i; @@ -458,18 +477,130 @@ class SegmentExpr : public Expr { processed_rows += size; } - return result; + return std::make_shared(std::move(result), + std::move(valid_result)); } - template + template + TargetBitmap + ProcessChunksForValid(bool use_index) { + if (use_index) { + return ProcessIndexChunksForValid(); + } else { + return ProcessDataChunksForValid(); + } + } + + template + TargetBitmap + ProcessDataChunksForValid() { + TargetBitmap valid_result; + valid_result.set(); + int64_t processed_size = 0; + for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) { + auto data_pos = + (i == current_data_chunk_) ? current_data_chunk_pos_ : 0; + auto size = + (i == (num_data_chunk_ - 1)) + ? (segment_->type() == SegmentType::Growing + ? (active_count_ % size_per_chunk_ == 0 + ? size_per_chunk_ - data_pos + : active_count_ % size_per_chunk_ - data_pos) + : active_count_ - data_pos) + : size_per_chunk_ - data_pos; + + size = std::min(size, batch_size_ - processed_size); + + auto chunk = segment_->chunk_data(field_id_, i); + const bool* valid_data = chunk.valid_data(); + if (valid_data == nullptr) { + return valid_result; + } + valid_data += data_pos; + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + valid_result[i] = false; + } + } + processed_size += size; + if (processed_size >= batch_size_) { + current_data_chunk_ = i; + current_data_chunk_pos_ = data_pos + size; + break; + } + } + return valid_result; + } + + int + ProcessIndexOneChunkForValid(TargetBitmap& valid_result, + size_t chunk_id, + const TargetBitmap& chunk_valid_res, + int processed_rows) { + auto data_pos = + chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0; + auto size = std::min( + std::min(size_per_chunk_ - data_pos, batch_size_ - processed_rows), + int64_t(chunk_valid_res.size())); + + valid_result.append(chunk_valid_res, data_pos, size); + return size; + } + + template TargetBitmap + ProcessIndexChunksForValid() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + int processed_rows = 0; + TargetBitmap valid_result; + valid_result.set(); + + for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { + // This cache result help getting result for every batch loop. + // It avoids indexing execute for every batch because indexing + // executing costs quite much time. + if (cached_index_chunk_id_ != i) { + const Index& index = + segment_->chunk_scalar_index(field_id_, i); + auto* index_ptr = const_cast(&index); + auto execute_sub_batch = [](Index* index_ptr) { + TargetBitmap res = index_ptr->IsNotNull(); + return res; + }; + cached_index_chunk_valid_res_ = execute_sub_batch(index_ptr); + cached_index_chunk_id_ = i; + } + + auto size = ProcessIndexOneChunkForValid( + valid_result, i, cached_index_chunk_valid_res_, processed_rows); + + if (processed_rows + size >= batch_size_) { + current_index_chunk_ = i; + current_index_chunk_pos_ = i == current_index_chunk_ + ? current_index_chunk_pos_ + size + : size; + break; + } + processed_rows += size; + } + return valid_result; + } + + template + VectorPtr ProcessTextMatchIndex(FUNC func, ValTypes... values) { TargetBitmap result; + TargetBitmap valid_result; if (cached_match_res_ == nullptr) { auto index = segment_->GetTextIndex(field_id_); auto res = std::move(func(index, values...)); + auto valid_res = index->IsNotNull(); cached_match_res_ = std::make_shared(std::move(res)); + cached_index_chunk_valid_res_ = std::move(valid_result); if (cached_match_res_->size() < active_count_) { // some entities are not visible in inverted index. // only happend on growing segment. @@ -485,9 +616,13 @@ class SegmentExpr : public Expr { : batch_size_; result.append( *cached_match_res_, current_data_chunk_pos_, real_batch_size); + valid_result.append(cached_index_chunk_valid_res_, + current_data_chunk_pos_, + real_batch_size); current_data_chunk_pos_ += real_batch_size; - return result; + return std::make_shared(std::move(result), + std::move(valid_result)); } template @@ -585,6 +720,8 @@ class SegmentExpr : public Expr { // Cache for index scan to avoid search index every batch int64_t cached_index_chunk_id_{-1}; TargetBitmap cached_index_chunk_res_{}; + // Cache for chunk valid res. + TargetBitmap cached_index_chunk_valid_res_{}; // Cache for text match. std::shared_ptr cached_match_res_{nullptr}; diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index c91420e577182..897cad544755f 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "JsonContainsExpr.h" +#include #include "common/Types.h" namespace milvus { @@ -173,9 +174,11 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { AssertInfo(expr_->column_.nested_path_.size() == 0, "[ExecArrayContains]nested path must be null"); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::unordered_set elements; for (auto const& element : expr_->vals_) { @@ -185,6 +188,7 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::unordered_set& elements) { auto executor = [&](size_t i) { const auto& array = data[i]; @@ -197,7 +201,7 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { }; for (int i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -205,7 +209,7 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -226,9 +230,11 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::unordered_set elements; auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -239,6 +245,7 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::unordered_set& elements) { auto executor = [&](size_t i) { @@ -260,7 +267,7 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -268,7 +275,7 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -284,9 +291,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::vector elements; @@ -298,6 +307,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements) { auto executor = [&](size_t i) -> bool { @@ -328,7 +338,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -336,7 +346,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -359,9 +369,11 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::unordered_set elements; for (auto const& element : expr_->vals_) { @@ -372,6 +384,7 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::unordered_set& elements) { auto executor = [&](size_t i) { std::unordered_set tmp_elements(elements); @@ -386,7 +399,7 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { }; for (int i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -394,7 +407,7 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -415,9 +428,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::unordered_set elements; @@ -429,6 +444,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::unordered_set& elements) { auto executor = [&](const size_t i) -> bool { @@ -453,7 +469,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -461,7 +477,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -476,9 +492,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -495,6 +513,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements, const std::unordered_set elements_index) { @@ -580,7 +599,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -590,6 +609,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { int64_t processed_size = ProcessDataChunks(execute_sub_batch, std::nullptr_t{}, res, + valid_res, pointer, elements, elements_index); @@ -608,9 +628,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -623,6 +645,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements) { auto executor = [&](const size_t i) { @@ -657,7 +680,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -665,7 +688,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -681,9 +704,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -700,6 +725,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string& pointer, const std::vector& elements) { auto executor = [&](const size_t i) { @@ -776,7 +802,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); @@ -784,7 +810,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, elements); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -872,12 +898,12 @@ PhyJsonContainsFilterExpr::ExecArrayContainsForIndexSegmentImpl() { } }; auto res = ProcessIndexChunks(execute_sub_batch, elems); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } } //namespace exec diff --git a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp index 4d4bb550691c2..14bde4decdab0 100644 --- a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp +++ b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp @@ -30,6 +30,9 @@ PhyLogicalUnaryExpr::Eval(EvalCtx& context, VectorPtr& result) { auto flat_vec = GetColumnVector(result); TargetBitmapView data(flat_vec->GetRawData(), flat_vec->size()); data.flip(); + TargetBitmapView valid_data(flat_vec->GetValidRawData(), + flat_vec->size()); + data &= valid_data; } } diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp index e5322a84a93fa..9e1092661b4e3 100644 --- a/internal/core/src/exec/expression/TermExpr.cpp +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -15,6 +15,8 @@ // limitations under the License. #include "TermExpr.h" +#include +#include #include "query/Utils.h" namespace milvus { namespace exec { @@ -199,9 +201,12 @@ PhyTermFilterExpr::ExecPkTermImpl() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + // pk valid_bitmap is always all true + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); for (size_t i = 0; i < real_batch_size; ++i) { res[i] = cached_bits_[current_data_chunk_pos_++]; @@ -241,9 +246,11 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); AssertInfo(expr_->vals_.size() == 1, "element length in json array must be one"); @@ -253,6 +260,7 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const ValueType& target_val) { auto executor = [&](size_t i) { for (int i = 0; i < data[i].length(); i++) { @@ -265,7 +273,7 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { }; for (int i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } executor(i); @@ -273,7 +281,7 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, target_val); + execute_sub_batch, std::nullptr_t{}, res, valid_res, target_val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -294,9 +302,11 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); int index = -1; if (expr_->column_.nested_path_.size() > 0) { @@ -317,14 +327,15 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, int index, const std::unordered_set& term_set) { for (int i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } - if (index >= data[i].length()) { + if (term_set.empty() || index >= data[i].length()) { res[i] = false; continue; } @@ -334,7 +345,7 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, index, term_set); + execute_sub_batch, std::nullptr_t{}, res, valid_res, index, term_set); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -354,9 +365,11 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); AssertInfo(expr_->vals_.size() == 1, "element length in json array must be one"); @@ -367,6 +380,7 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string pointer, const ValueType& target_val) { auto executor = [&](size_t i) { @@ -387,14 +401,14 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = executor(i); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, val); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -414,9 +428,11 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::unordered_set term_set; @@ -434,6 +450,7 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::string pointer, const std::unordered_set& terms) { auto executor = [&](size_t i) { @@ -456,6 +473,10 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { }; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { + res[i] = valid_res[i] = false; + continue; + } + if (terms.empty()) { res[i] = false; continue; } @@ -463,7 +484,7 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, pointer, term_set); + execute_sub_batch, std::nullptr_t{}, res, valid_res, pointer, term_set); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -509,12 +530,12 @@ PhyTermFilterExpr::ExecVisitorImplForIndex() { return func(index_ptr, vals.size(), vals.data()); }; auto res = ProcessIndexChunks(execute_sub_batch, vals); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } template <> @@ -536,7 +557,7 @@ PhyTermFilterExpr::ExecVisitorImplForIndex() { return std::move(func(index_ptr, vals.size(), (bool*)vals.data())); }; auto res = ProcessIndexChunks(execute_sub_batch, vals); - return std::make_shared(std::move(res)); + return res; } template @@ -547,9 +568,11 @@ PhyTermFilterExpr::ExecVisitorImplForData() { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); std::vector vals; for (auto& val : expr_->vals_) { @@ -565,18 +588,19 @@ PhyTermFilterExpr::ExecVisitorImplForData() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, const std::unordered_set& vals) { TermElementFuncSet func; for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } res[i] = func(vals, data[i]); } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, vals_set); + execute_sub_batch, std::nullptr_t{}, res, valid_res, vals_set); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index 38cbb37ba5a31..ae10d9386d01a 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include "UnaryExpr.h" +#include #include "common/Json.h" namespace milvus { @@ -260,9 +261,11 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { if (real_batch_size == 0) { return nullptr; } - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); ValueType val = GetValueFromProto(expr_->val_); auto op_type = expr_->op_type_; @@ -274,46 +277,47 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ValueType val, int index) { switch (op_type) { case proto::plan::GreaterThan: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::GreaterEqual: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::LessThan: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::LessEqual: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::Equal: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::NotEqual: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } case proto::plan::PrefixMatch: { UnaryElementFuncForArray func; - func(data, valid_data, size, val, index, res); + func(data, valid_data, size, val, index, res, valid_res); break; } default: @@ -324,7 +328,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val, index); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val, index); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -433,14 +437,14 @@ PhyUnaryRangeFilterExpr::ExecArrayEqualForIndex(bool reverse) { } return res; }); - AssertInfo(batch_res.size() == real_batch_size, + AssertInfo(batch_res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - batch_res.size(), + batch_res->size(), real_batch_size); // return the result. - return std::make_shared(std::move(batch_res)); + return batch_res; } template @@ -456,9 +460,11 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } ExprValueType val = GetValueFromProto(expr_->val_); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto op_type = expr_->op_type_; auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -496,12 +502,13 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, ExprValueType val) { switch (op_type) { case proto::plan::GreaterThan: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -515,7 +522,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { case proto::plan::GreaterEqual: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -529,7 +536,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { case proto::plan::LessThan: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -543,7 +550,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { case proto::plan::LessEqual: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -557,7 +564,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { case proto::plan::Equal: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -577,7 +584,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { case proto::plan::NotEqual: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -597,7 +604,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { case proto::plan::PrefixMatch: { for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -615,7 +622,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { RegexMatcher matcher(regex_pattern); for (size_t i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (std::is_same_v) { @@ -635,7 +642,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } }; int64_t processed_size = ProcessDataChunks( - execute_sub_batch, std::nullptr_t{}, res, val); + execute_sub_batch, std::nullptr_t{}, res, valid_res, val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", @@ -727,12 +734,12 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { }; auto val = GetValueFromProto(expr_->val_); auto res = ProcessIndexChunks(execute_sub_batch, val); - AssertInfo(res.size() == real_batch_size, + AssertInfo(res->size() == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", - res.size(), + res->size(), real_batch_size); - return std::make_shared(std::move(res)); + return res; } template @@ -754,10 +761,11 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { switch (expr_->op_type_) { case proto::plan::GreaterThan: case proto::plan::GreaterEqual: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; if (milvus::query::lt_lb(val)) { res.set(); @@ -767,10 +775,11 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { } case proto::plan::LessThan: case proto::plan::LessEqual: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; if (milvus::query::gt_ub(val)) { res.set(); @@ -779,19 +788,21 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { return res_vec; } case proto::plan::Equal: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; res.reset(); return res_vec; } case proto::plan::NotEqual: { + auto valid_res = ProcessChunksForValid(CanUseIndex()); auto res_vec = std::make_shared( - TargetBitmap(batch_size)); - cached_overflow_res_ = res_vec; + TargetBitmap(batch_size), std::move(valid_res)); TargetBitmapView res(res_vec->GetRawData(), batch_size); + cached_overflow_res_ = res_vec; res.set(); return res_vec; @@ -822,54 +833,57 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { return nullptr; } IndexInnerType val = GetValueFromProto(expr_->val_); - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* data, const bool* valid_data, const int size, TargetBitmapView res, + TargetBitmapView valid_res, IndexInnerType val) { switch (expr_type) { case proto::plan::GreaterThan: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::GreaterEqual: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::LessThan: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::LessEqual: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::Equal: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::NotEqual: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::PrefixMatch: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } case proto::plan::Match: { UnaryElementFunc func; - func(data, valid_data, size, val, res); + func(data, valid_data, size, val, res, valid_res); break; } default: @@ -885,8 +899,8 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { return skip_index.CanSkipUnaryRange( field_id, chunk_id, expr_type, val); }; - int64_t processed_size = - ProcessDataChunks(execute_sub_batch, skip_index_func, res, val); + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, skip_index_func, res, valid_res, val); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}, related params[active_count:{}, " @@ -916,7 +930,7 @@ PhyUnaryRangeFilterExpr::ExecTextMatch() { return index->MatchQuery(query); }; auto res = ProcessTextMatchIndex(func, query); - return std::make_shared(std::move(res)); + return res; }; } // namespace exec diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index 612b5ba1d70f9..d216decf07f47 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -63,7 +63,8 @@ struct UnaryElementFunc { const bool* valid_data, size_t size, IndexInnerType val, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { if constexpr (op == proto::plan::OpType::Match) { UnaryElementFuncForMatch func; func(src, size, val, res); @@ -153,6 +154,7 @@ struct UnaryElementFunc { } continue; } + valid_res[right] = false; execute_sub_batch(src + left, right - left, val, res + left); left = right; break; @@ -186,10 +188,11 @@ struct UnaryElementFuncForArray { size_t size, ValueType val, int index, - TargetBitmapView res) { + TargetBitmapView res, + TargetBitmapView valid_res) { for (int i = 0; i < size; ++i) { if (valid_data && !valid_data[i]) { - res[i] = false; + res[i] = valid_res[i] = false; continue; } if constexpr (op == proto::plan::OpType::Equal) { diff --git a/internal/core/src/index/BitmapIndex.cpp b/internal/core/src/index/BitmapIndex.cpp index 00cd6a58cbcfe..c35e73a5902bb 100644 --- a/internal/core/src/index/BitmapIndex.cpp +++ b/internal/core/src/index/BitmapIndex.cpp @@ -89,13 +89,13 @@ BitmapIndex::Build(size_t n, const T* data, const bool* valid_data) { } total_num_rows_ = n; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); T* p = const_cast(data); for (int i = 0; i < n; ++i, ++p) { if (!valid_data || valid_data[i]) { data_[*p].add(i); - valid_bitset.set(i); + valid_bitset_.set(i); } } @@ -122,7 +122,7 @@ BitmapIndex::BuildPrimitiveField( if (data->is_valid(i)) { auto val = reinterpret_cast(data->RawValue(i)); data_[*val].add(offset); - valid_bitset.set(offset); + valid_bitset_.set(offset); } offset++; } @@ -141,7 +141,7 @@ BitmapIndex::BuildWithFieldData( PanicInfo(DataIsEmpty, "scalar bitmap index can not build null values"); } total_num_rows_ = total_num_rows; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); switch (schema_.data_type()) { case proto::schema::DataType::Bool: @@ -186,7 +186,7 @@ BitmapIndex::BuildArrayField(const std::vector& field_datas) { auto val = array->template get_data(j); data_[val].add(offset); } - valid_bitset.set(offset); + valid_bitset_.set(offset); } offset++; } @@ -361,7 +361,7 @@ BitmapIndex::DeserializeIndexData(const uint8_t* data_ptr, data_[key] = value; } for (const auto& v : value) { - valid_bitset.set(v); + valid_bitset_.set(v); } } } @@ -424,7 +424,7 @@ BitmapIndex::DeserializeIndexData(const uint8_t* data_ptr, data_[key] = value; } for (const auto& v : value) { - valid_bitset.set(v); + valid_bitset_.set(v); } } } @@ -518,7 +518,7 @@ BitmapIndex::LoadWithoutAssemble(const BinarySet& binary_set, index_meta_buffer->size); auto index_length = index_meta.first; total_num_rows_ = index_meta.second; - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); auto index_data_buffer = binary_set.GetByName(BITMAP_INDEX_DATA); @@ -647,7 +647,7 @@ BitmapIndex::NotIn(const size_t n, const T* values) { } } // NotIn(null) and In(null) is both false, need to mask with IsNotNull operate - res &= valid_bitset; + res &= valid_bitset_; return res; } else { TargetBitmap res(total_num_rows_, false); @@ -659,7 +659,7 @@ BitmapIndex::NotIn(const size_t n, const T* values) { } res.flip(); // NotIn(null) and In(null) is both false, need to mask with IsNotNull operate - res &= valid_bitset; + res &= valid_bitset_; return res; } } @@ -669,7 +669,7 @@ const TargetBitmap BitmapIndex::IsNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap res(total_num_rows_, true); - res &= valid_bitset; + res &= valid_bitset_; res.flip(); return res; } @@ -679,7 +679,7 @@ const TargetBitmap BitmapIndex::IsNotNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap res(total_num_rows_, true); - res &= valid_bitset; + res &= valid_bitset_; return res; } @@ -1088,11 +1088,15 @@ BitmapIndex::Reverse_Lookup_InCache(size_t idx) const { } template -T +std::optional BitmapIndex::Reverse_Lookup(size_t idx) const { AssertInfo(is_built_, "index has not been built"); AssertInfo(idx < total_num_rows_, "out of range of total coun"); + if (!valid_bitset_[idx]) { + return std::nullopt; + } + if (use_offset_cache_) { return Reverse_Lookup_InCache(idx); } diff --git a/internal/core/src/index/BitmapIndex.h b/internal/core/src/index/BitmapIndex.h index 714ddf878c4d4..fb677e6f3194f 100644 --- a/internal/core/src/index/BitmapIndex.h +++ b/internal/core/src/index/BitmapIndex.h @@ -267,7 +267,7 @@ class BitmapIndex : public ScalarIndex { std::shared_ptr file_manager_; // generate valid_bitset to speed up NotIn and IsNull and IsNotNull operate - TargetBitmap valid_bitset; + TargetBitmap valid_bitset_; }; } // namespace index diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 7a2a2dd264e14..b7ef644deff06 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -2091,7 +2091,11 @@ SegmentSealedImpl::CreateTextIndex(FieldId field_id) { "converted to string index"); auto n = impl->Size(); for (size_t i = 0; i < n; i++) { - index->AddText(impl->Reverse_Lookup(i), i); + auto value = impl->Reverse_Lookup(i); + if (!value.has_value()) { + continue; + } + index->AddText(impl->Reverse_Lookup(i).value(), i); } } } diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index c45cc4e3b12fc..8d128338fea30 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -3335,6 +3335,373 @@ TEST(Expr, TestExprPerformance) { test_case_base(expr); } +TEST(Expr, TestExprNOT) { + auto schema = std::make_shared(); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8, true); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16, true); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32, true); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR, true); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT, true); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE, true); + schema->set_primary_field_id(str1_fid); + + std::map fids = {{DataType::INT8, int8_fid}, + {DataType::INT16, int16_fid}, + {DataType::INT32, int32_fid}, + {DataType::INT64, int64_fid}, + {DataType::VARCHAR, str2_fid}, + {DataType::FLOAT, float_fid}, + {DataType::DOUBLE, double_fid}}; + + auto seg = CreateSealedSegment(schema); + FixedVector valid_data_i8; + FixedVector valid_data_i16; + FixedVector valid_data_i32; + FixedVector valid_data_i64; + FixedVector valid_data_str; + FixedVector valid_data_float; + FixedVector valid_data_double; + int N = 1000; + auto raw_data = DataGen(schema, N); + valid_data_i8 = raw_data.get_col_valid(int8_fid); + valid_data_i16 = raw_data.get_col_valid(int16_fid); + valid_data_i32 = raw_data.get_col_valid(int32_fid); + valid_data_i64 = raw_data.get_col_valid(int64_fid); + valid_data_str = raw_data.get_col_valid(str2_fid); + valid_data_float = raw_data.get_col_valid(float_fid); + valid_data_double = raw_data.get_col_valid(double_fid); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + enum ExprType { + UnaryRangeExpr = 0, + TermExprImpl = 1, + CompareExpr = 2, + LogicalUnaryExpr = 3, + BinaryRangeExpr = 4, + LogicalBinaryExpr = 5, + BinaryArithOpEvalRangeExpr = 6, + }; + + auto build_unary_range_expr = [&](DataType data_type, + int64_t value) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val; + val.set_int64_val(value); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val; + val.set_float_val(float(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(value)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::LessThan, + val); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_binary_range_expr = [&](DataType data_type, + int64_t low, + int64_t high) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(low); + proto::plan::GenericValue val2; + val2.set_int64_val(high); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(low)); + proto::plan::GenericValue val2; + val2.set_float_val(float(high)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else if (IsStringDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_string_val(std::to_string(low)); + proto::plan::GenericValue val2; + val2.set_string_val(std::to_string(low)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + val1, + val2, + true, + true); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_term_expr = + [&](DataType data_type, + std::vector in_vals) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_int64_val(v); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsFloatDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_float_val(float(v)); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else if (IsStringDataType(data_type)) { + std::vector vals; + for (auto& v : in_vals) { + proto::plan::GenericValue val; + val.set_string_val(std::to_string(v)); + vals.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), vals, false); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_compare_expr = [&](DataType data_type) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type) || IsFloatDataType(data_type) || + IsStringDataType(data_type)) { + return std::make_shared( + fids[data_type], + fids[data_type], + data_type, + data_type, + proto::plan::OpType::LessThan); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 10); + auto child2_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + }; + + auto build_multi_logical_binary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child1_expr = build_unary_range_expr(data_type, 100); + auto child2_expr = build_unary_range_expr(data_type, 100); + auto child3_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child4_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + auto child5_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + auto child6_expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child3_expr, child4_expr); + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child5_expr, child6_expr); + }; + + auto build_arith_op_expr = [&](DataType data_type, + int64_t right_val, + int64_t val) -> expr::TypedExprPtr { + if (IsIntegerDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_int64_val(right_val); + proto::plan::GenericValue val2; + val2.set_int64_val(val); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); + } else if (IsFloatDataType(data_type)) { + proto::plan::GenericValue val1; + val1.set_float_val(float(right_val)); + proto::plan::GenericValue val2; + val2.set_float_val(float(val)); + return std::make_shared( + expr::ColumnInfo(fids[data_type], data_type), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val1, + val2); + } else { + throw std::runtime_error("not supported type"); + } + }; + + auto build_logical_unary_expr = + [&](DataType data_type) -> expr::TypedExprPtr { + auto child_expr = build_unary_range_expr(data_type, 10); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + }; + + auto test_ans = [=, &seg](expr::TypedExprPtr expr, + FixedVector valid_data) { + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, expr); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), N, final); + EXPECT_EQ(final.size(), N); + for (int i = 0; i < N; i++) { + if (!valid_data[i]) { + EXPECT_EQ(final[i], false); + } + } + }; + + auto expr = build_unary_range_expr(DataType::INT8, 10); + test_ans(expr, valid_data_i8); + expr = build_unary_range_expr(DataType::INT16, 10); + test_ans(expr, valid_data_i16); + expr = build_unary_range_expr(DataType::INT32, 10); + test_ans(expr, valid_data_i32); + expr = build_unary_range_expr(DataType::INT64, 10); + test_ans(expr, valid_data_i64); + expr = build_unary_range_expr(DataType::FLOAT, 10); + test_ans(expr, valid_data_float); + expr = build_unary_range_expr(DataType::DOUBLE, 10); + test_ans(expr, valid_data_double); + expr = build_unary_range_expr(DataType::VARCHAR, 10); + test_ans(expr, valid_data_str); + + expr = build_binary_range_expr(DataType::INT8, 10, 100); + test_ans(expr, valid_data_i8); + expr = build_binary_range_expr(DataType::INT16, 10, 100); + test_ans(expr, valid_data_i16); + expr = build_binary_range_expr(DataType::INT32, 10, 100); + test_ans(expr, valid_data_i32); + expr = build_binary_range_expr(DataType::INT64, 10, 100); + test_ans(expr, valid_data_i64); + expr = build_binary_range_expr(DataType::FLOAT, 10, 100); + test_ans(expr, valid_data_float); + expr = build_binary_range_expr(DataType::DOUBLE, 10, 100); + test_ans(expr, valid_data_double); + expr = build_binary_range_expr(DataType::VARCHAR, 10, 100); + test_ans(expr, valid_data_str); + + expr = build_compare_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_compare_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_compare_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_compare_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_compare_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_compare_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_compare_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); + + expr = build_arith_op_expr(DataType::INT8, 10, 100); + test_ans(expr, valid_data_i8); + expr = build_arith_op_expr(DataType::INT16, 10, 100); + test_ans(expr, valid_data_i16); + expr = build_arith_op_expr(DataType::INT32, 10, 100); + test_ans(expr, valid_data_i32); + expr = build_arith_op_expr(DataType::INT64, 10, 100); + test_ans(expr, valid_data_i64); + expr = build_arith_op_expr(DataType::FLOAT, 10, 100); + test_ans(expr, valid_data_float); + expr = build_arith_op_expr(DataType::DOUBLE, 10, 100); + test_ans(expr, valid_data_double); + + expr = build_logical_unary_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_logical_unary_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_logical_unary_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_logical_unary_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_logical_unary_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_logical_unary_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_logical_unary_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); + + expr = build_logical_binary_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_logical_binary_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_logical_binary_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_logical_binary_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_logical_binary_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_logical_binary_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_logical_binary_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); + + expr = build_multi_logical_binary_expr(DataType::INT8); + test_ans(expr, valid_data_i8); + expr = build_multi_logical_binary_expr(DataType::INT16); + test_ans(expr, valid_data_i16); + expr = build_multi_logical_binary_expr(DataType::INT32); + test_ans(expr, valid_data_i32); + expr = build_multi_logical_binary_expr(DataType::INT64); + test_ans(expr, valid_data_i64); + expr = build_multi_logical_binary_expr(DataType::FLOAT); + test_ans(expr, valid_data_float); + expr = build_multi_logical_binary_expr(DataType::DOUBLE); + test_ans(expr, valid_data_double); + expr = build_multi_logical_binary_expr(DataType::VARCHAR); + test_ans(expr, valid_data_str); +} + TEST_P(ExprTest, test_term_pk) { auto schema = std::make_shared(); schema->AddField( @@ -3399,67 +3766,6 @@ TEST_P(ExprTest, test_term_pk) { } } -TEST_P(ExprTest, TestSealedSegmentGetBatchSize) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - schema->set_primary_field_id(str1_fid); - - auto seg = CreateSealedSegment(schema); - int N = 1000; - auto raw_data = DataGen(schema, N); - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); - - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); - - seg->LoadFieldData(FieldId(field_id), info); - } - - proto::plan::GenericValue val; - val.set_int64_val(10); - auto expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - auto plan_node = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - - std::vector test_batch_size = { - 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; - for (const auto& batch_size : test_batch_size) { - EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; - auto plan = plan::PlanFragment(plan_node); - auto query_context = std::make_shared( - "query id", seg.get(), N, MAX_TIMESTAMP); - - auto task = - milvus::exec::Task::Create("task_expr", plan, 0, query_context); - auto last_num = N % batch_size; - auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; - int iter = 0; - for (;;) { - auto result = task->Next(); - if (!result) { - break; - } - auto childrens = result->childrens(); - if (++iter != iter_num) { - EXPECT_EQ(childrens[0]->size(), batch_size); - } else { - EXPECT_EQ(childrens[0]->size(), last_num); - } - } - } -} - TEST_P(ExprTest, TestGrowingSegmentGetBatchSize) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index f45079db184f5..b01d1d326501f 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -166,7 +166,8 @@ GenAlwaysFalseExpr(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { } auto -GenAlwaysTrueExpr(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { +GenAlwaysTrueExprIfValid(const FieldMeta& fvec_meta, + const FieldMeta& str_meta) { auto always_false_expr = GenAlwaysFalseExpr(fvec_meta, str_meta); auto not_expr = GenNotExpr(); not_expr->set_allocated_child(always_false_expr); @@ -196,7 +197,7 @@ GenAlwaysFalsePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { auto GenAlwaysTruePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { - auto always_true_expr = GenAlwaysTrueExpr(fvec_meta, str_meta); + auto always_true_expr = GenAlwaysTrueExprIfValid(fvec_meta, str_meta); proto::plan::VectorType vector_type; if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { vector_type = proto::plan::VectorType::FloatVector; @@ -1293,7 +1294,7 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFields) { dataset.timestamps_.data(), dataset.raw_); - auto expr_proto = GenAlwaysTrueExpr(fvec_meta, str_meta); + auto expr_proto = GenAlwaysTrueExprIfValid(fvec_meta, str_meta); auto plan_proto = GenPlanNode(); plan_proto->mutable_query()->set_allocated_predicates(expr_proto); SetTargetEntry(plan_proto, {str_meta.get_id().get()}); @@ -1336,7 +1337,7 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFieldsNullable) { dataset.timestamps_.data(), dataset.raw_); - auto expr_proto = GenAlwaysTrueExpr(fvec_meta, str_meta); + auto expr_proto = GenAlwaysTrueExprIfValid(fvec_meta, str_meta); auto plan_proto = GenPlanNode(); plan_proto->mutable_query()->set_allocated_predicates(expr_proto); SetTargetEntry(plan_proto, {str_meta.get_id().get()}); @@ -1346,9 +1347,9 @@ TEST(AlwaysTrueStringPlan, QueryWithOutputFieldsNullable) { auto retrieved = segment->Retrieve( nullptr, plan.get(), time, DEFAULT_MAX_OUTPUT_SIZE, false); - ASSERT_EQ(retrieved->offset().size(), N); + ASSERT_EQ(retrieved->offset().size(), N / 2); ASSERT_EQ(retrieved->fields_data().size(), 1); ASSERT_EQ(retrieved->fields_data(0).scalars().string_data().data().size(), - N); - ASSERT_EQ(retrieved->fields_data(0).valid_data().size(), N); + N / 2); + ASSERT_EQ(retrieved->fields_data(0).valid_data().size(), N / 2); } From 460465caa244f75c58771782175823c58213b2f7 Mon Sep 17 00:00:00 2001 From: lixinguo Date: Fri, 13 Sep 2024 15:11:57 +0800 Subject: [PATCH 3/4] fix conflict Signed-off-by: lixinguo --- internal/core/src/exec/expression/Expr.h | 7 +- .../src/exec/expression/LogicalBinaryExpr.cpp | 4 + .../core/src/exec/operator/FilterBitsNode.cpp | 7 +- internal/core/src/exec/operator/MvccNode.cpp | 9 +- internal/core/src/index/ScalarIndexSort.cpp | 24 +- internal/core/src/index/ScalarIndexSort.h | 4 +- internal/core/src/index/StringIndexMarisa.cpp | 13 +- internal/core/src/index/StringIndexMarisa.h | 4 +- internal/core/src/segcore/Utils.cpp | 50 +++ internal/core/unittest/test_expr.cpp | 359 +++++++++++------- internal/core/unittest/test_string_expr.cpp | 45 ++- tests/python_client/testcases/test_search.py | 1 - 12 files changed, 331 insertions(+), 196 deletions(-) diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 2b9c16082218f..734ecf0077c6e 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -494,7 +494,7 @@ class SegmentExpr : public Expr { template TargetBitmap ProcessDataChunksForValid() { - TargetBitmap valid_result; + TargetBitmap valid_result(batch_size_); valid_result.set(); int64_t processed_size = 0; for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) { @@ -519,7 +519,7 @@ class SegmentExpr : public Expr { valid_data += data_pos; for (int i = 0; i < size; i++) { if (!valid_data[i]) { - valid_result[i] = false; + valid_result[i + data_pos] = false; } } processed_size += size; @@ -600,12 +600,13 @@ class SegmentExpr : public Expr { auto res = std::move(func(index, values...)); auto valid_res = index->IsNotNull(); cached_match_res_ = std::make_shared(std::move(res)); - cached_index_chunk_valid_res_ = std::move(valid_result); + cached_index_chunk_valid_res_ = std::move(valid_res); if (cached_match_res_->size() < active_count_) { // some entities are not visible in inverted index. // only happend on growing segment. TargetBitmap tail(active_count_ - cached_match_res_->size()); cached_match_res_->append(tail); + cached_index_chunk_valid_res_.append(tail); } } diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp index d388ab2454cc3..4267f770389ec 100644 --- a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp @@ -45,6 +45,10 @@ PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) { "unsupported logical operator: {}", expr_->GetOpTypeString()); } + TargetBitmapView lvalid_view(lflat->GetValidRawData(), size); + TargetBitmapView rvalid_view(rflat->GetValidRawData(), size); + LogicalElementFunc func; + func(lvalid_view, rvalid_view, size); result = std::move(left); } diff --git a/internal/core/src/exec/operator/FilterBitsNode.cpp b/internal/core/src/exec/operator/FilterBitsNode.cpp index 7ad302cbec371..f7716a3fa19b0 100644 --- a/internal/core/src/exec/operator/FilterBitsNode.cpp +++ b/internal/core/src/exec/operator/FilterBitsNode.cpp @@ -68,6 +68,7 @@ PhyFilterBitsNode::GetOutput() { operator_context_->get_exec_context(), exprs_.get(), input_.get()); TargetBitmap bitset; + TargetBitmap valid_bitset; while (num_processed_rows_ < need_process_rows_) { exprs_->Eval(0, 1, true, eval_ctx, results_); @@ -79,13 +80,17 @@ PhyFilterBitsNode::GetOutput() { auto col_vec_size = col_vec->size(); TargetBitmapView view(col_vec->GetRawData(), col_vec_size); bitset.append(view); + TargetBitmapView valid_view(col_vec->GetValidRawData(), col_vec_size); + valid_bitset.append(valid_view); num_processed_rows_ += col_vec_size; } bitset.flip(); Assert(bitset.size() == need_process_rows_); + Assert(valid_bitset.size() == need_process_rows_); // num_processed_rows_ = need_process_rows_; std::vector col_res; - col_res.push_back(std::make_shared(std::move(bitset))); + col_res.push_back(std::make_shared(std::move(bitset), + std::move(valid_bitset))); std::chrono::high_resolution_clock::time_point scalar_end = std::chrono::high_resolution_clock::now(); double scalar_cost = diff --git a/internal/core/src/exec/operator/MvccNode.cpp b/internal/core/src/exec/operator/MvccNode.cpp index eeae9ebf3748d..e191baff60d4b 100644 --- a/internal/core/src/exec/operator/MvccNode.cpp +++ b/internal/core/src/exec/operator/MvccNode.cpp @@ -52,12 +52,13 @@ PhyMvccNode::GetOutput() { return nullptr; } - auto col_input = - is_source_node_ - ? std::make_shared(TargetBitmap(active_count_)) - : GetColumnVector(input_); + auto col_input = is_source_node_ ? std::make_shared( + TargetBitmap(active_count_), + TargetBitmap(active_count_)) + : GetColumnVector(input_); TargetBitmapView data(col_input->GetRawData(), col_input->size()); + // need to expose null? segment_->mask_with_timestamps(data, query_timestamp_); segment_->mask_with_delete(data, active_count_, query_timestamp_); is_finished_ = true; diff --git a/internal/core/src/index/ScalarIndexSort.cpp b/internal/core/src/index/ScalarIndexSort.cpp index 873f475e06df9..a036d2ef512f6 100644 --- a/internal/core/src/index/ScalarIndexSort.cpp +++ b/internal/core/src/index/ScalarIndexSort.cpp @@ -70,14 +70,14 @@ ScalarIndexSort::Build(size_t n, const T* values, const bool* valid_data) { } data_.reserve(n); total_num_rows_ = n; - valid_bitset = TargetBitmap(total_num_rows_, false); - idx_to_offsets_.resize(n, -1); + valid_bitset_ = TargetBitmap(total_num_rows_, false); + idx_to_offsets_.resize(n); T* p = const_cast(values); for (size_t i = 0; i < n; ++i, ++p) { if (!valid_data || valid_data[i]) { data_.emplace_back(IndexStructure(*p, i)); - valid_bitset.set(i); + valid_bitset_.set(i); } } @@ -102,7 +102,7 @@ ScalarIndexSort::BuildWithFieldData( } data_.reserve(length); - valid_bitset = TargetBitmap(total_num_rows_, false); + valid_bitset_ = TargetBitmap(total_num_rows_, false); int64_t offset = 0; for (const auto& data : field_datas) { auto slice_num = data->get_num_rows(); @@ -110,14 +110,14 @@ ScalarIndexSort::BuildWithFieldData( if (data->is_valid(i)) { auto value = reinterpret_cast(data->RawValue(i)); data_.emplace_back(IndexStructure(*value, offset)); - valid_bitset.set(offset); + valid_bitset_.set(offset); } offset++; } } std::sort(data_.begin(), data_.end()); - idx_to_offsets_.resize(total_num_rows_, -1); + idx_to_offsets_.resize(total_num_rows_); for (size_t i = 0; i < length; ++i) { idx_to_offsets_[data_[i].idx_] = i; } @@ -179,12 +179,12 @@ ScalarIndexSort::LoadWithoutAssemble(const BinarySet& index_binary, memcpy(&total_num_rows_, index_num_rows->data.get(), (size_t)index_num_rows->size); - idx_to_offsets_.resize(total_num_rows_, -1); - valid_bitset = TargetBitmap(total_num_rows_, false); + idx_to_offsets_.resize(total_num_rows_); + valid_bitset_ = TargetBitmap(total_num_rows_, false); memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size); for (size_t i = 0; i < data_.size(); ++i) { idx_to_offsets_[data_[i].idx_] = i; - valid_bitset.set(data_[i].idx_); + valid_bitset_.set(data_[i].idx_); } is_built_ = true; @@ -261,7 +261,7 @@ ScalarIndexSort::NotIn(const size_t n, const T* values) { } } // NotIn(null) and In(null) is both false, need to mask with IsNotNull operate - bitset &= valid_bitset; + bitset &= valid_bitset_; return bitset; } @@ -270,7 +270,7 @@ const TargetBitmap ScalarIndexSort::IsNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap bitset(total_num_rows_, true); - bitset &= valid_bitset; + bitset &= valid_bitset_; bitset.flip(); return bitset; } @@ -280,7 +280,7 @@ const TargetBitmap ScalarIndexSort::IsNotNull() { AssertInfo(is_built_, "index has not been built"); TargetBitmap bitset(total_num_rows_, true); - bitset &= valid_bitset; + bitset &= valid_bitset_; return bitset; } diff --git a/internal/core/src/index/ScalarIndexSort.h b/internal/core/src/index/ScalarIndexSort.h index f9ae09b3b3c69..1370b9dff89d6 100644 --- a/internal/core/src/index/ScalarIndexSort.h +++ b/internal/core/src/index/ScalarIndexSort.h @@ -127,8 +127,8 @@ class ScalarIndexSort : public ScalarIndex { std::vector> data_; std::shared_ptr file_manager_; size_t total_num_rows_{0}; - // generate valid_bitset to speed up NotIn and IsNull and IsNotNull operate - TargetBitmap valid_bitset; + // generate valid_bitset_ to speed up NotIn and IsNull and IsNotNull operate + TargetBitmap valid_bitset_; }; template diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index 3d71ecd8e6e35..870a58b0dd32d 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -137,7 +137,7 @@ StringIndexMarisa::Build(size_t n, } trie_.build(keyset, MARISA_LABEL_ORDER); - fill_str_ids(n, values); + fill_str_ids(n, values, valid_data); fill_offsets(); built_ = true; @@ -218,7 +218,7 @@ StringIndexMarisa::LoadWithoutAssemble(const BinarySet& set, auto str_ids = set.GetByName(MARISA_STR_IDS); auto str_ids_len = str_ids->size; - str_ids_.resize(str_ids_len / sizeof(size_t)); + str_ids_.resize(str_ids_len / sizeof(size_t), MARISA_NULL_KEY_ID); memcpy(str_ids_.data(), str_ids->data.get(), str_ids_len); fill_offsets(); @@ -496,9 +496,14 @@ StringIndexMarisa::PrefixMatch(std::string_view prefix) { } void -StringIndexMarisa::fill_str_ids(size_t n, const std::string* values) { - str_ids_.resize(n); +StringIndexMarisa::fill_str_ids(size_t n, + const std::string* values, + const bool* valid_data) { + str_ids_.resize(n, MARISA_NULL_KEY_ID); for (size_t i = 0; i < n; i++) { + if (valid_data && !valid_data[i]) { + continue; + } auto str = values[i]; auto str_id = lookup(str); AssertInfo(valid_str_id(str_id), "invalid marisa key"); diff --git a/internal/core/src/index/StringIndexMarisa.h b/internal/core/src/index/StringIndexMarisa.h index 9bf44d73619c8..f3dff120897f0 100644 --- a/internal/core/src/index/StringIndexMarisa.h +++ b/internal/core/src/index/StringIndexMarisa.h @@ -102,7 +102,7 @@ class StringIndexMarisa : public StringIndex { private: void - fill_str_ids(size_t n, const std::string* values); + fill_str_ids(size_t n, const std::string* values, const bool* valid_data); void fill_offsets(); @@ -124,7 +124,7 @@ class StringIndexMarisa : public StringIndex { private: Config config_; marisa::Trie trie_; - std::vector str_ids_; // used to retrieve. + std::vector str_ids_; // used to retrieve. std::map> str_ids_to_offsets_; bool built_ = false; std::shared_ptr file_manager_; diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index a778d86cb1f43..d6273055ec9e2 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -683,6 +683,11 @@ ReverseDataFromIndex(const index::IndexBase* index, data_array->set_field_id(field_meta.get_id().get()); data_array->set_type(static_cast( field_meta.get_data_type())); + auto nullable = field_meta.is_nullable(); + std::vector valid_data; + if (nullable) { + valid_data.resize(count); + } auto scalar_array = data_array->mutable_scalars(); switch (data_type) { @@ -692,9 +697,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_bool_data(); @@ -707,9 +717,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_int_data(); @@ -722,9 +737,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_int_data(); @@ -737,9 +757,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_int_data(); @@ -752,9 +777,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_long_data(); @@ -767,9 +797,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_float_data(); @@ -782,9 +817,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_double_data(); @@ -797,9 +837,14 @@ ReverseDataFromIndex(const index::IndexBase* index, std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { auto value = ptr->Reverse_Lookup(seg_offsets[i]); + // if has no value, means nullable must be true, no need to check nullable again here if (!value.has_value()) { + valid_data[i] = false; continue; } + if (nullable) { + valid_data[i] = true; + } raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); } auto obj = scalar_array->mutable_string_data(); @@ -812,6 +857,11 @@ ReverseDataFromIndex(const index::IndexBase* index, } } + if (nullable) { + *(data_array->mutable_valid_data()) = {valid_data.begin(), + valid_data.end()}; + } + return data_array; } diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 8d128338fea30..51079941a55ad 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -610,10 +610,11 @@ TEST_P(ExprTest, TestRangeNullable) { CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -869,11 +870,10 @@ TEST_P(ExprTest, TestExistsJson) { auto expr = std::make_shared(milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path)); - BitsetType final; - plan.filter_plannode_ = + auto plannode = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode( - plan.filter_plannode_.value(), seg_promote, N * num_iters, final); + BitsetType final = ExecuteQueryExpr( + plannode, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -936,11 +936,10 @@ TEST_P(ExprTest, TestExistsJsonNullable) { auto expr = std::make_shared(milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path)); - BitsetType final; - plan.filter_plannode_ = + auto plannode = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode( - plan.filter_plannode_.value(), seg_promote, N * num_iters, final); + BitsetType final = ExecuteQueryExpr( + plannode, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1080,7 +1079,6 @@ TEST_P(ExprTest, TestUnaryRangeJson) { auto final = ExecuteQueryExpr( plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); - EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; @@ -1282,8 +1280,8 @@ TEST_P(ExprTest, TestUnaryRangeJsonNullable) { BitsetType final; auto plan = std::make_shared( DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); - EXPECT_EQ(final.size(), N * num_iters); + final = ExecuteQueryExpr( + plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1348,7 +1346,8 @@ TEST_P(ExprTest, TestUnaryRangeJsonNullable) { BitsetType final; auto plan = std::make_shared( DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = ExecuteQueryExpr( + plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1417,7 +1416,8 @@ TEST_P(ExprTest, TestTermJson) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1496,7 +1496,8 @@ TEST_P(ExprTest, TestTermJsonNullable) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; @@ -1809,10 +1810,11 @@ TEST_P(ExprTest, TestCompare) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1940,10 +1942,11 @@ TEST_P(ExprTest, TestCompareNullable) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2071,10 +2074,11 @@ TEST_P(ExprTest, TestCompareNullable2) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2167,8 +2171,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -2300,8 +2307,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndexNullable) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -2433,8 +2443,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndexNullable2) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -2639,25 +2652,25 @@ TEST_P(ExprTest, TestSealedSegmentGetBatchSize) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT8); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT16); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT32); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT64); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::FLOAT); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::DOUBLE); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); std::cout << "end compare test" << std::endl; } @@ -2795,25 +2808,25 @@ TEST_P(ExprTest, TestCompareExprNullable) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT8); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT16); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT32); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT64); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::FLOAT); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::DOUBLE); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); std::cout << "end compare test" << std::endl; } @@ -2951,25 +2964,25 @@ TEST_P(ExprTest, TestCompareExprNullable2) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT8); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT16); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT32); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT64); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::FLOAT); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::DOUBLE); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); std::cout << "end compare test" << std::endl; } @@ -3213,7 +3226,7 @@ TEST(Expr, TestExprPerformance) { std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); for (int i = 0; i < 100; i++) { - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); } std::cout << "cost: " @@ -3589,7 +3602,7 @@ TEST(Expr, TestExprNOT) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; i++) { if (!valid_data[i]) { @@ -3742,7 +3755,7 @@ TEST_P(ExprTest, test_term_pk) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < 10; ++i) { EXPECT_EQ(final[i], true); @@ -3759,7 +3772,7 @@ TEST_P(ExprTest, test_term_pk) { expr = std::make_shared( expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { EXPECT_EQ(final[i], false); @@ -3881,7 +3894,7 @@ TEST_P(ExprTest, TestConjuctExpr) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); for (int i = 0; i < N; ++i) { EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; } @@ -3952,7 +3965,7 @@ TEST_P(ExprTest, TestConjuctExprNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); for (int i = 0; i < N; ++i) { EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; } @@ -4021,7 +4034,7 @@ TEST_P(ExprTest, TestUnaryBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 10; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -4101,7 +4114,7 @@ TEST_P(ExprTest, TestBinaryRangeBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 10; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -4175,7 +4188,7 @@ TEST_P(ExprTest, TestLogicalUnaryBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 50; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -4259,7 +4272,7 @@ TEST_P(ExprTest, TestBinaryLogicalBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 50; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -4339,7 +4352,7 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { int64_t all_cost = 0; for (int i = 0; i < 50; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -4412,7 +4425,7 @@ TEST_P(ExprTest, TestCompareExprBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 10; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -4573,7 +4586,7 @@ TEST_P(ExprTest, TestRefactorExprs) { std::make_shared(DEFAULT_PLANNODE_ID, expr); std::cout << "start test" << std::endl; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); std::cout << n << "cost: " << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4654,7 +4667,7 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { // load index for int64 field auto str2_col = raw_data.get_col(str2_fid); - auto str2_index = milvus::index::CreateScalarIndexSort(); + auto str2_index = milvus::index::CreateStringIndexMarisa(); str2_index->Build(N, str2_col.data()); load_index_info.field_id = str2_fid.get(); load_index_info.field_type = DataType::VARCHAR; @@ -4671,8 +4684,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -4782,7 +4798,7 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable) { // load index for int64 field auto nullable_col = raw_data.get_col(nullable_fid); auto valid_data_col = raw_data.get_col_valid(nullable_fid); - auto str2_index = milvus::index::CreateScalarIndexSort(); + auto str2_index = milvus::index::CreateStringIndexMarisa(); str2_index->Build(N, nullable_col.data(), valid_data_col.data()); load_index_info.field_id = nullable_fid.get(); load_index_info.field_type = DataType::VARCHAR; @@ -4799,8 +4815,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -4910,7 +4929,7 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable2) { // load index for int64 field auto nullable_col = raw_data.get_col(nullable_fid); auto valid_data_col = raw_data.get_col_valid(nullable_fid); - auto str2_index = milvus::index::CreateScalarIndexSort(); + auto str2_index = milvus::index::CreateStringIndexMarisa(); str2_index->Build(N, nullable_col.data(), valid_data_col.data()); load_index_info.field_id = nullable_fid.get(); load_index_info.field_type = DataType::VARCHAR; @@ -4927,8 +4946,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMarisNullable2) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -5636,10 +5658,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -6574,10 +6597,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeNullable) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -7411,10 +7435,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -8334,10 +8359,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONNullable) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -8415,7 +8441,8 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -8457,7 +8484,8 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -8546,7 +8574,8 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloatNullable) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -8591,7 +8620,8 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloatNullable) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -9089,8 +9119,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { *schema, binary_plan.data(), binary_plan.size()); BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg_promote, N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -9790,8 +9823,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndexNullable) { *schema, binary_plan.data(), binary_plan.size()); BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg_promote, N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -9987,10 +10023,11 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -10216,10 +10253,11 @@ TEST_P(ExprTest, TestUnaryRangeWithJSONNullable) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -10393,10 +10431,11 @@ TEST_P(ExprTest, TestTermWithJSON) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -10592,10 +10631,11 @@ TEST_P(ExprTest, TestTermWithJSONNullable) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -10934,10 +10974,11 @@ TEST_P(ExprTest, TestExistsWithJSONNullable) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -11032,7 +11073,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); // std::cout << "cost" // << std::chrono::duration_cast( // std::chrono::steady_clock::now() - start) @@ -11080,7 +11122,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11128,7 +11171,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11176,7 +11220,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11253,7 +11298,8 @@ TEST_P(ExprTest, TestTermInFieldJsonNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); // std::cout << "cost" // << std::chrono::duration_cast( // std::chrono::steady_clock::now() - start) @@ -11304,7 +11350,8 @@ TEST_P(ExprTest, TestTermInFieldJsonNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11355,7 +11402,8 @@ TEST_P(ExprTest, TestTermInFieldJsonNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11407,7 +11455,8 @@ TEST_P(ExprTest, TestTermInFieldJsonNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11622,7 +11671,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11671,7 +11721,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11720,7 +11771,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11769,7 +11821,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11847,7 +11900,8 @@ TEST_P(ExprTest, TestJsonContainsAnyNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11899,7 +11953,8 @@ TEST_P(ExprTest, TestJsonContainsAnyNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -11951,7 +12006,8 @@ TEST_P(ExprTest, TestJsonContainsAnyNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -12004,7 +12060,8 @@ TEST_P(ExprTest, TestJsonContainsAnyNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -12081,7 +12138,8 @@ TEST_P(ExprTest, TestJsonContainsAll) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -12250,7 +12308,8 @@ TEST_P(ExprTest, TestJsonContainsAll) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -12452,7 +12511,8 @@ TEST_P(ExprTest, TestJsonContainsAllNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -12510,7 +12570,8 @@ TEST_P(ExprTest, TestJsonContainsAllNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -13354,7 +13415,8 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArrayNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -13381,7 +13443,8 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArrayNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -13571,7 +13634,8 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -13601,7 +13665,8 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeNullable) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index b01d1d326501f..cb4ccf4131cbd 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -356,10 +356,11 @@ TEST(StringExpr, TermNullable) { auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -473,10 +474,11 @@ TEST(StringExpr, Compare) { auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -600,10 +602,11 @@ TEST(StringExpr, CompareNullable) { auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -731,10 +734,11 @@ TEST(StringExpr, CompareNullable2) { auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1180,10 +1184,11 @@ TEST(StringExpr, BinaryRangeNullable) { auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index fb0030b968310..71b3896882cc3 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -13017,7 +13017,6 @@ def test_search_collection_with_non_default_data_after_release_load(self, nq, _a @pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.GPU) - @pytest.mark.skip(reason="issue #36184") def test_search_after_different_index_with_params_none_default_data(self, varchar_scalar_index, numeric_scalar_index, null_data_percent, _async): """ From 37ce45bb001e04d42a78007701f066de711952f8 Mon Sep 17 00:00:00 2001 From: lixinguo Date: Tue, 15 Oct 2024 16:35:18 +0800 Subject: [PATCH 4/4] fix comment Signed-off-by: lixinguo --- internal/core/src/common/FieldData.cpp | 2 +- .../expression/BinaryArithOpEvalRangeExpr.cpp | 280 +++------------- .../expression/BinaryArithOpEvalRangeExpr.h | 163 +++------ .../src/exec/expression/BinaryRangeExpr.cpp | 18 +- .../src/exec/expression/BinaryRangeExpr.h | 68 +--- .../core/src/exec/expression/CompareExpr.cpp | 314 ++++++++---------- .../core/src/exec/expression/CompareExpr.h | 9 +- .../core/src/exec/expression/ExistsExpr.cpp | 2 +- internal/core/src/exec/expression/Expr.h | 22 +- .../src/exec/expression/JsonContainsExpr.cpp | 16 +- .../core/src/exec/expression/TermExpr.cpp | 10 +- .../core/src/exec/expression/UnaryExpr.cpp | 42 ++- internal/core/src/exec/expression/UnaryExpr.h | 84 ++--- internal/core/src/exec/operator/MvccNode.cpp | 3 +- .../operator/groupby/SearchGroupByOperator.h | 6 +- internal/core/src/index/BitmapIndex.cpp | 4 +- internal/core/src/index/ScalarIndexSort.cpp | 4 +- internal/core/src/index/StringIndexMarisa.cpp | 4 +- .../src/segcore/ChunkedSegmentSealedImpl.cpp | 16 +- .../core/src/segcore/SegmentSealedImpl.cpp | 20 +- internal/core/src/segcore/Utils.cpp | 48 +-- .../core/unittest/test_utils/AssertUtils.h | 19 +- 22 files changed, 424 insertions(+), 730 deletions(-) diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index f64e677d9a036..af089fa4696e3 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -69,7 +69,7 @@ FieldDataImpl::FillFieldData( ssize_t byte_count = (element_count + 7) / 8; // Note: if 'nullable == true` and valid_data is nullptr // means null_count == 0, will fill it with 0xFF - if (!valid_data) { + if (valid_data == nullptr) { valid_data_.assign(byte_count, 0xFF); } else { std::copy_n(valid_data, byte_count, valid_data_.data()); diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp index 6e24141d09596..e5b24ac4121ce 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -131,7 +131,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { #define BinaryArithRangeJSONCompare(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ - if (valid_data && !valid_data[i]) { \ + if (valid_data != nullptr && !valid_data[i]) { \ res[i] = false; \ valid_res[i] = false; \ continue; \ @@ -153,7 +153,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { #define BinaryArithRangeJSONCompareNotEqual(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ - if (valid_data && !valid_data[i]) { \ + if (valid_data != nullptr && !valid_data[i]) { \ res[i] = false; \ valid_res[i] = false; \ continue; \ @@ -211,7 +211,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = false; valid_res[i] = false; continue; @@ -265,7 +265,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = false; valid_res[i] = false; continue; @@ -319,7 +319,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = false; valid_res[i] = false; continue; @@ -373,7 +373,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = false; valid_res[i] = false; continue; @@ -427,7 +427,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = false; valid_res[i] = false; continue; @@ -481,7 +481,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = false; valid_res[i] = false; continue; @@ -558,7 +558,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { #define BinaryArithRangeArrayCompare(cmp) \ do { \ for (size_t i = 0; i < size; ++i) { \ - if (valid_data && !valid_data[i]) { \ + if (valid_data != nullptr && !valid_data[i]) { \ res[i] = false; \ valid_res[i] = false; \ continue; \ @@ -612,7 +612,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -659,7 +659,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -706,7 +706,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -753,7 +753,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -800,7 +800,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -847,7 +847,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { } case proto::plan::ArithOpType::ArrayLength: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -1318,13 +1318,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Add> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1332,13 +1326,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Sub> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1346,13 +1334,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Mul> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1360,13 +1342,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Div> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1374,13 +1350,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::Equal, proto::plan::ArithOpType::Mod> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } default: @@ -1399,13 +1369,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Add> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1413,13 +1377,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Sub> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1427,13 +1385,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Mul> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1441,13 +1393,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Div> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1455,13 +1401,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::NotEqual, proto::plan::ArithOpType::Mod> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } default: @@ -1480,13 +1420,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Add> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1494,13 +1428,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Sub> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1508,13 +1436,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Mul> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1522,13 +1444,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Div> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1536,13 +1452,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterThan, proto::plan::ArithOpType::Mod> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } default: @@ -1561,13 +1471,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Add> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1575,13 +1479,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Sub> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1589,13 +1487,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Mul> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1603,13 +1495,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Div> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1617,13 +1503,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::GreaterEqual, proto::plan::ArithOpType::Mod> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } default: @@ -1642,13 +1522,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Add> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1656,13 +1530,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Sub> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1670,13 +1538,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Mul> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1684,13 +1546,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Div> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1698,13 +1554,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessThan, proto::plan::ArithOpType::Mod> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } default: @@ -1723,13 +1573,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Add> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Sub: { @@ -1737,13 +1581,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Sub> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mul: { @@ -1751,13 +1589,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Mul> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Div: { @@ -1765,13 +1597,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Div> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } case proto::plan::ArithOpType::Mod: { @@ -1779,13 +1605,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { proto::plan::OpType::LessEqual, proto::plan::ArithOpType::Mod> func; - func(data, - valid_data, - size, - value, - right_operand, - res, - valid_res); + func(data, size, value, right_operand, res); break; } default: @@ -1803,6 +1623,16 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { "arithmetic eval expr: {}", op_type); } + // there is a batch operation in ArithOpElementFunc, + // so not divide data again for the reason that it may reduce performance if the null distribution is scattered + // but to mask res with valid_data after the batch operation. + if (valid_data != nullptr) { + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + res[i] = valid_res[i] = false; + } + } + } }; int64_t processed_size = ProcessDataChunks(execute_sub_batch, std::nullptr_t{}, diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h index 0db5a59f147ed..5eef111438591 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -97,12 +97,10 @@ struct ArithOpElementFunc { HighPrecisonType; void operator()(const T* src, - const bool* valid_data, size_t size, HighPrecisonType val, HighPrecisonType right_operand, - TargetBitmapView res, - TargetBitmapView valid_res) { + TargetBitmapView res) { /* // This is the original code, kept here for the documentation purposes for (int i = 0; i < size; ++i) { @@ -241,59 +239,27 @@ struct ArithOpElementFunc { } } */ - auto execute_sub_batch = [](const T* src, - size_t size, - HighPrecisonType val, - HighPrecisonType right_operand, - TargetBitmapView res) { - if (size == 0) { - return; - } - if constexpr (!std::is_same_v::op), + if constexpr (!std::is_same_v::op), + void>) { + constexpr auto cmp_op_cvt = CmpOpHelper::op; + if constexpr (!std::is_same_v::op), void>) { - constexpr auto cmp_op_cvt = CmpOpHelper::op; - if constexpr (!std::is_same_v< - decltype(ArithOpHelper::op), - void>) { - constexpr auto arith_op_cvt = ArithOpHelper::op; + constexpr auto arith_op_cvt = ArithOpHelper::op; - res.inplace_arith_compare( - src, right_operand, val, size); - } else { - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported arith type:{} for ArithOpElementFunc", - arith_op)); - } + res.inplace_arith_compare( + src, right_operand, val, size); } else { - PanicInfo(OpTypeInvalid, - fmt::format( - "unsupported cmp type:{} for ArithOpElementFunc", - cmp_op)); - } - }; - if (valid_data == nullptr) { - return execute_sub_batch(src, size, val, right_operand, res); - } - for (int left = 0; left < size; left++) { - for (int right = left; right < size; right++) { - if (valid_data[right]) { - if (right == size - 1) { - execute_sub_batch(src + left, - right - left, - val, - right_operand, - res + left); - } - continue; - } - valid_res[right] = false; - execute_sub_batch( - src + left, right - left, val, right_operand, res + left); - left = right; - break; + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); } + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported cmp type:{} for ArithOpElementFunc", + cmp_op)); } } }; @@ -315,30 +281,26 @@ struct ArithOpIndexFunc { HighPrecisonType right_operand) { TargetBitmap res(size); for (size_t i = 0; i < size; ++i) { - if (!index->Reverse_Lookup(i).has_value()) { + auto raw = index->Reverse_Lookup(i); + if (!raw.has_value()) { res[i] = false; continue; } if constexpr (cmp_op == proto::plan::OpType::Equal) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i).value() + - right_operand) == val; + res[i] = (raw.value() + right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i).value() - - right_operand) == val; + res[i] = (raw.value() - right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i).value() * - right_operand) == val; + res[i] = (raw.value() * right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i).value() / - right_operand) == val; + res[i] = (raw.value() / right_operand) == val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = (fmod(index->Reverse_Lookup(i).value(), - right_operand)) == val; + res[i] = (fmod(raw.value(), right_operand)) == val; } else { PanicInfo( OpTypeInvalid, @@ -348,24 +310,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::NotEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i).value() + - right_operand) != val; + res[i] = (raw.value() + right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i).value() - - right_operand) != val; + res[i] = (raw.value() - right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i).value() * - right_operand) != val; + res[i] = (raw.value() * right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i).value() / - right_operand) != val; + res[i] = (raw.value() / right_operand) != val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = (fmod(index->Reverse_Lookup(i).value(), - right_operand)) != val; + res[i] = (fmod(raw.value(), right_operand)) != val; } else { PanicInfo( OpTypeInvalid, @@ -375,24 +332,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::GreaterThan) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i).value() + - right_operand) > val; + res[i] = (raw.value() + right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i).value() - - right_operand) > val; + res[i] = (raw.value() - right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i).value() * - right_operand) > val; + res[i] = (raw.value() * right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i).value() / - right_operand) > val; + res[i] = (raw.value() / right_operand) > val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = (fmod(index->Reverse_Lookup(i).value(), - right_operand)) > val; + res[i] = (fmod(raw.value(), right_operand)) > val; } else { PanicInfo( OpTypeInvalid, @@ -402,24 +354,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::GreaterEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i).value() + - right_operand) >= val; + res[i] = (raw.value() + right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i).value() - - right_operand) >= val; + res[i] = (raw.value() - right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i).value() * - right_operand) >= val; + res[i] = (raw.value() * right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i).value() / - right_operand) >= val; + res[i] = (raw.value() / right_operand) >= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = (fmod(index->Reverse_Lookup(i).value(), - right_operand)) >= val; + res[i] = (fmod(raw.value(), right_operand)) >= val; } else { PanicInfo( OpTypeInvalid, @@ -429,24 +376,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::LessThan) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i).value() + - right_operand) < val; + res[i] = (raw.value() + right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i).value() - - right_operand) < val; + res[i] = (raw.value() - right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i).value() * - right_operand) < val; + res[i] = (raw.value() * right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i).value() / - right_operand) < val; + res[i] = (raw.value() / right_operand) < val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = (fmod(index->Reverse_Lookup(i).value(), - right_operand)) < val; + res[i] = (fmod(raw.value(), right_operand)) < val; } else { PanicInfo( OpTypeInvalid, @@ -456,24 +398,19 @@ struct ArithOpIndexFunc { } } else if constexpr (cmp_op == proto::plan::OpType::LessEqual) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { - res[i] = (index->Reverse_Lookup(i).value() + - right_operand) <= val; + res[i] = (raw.value() + right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Sub) { - res[i] = (index->Reverse_Lookup(i).value() - - right_operand) <= val; + res[i] = (raw.value() - right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mul) { - res[i] = (index->Reverse_Lookup(i).value() * - right_operand) <= val; + res[i] = (raw.value() * right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Div) { - res[i] = (index->Reverse_Lookup(i).value() / - right_operand) <= val; + res[i] = (raw.value() / right_operand) <= val; } else if constexpr (arith_op == proto::plan::ArithOpType::Mod) { - res[i] = (fmod(index->Reverse_Lookup(i).value(), - right_operand)) <= val; + res[i] = (fmod(raw.value(), right_operand)) <= val; } else { PanicInfo( OpTypeInvalid, diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index c1ce0ad4383a7..26467cd4646a3 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -261,16 +261,26 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { HighPrecisionType val2) { if (lower_inclusive && upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res, valid_res); + func(val1, val2, data, size, res); } else if (lower_inclusive && !upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res, valid_res); + func(val1, val2, data, size, res); } else if (!lower_inclusive && upper_inclusive) { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res, valid_res); + func(val1, val2, data, size, res); } else { BinaryRangeElementFunc func; - func(val1, val2, data, valid_data, size, res, valid_res); + func(val1, val2, data, size, res); + } + // there is a batch operation in BinaryRangeElementFunc, + // so not divide data again for the reason that it may reduce performance if the null distribution is scattered + // but to mask res with valid_data after the batch operation. + if (valid_data != nullptr) { + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + res[i] = valid_res[i] = false; + } + } } }; auto skip_index_func = diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index 66d37e8494eb1..145a8955ffe88 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -35,64 +35,26 @@ struct BinaryRangeElementFunc { T> HighPrecisionType; void - operator()(T val1, - T val2, - const T* src, - const bool* valid_data, - size_t n, - TargetBitmapView res, - TargetBitmapView valid_res) { - auto execute_sub_batch = [](T val1, - T val2, - const T* src, - size_t n, - TargetBitmapView res) { - if (n == 0) { - return; - } - if constexpr (lower_inclusive && upper_inclusive) { - res.inplace_within_range_val( - val1, val2, src, n); - } else if constexpr (lower_inclusive && !upper_inclusive) { - res.inplace_within_range_val( - val1, val2, src, n); - } else if constexpr (!lower_inclusive && upper_inclusive) { - res.inplace_within_range_val( - val1, val2, src, n); - } else { - res.inplace_within_range_val( - val1, val2, src, n); - } - }; - if (valid_data == nullptr) { - return execute_sub_batch(val1, val2, src, n, res); - } - for (int left = 0; left < n; left++) { - for (int right = left; right < n; right++) { - if (valid_data[right]) { - if (right == n - 1) { - execute_sub_batch( - val1, val2, src + left, right - left, res + left); - } - continue; - } - valid_res[right] = false; - execute_sub_batch( - val1, val2, src + left, right - left, res + left); - left = right; - break; - } + operator()(T val1, T val2, const T* src, size_t n, TargetBitmapView res) { + if constexpr (lower_inclusive && upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else if constexpr (lower_inclusive && !upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else if constexpr (!lower_inclusive && upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else { + res.inplace_within_range_val( + val1, val2, src, n); } } }; #define BinaryRangeJSONCompare(cmp) \ do { \ - if (valid_data && !valid_data[i]) { \ + if (valid_data != nullptr && !valid_data[i]) { \ res[i] = valid_res[i] = false; \ break; \ } \ @@ -156,7 +118,7 @@ struct BinaryRangeElementFuncForArray { TargetBitmapView res, TargetBitmapView valid_res) { for (size_t i = 0; i < n; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp index 27a7e6ad2e05d..5bc2e8dab15e1 100644 --- a/internal/core/src/exec/expression/CompareExpr.cpp +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -59,12 +59,19 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_->chunk_scalar_index(field_id, current_chunk_id)); } - return indexing.Reverse_Lookup(current_chunk_pos++); + auto raw = indexing.Reverse_Lookup(current_chunk_pos); + current_chunk_pos++; + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); }; } } auto chunk_data = segment_->chunk_data(field_id, current_chunk_id).data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id).valid_data(); auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); return [=, ¤t_chunk_id, ¤t_chunk_pos]() mutable -> const number { @@ -73,10 +80,16 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, current_chunk_pos = 0; chunk_data = segment_->chunk_data(field_id, current_chunk_id).data(); + chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); } - + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } return chunk_data[current_chunk_pos++]; }; } @@ -104,7 +117,12 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_->chunk_scalar_index( field_id, current_chunk_id)); } - return indexing.Reverse_Lookup(current_chunk_pos++); + auto raw = indexing.Reverse_Lookup(current_chunk_pos); + current_chunk_pos++; + if (!raw.has_value()) { + return std::nullopt; + } + return raw.value(); }; } } @@ -115,6 +133,9 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto chunk_data = segment_->chunk_data(field_id, current_chunk_id) .data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); return [=, @@ -127,16 +148,26 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_ ->chunk_data(field_id, current_chunk_id) .data(); + chunk_valid_data = + segment_ + ->chunk_data(field_id, current_chunk_id) + .valid_data(); current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); } - + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } return chunk_data[current_chunk_pos++]; }; } else { auto chunk_data = segment_->chunk_view(field_id, current_chunk_id) .first.data(); + auto chunk_valid_data = + segment_->chunk_data(field_id, current_chunk_id) + .valid_data(); auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); return [=, @@ -149,9 +180,17 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, ->chunk_view( field_id, current_chunk_id) .first.data(); + chunk_valid_data = segment_ + ->chunk_data( + field_id, current_chunk_id) + .valid_data(); current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); } + if (chunk_valid_data && !chunk_valid_data[current_chunk_pos]) { + current_chunk_pos++; + return std::nullopt; + } return std::string(chunk_data[current_chunk_pos++]); }; @@ -198,133 +237,97 @@ PhyCompareFilterExpr::GetChunkData(DataType data_type, template VectorPtr PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { - auto real_batch_size = GetNextBatchSize(); - if (real_batch_size == 0) { - return nullptr; - } - - auto res_vec = - std::make_shared(TargetBitmap(real_batch_size)); - TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + if (segment_->is_chunked()) { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } - auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); - auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); - int64_t processed_rows = 0; - for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; - ++chunk_id) { - auto chunk_size = chunk_id == num_chunk_ - 1 - ? active_count_ - chunk_id * size_per_chunk_ - : size_per_chunk_; auto left = GetChunkData(expr_->left_data_type_, expr_->left_field_id_, - chunk_id, - left_data_barrier); + is_left_indexed_, + left_current_chunk_id_, + left_current_chunk_pos_); auto right = GetChunkData(expr_->right_data_type_, expr_->right_field_id_, - chunk_id, - right_data_barrier); - - for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; - i < chunk_size; - ++i) { - if (!left(i).has_value() || !right(i).has_value()) { - res[processed_rows] = false; - } else { - res[processed_rows] = boost::apply_visitor( - milvus::query::Relational{}, - left(i).value(), - right(i).value()); + is_right_indexed_, + right_current_chunk_id_, + right_current_chunk_pos_); + for (int i = 0; i < real_batch_size; ++i) { + if (!left().has_value() || !right().has_value()) { + res[i] = false; + valid_res[i] = false; + continue; } - processed_rows++; + res[i] = + boost::apply_visitor(milvus::query::Relational{}, + left().value(), + right().value()); + } + return res_vec; + } else { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = std::make_shared( + TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); + valid_res.set(); + + auto left_data_barrier = + segment_->num_chunk_data(expr_->left_field_id_); + auto right_data_barrier = + segment_->num_chunk_data(expr_->right_field_id_); + + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk_ - 1 + ? active_count_ - chunk_id * size_per_chunk_ + : size_per_chunk_; + auto left = GetChunkData(expr_->left_data_type_, + expr_->left_field_id_, + chunk_id, + left_data_barrier); + auto right = GetChunkData(expr_->right_data_type_, + expr_->right_field_id_, + chunk_id, + right_data_barrier); + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + if (!left(i).has_value() || !right(i).has_value()) { + res[processed_rows] = false; + valid_res[processed_rows] = false; + } else { + res[processed_rows] = boost::apply_visitor( + milvus::query::Relational{}, + left(i).value(), + right(i).value()); + } + processed_rows++; - if (processed_rows >= batch_size_) { - current_chunk_id_ = chunk_id; - current_chunk_pos_ = i + 1; - return res_vec; + if (processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + return res_vec; + } } } + return res_vec; } - return res_vec; } -// template -// VectorPtr -// PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { -// if (segment_->is_chunked()) { -// auto real_batch_size = GetNextBatchSize(); -// if (real_batch_size == 0) { -// return nullptr; -// } - -// auto res_vec = -// std::make_shared(TargetBitmap(real_batch_size)); -// TargetBitmapView res(res_vec->GetRawData(), real_batch_size); - -// auto left = GetChunkData(expr_->left_data_type_, -// expr_->left_field_id_, -// is_left_indexed_, -// left_current_chunk_id_, -// left_current_chunk_pos_); -// auto right = GetChunkData(expr_->right_data_type_, -// expr_->right_field_id_, -// is_right_indexed_, -// right_current_chunk_id_, -// right_current_chunk_pos_); -// for (int i = 0; i < real_batch_size; ++i) { -// res[i] = boost::apply_visitor( -// milvus::query::Relational{}, left(), right()); -// } -// return res_vec; -// } else { -// auto real_batch_size = GetNextBatchSize(); -// if (real_batch_size == 0) { -// return nullptr; -// } - -// auto res_vec = -// std::make_shared(TargetBitmap(real_batch_size)); -// TargetBitmapView res(res_vec->GetRawData(), real_batch_size); - -// auto left_data_barrier = -// segment_->num_chunk_data(expr_->left_field_id_); -// auto right_data_barrier = -// segment_->num_chunk_data(expr_->right_field_id_); - -// int64_t processed_rows = 0; -// for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; -// ++chunk_id) { -// auto chunk_size = chunk_id == num_chunk_ - 1 -// ? active_count_ - chunk_id * size_per_chunk_ -// : size_per_chunk_; -// auto left = GetChunkData(expr_->left_data_type_, -// expr_->left_field_id_, -// chunk_id, -// left_data_barrier); -// auto right = GetChunkData(expr_->right_data_type_, -// expr_->right_field_id_, -// chunk_id, -// right_data_barrier); - -// for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; -// i < chunk_size; -// ++i) { -// res[processed_rows++] = boost::apply_visitor( -// milvus::query::Relational{}, -// left(i), -// right(i)); - -// if (processed_rows >= batch_size_) { -// current_chunk_id_ = chunk_id; -// current_chunk_pos_ = i + 1; -// return res_vec; -// } -// } -// } -// return res_vec; -// } -// } - template ChunkDataAccessor PhyCompareFilterExpr::GetChunkData(FieldId field_id, @@ -334,10 +337,11 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { return [&indexing](int i) -> const number { - if (!indexing.Reverse_Lookup(i).has_value()) { + auto raw = indexing.Reverse_Lookup(i); + if (!raw.has_value()) { return std::nullopt; } - return indexing.Reverse_Lookup(i).value(); + return raw.value(); }; } } @@ -362,10 +366,11 @@ PhyCompareFilterExpr::GetChunkData(FieldId field_id, segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { return [&indexing](int i) -> const number { - if (!indexing.Reverse_Lookup(i).has_value()) { + auto raw = indexing.Reverse_Lookup(i); + if (!raw.has_value()) { return std::nullopt; } - return indexing.Reverse_Lookup(i).value(); + return raw.value(); }; } } @@ -425,62 +430,6 @@ PhyCompareFilterExpr::GetChunkData(DataType data_type, } } -// template -// VectorPtr -// PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { -// auto real_batch_size = GetNextBatchSize(); -// if (real_batch_size == 0) { -// return nullptr; -// } - -// auto res_vec = std::make_shared( -// TargetBitmap(real_batch_size), TargetBitmap(real_batch_size)); -// TargetBitmapView res(res_vec->GetRawData(), real_batch_size); -// TargetBitmapView valid_res(res_vec->GetValidRawData(), real_batch_size); -// valid_res.set(); - -// auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); -// auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); - -// int64_t processed_rows = 0; -// for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; -// ++chunk_id) { -// auto chunk_size = chunk_id == num_chunk_ - 1 -// ? active_count_ - chunk_id * size_per_chunk_ -// : size_per_chunk_; -// auto left = GetChunkData(expr_->left_data_type_, -// expr_->left_field_id_, -// chunk_id, -// left_data_barrier); -// auto right = GetChunkData(expr_->right_data_type_, -// expr_->right_field_id_, -// chunk_id, -// right_data_barrier); - -// for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; -// i < chunk_size; -// ++i) { -// if (!left(i).has_value() || !right(i).has_value()) { -// res[processed_rows] = false; -// valid_res[processed_rows] = false; -// } else { -// res[processed_rows] = boost::apply_visitor( -// milvus::query::Relational{}, -// left(i).value(), -// right(i).value()); -// } -// processed_rows++; - -// if (processed_rows >= batch_size_) { -// current_chunk_id_ = chunk_id; -// current_chunk_pos_ = i + 1; -// return res_vec; -// } -// } -// } -// return res_vec; -// } - void PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { // For segment both fields has no index, can use SIMD to speed up. @@ -627,11 +576,10 @@ PhyCompareFilterExpr::ExecCompareRightType() { break; } default: - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported operator type for compare column expr: {}", - expr_type)); + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operator type for " + "compare column expr: {}", + expr_type)); } }; int64_t processed_size = diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index 569e305c5c83c..8f4aaaed53709 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -268,16 +268,19 @@ class PhyCompareFilterExpr : public Expr { template int64_t - ProcessBothDataChunks(FUNC func, TargetBitmapView res, ValTypes... values) { + ProcessBothDataChunks(FUNC func, + TargetBitmapView res, + TargetBitmapView valid_res, + ValTypes... values) { if (segment_->is_chunked()) { return ProcessBothDataChunksForMultipleChunk( - func, res, values...); + func, res, valid_res, values...); } else { return ProcessBothDataChunksForSingleChunk( - func, res, values...); + func, res, valid_res, values...); } } diff --git a/internal/core/src/exec/expression/ExistsExpr.cpp b/internal/core/src/exec/expression/ExistsExpr.cpp index 0fab44ebdc463..c73b4e007dc38 100644 --- a/internal/core/src/exec/expression/ExistsExpr.cpp +++ b/internal/core/src/exec/expression/ExistsExpr.cpp @@ -58,7 +58,7 @@ PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { TargetBitmapView valid_res, const std::string& pointer) { for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 734ecf0077c6e..307792a539ac2 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -260,6 +260,8 @@ class SegmentExpr : public Expr { if (!skip_func || !skip_func(skip_index, field_id_, 0)) { auto views_info = segment_->get_batch_views( field_id_, 0, current_data_chunk_pos_, need_size); + // first is the raw data, second is valid_data + // use valid_data to see if raw data is null func(views_info.first.data(), views_info.second.data(), need_size, @@ -335,6 +337,7 @@ class SegmentExpr : public Expr { FUNC func, std::function skip_func, TargetBitmapView res, + TargetBitmapView valid_res, ValTypes... values) { int64_t processed_size = 0; @@ -369,13 +372,21 @@ class SegmentExpr : public Expr { if constexpr (std::is_same_v || std::is_same_v) { if (segment_->type() == SegmentType::Sealed) { + // first is the raw data, second is valid_data + // use valid_data to see if raw data is null auto data_vec = segment_ ->get_batch_views( field_id_, i, data_pos, size) .first; + auto valid_data = segment_ + ->get_batch_views( + field_id_, i, data_pos, size) + .second; func(data_vec.data(), + valid_data.data(), size, res + processed_size, + valid_res + processed_size, values...); is_seal = true; } @@ -383,7 +394,16 @@ class SegmentExpr : public Expr { if (!is_seal) { auto chunk = segment_->chunk_data(field_id_, i); const T* data = chunk.data() + data_pos; - func(data, size, res + processed_size, values...); + const bool* valid_data = chunk.valid_data(); + if (valid_data != nullptr) { + valid_data += data_pos; + } + func(data, + valid_data, + size, + res + processed_size, + valid_res + processed_size, + values...); } } diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index 897cad544755f..b21714b4c8b6b 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -200,7 +200,7 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { return false; }; for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -266,7 +266,7 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { return false; }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -337,7 +337,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { return false; }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -398,7 +398,7 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { return tmp_elements.size() == 0; }; for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -468,7 +468,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { return tmp_elements.size() == 0; }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -598,7 +598,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { return tmp_elements_index.size() == 0; }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -679,7 +679,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { return exist_elements_index.size() == elements.size(); }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -801,7 +801,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return false; }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp index 9e1092661b4e3..fcb27a1c747a2 100644 --- a/internal/core/src/exec/expression/TermExpr.cpp +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -272,7 +272,7 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { return false; }; for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -331,7 +331,7 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { int index, const std::unordered_set& term_set) { for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -400,7 +400,7 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { return false; }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -472,7 +472,7 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { return terms.find(ValueType(x.value())) != terms.end(); }; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -592,7 +592,7 @@ PhyTermFilterExpr::ExecVisitorImplForData() { const std::unordered_set& vals) { TermElementFuncSet func; for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index ae10d9386d01a..ad3cd8cb294d1 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -507,7 +507,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { switch (op_type) { case proto::plan::GreaterThan: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -521,7 +521,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::GreaterEqual: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -535,7 +535,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::LessThan: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -549,7 +549,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::LessEqual: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -563,7 +563,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::Equal: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -583,7 +583,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::NotEqual: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -603,7 +603,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::PrefixMatch: { for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -621,7 +621,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { auto regex_pattern = translator(val); RegexMatcher matcher(regex_pattern); for (size_t i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } @@ -848,42 +848,42 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { switch (expr_type) { case proto::plan::GreaterThan: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::GreaterEqual: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::LessThan: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::LessEqual: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::Equal: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::NotEqual: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::PrefixMatch: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } case proto::plan::Match: { UnaryElementFunc func; - func(data, valid_data, size, val, res, valid_res); + func(data, size, val, res); break; } default: @@ -892,6 +892,16 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { fmt::format("unsupported operator type for unary expr: {}", expr_type)); } + // there is a batch operation in BinaryRangeElementFunc, + // so not divide data again for the reason that it may reduce performance if the null distribution is scattered + // but to mask res with valid_data after the batch operation. + if (valid_data != nullptr) { + for (int i = 0; i < size; i++) { + if (!valid_data[i]) { + res[i] = valid_res[i] = false; + } + } + } }; auto skip_index_func = [expr_type, val](const SkipIndex& skip_index, FieldId field_id, diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index d216decf07f47..71a8869ecd291 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -60,11 +60,9 @@ struct UnaryElementFunc { IndexInnerType; void operator()(const T* src, - const bool* valid_data, size_t size, IndexInnerType val, - TargetBitmapView res, - TargetBitmapView valid_res) { + TargetBitmapView res) { if constexpr (op == proto::plan::OpType::Match) { UnaryElementFuncForMatch func; func(src, size, val, res); @@ -98,67 +96,33 @@ struct UnaryElementFunc { } */ - auto execute_sub_batch = [](const T* src, - size_t size, - IndexInnerType val, - TargetBitmapView res) { - if (size == 0) { - return; - } - if constexpr (op == proto::plan::OpType::Equal) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::NotEqual) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::GreaterThan) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::LessThan) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::GreaterEqual) { - res.inplace_compare_val( - src, size, val); - } else if constexpr (op == proto::plan::OpType::LessEqual) { - res.inplace_compare_val( - src, size, val); - } else { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported op_type:{} for UnaryElementFunc", - op)); - } - }; - if constexpr (op == proto::plan::OpType::PrefixMatch) { for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { - res[i] = false; - continue; - } res[i] = milvus::query::Match( src[i], val, proto::plan::OpType::PrefixMatch); } - return; - } - if (!valid_data) { - return execute_sub_batch(src, size, val, res); - } - for (int left = 0; left < size; left++) { - for (int right = left; right < size; right++) { - if (valid_data[right]) { - if (right == size - 1) { - execute_sub_batch( - src + left, right - left, val, res + left); - } - continue; - } - valid_res[right] = false; - execute_sub_batch(src + left, right - left, val, res + left); - left = right; - break; - } + } else if constexpr (op == proto::plan::OpType::Equal) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::NotEqual) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::LessThan) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::LessEqual) { + res.inplace_compare_val( + src, size, val); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported op_type:{} for UnaryElementFunc", op)); } } }; @@ -191,7 +155,7 @@ struct UnaryElementFuncForArray { TargetBitmapView res, TargetBitmapView valid_res) { for (int i = 0; i < size; ++i) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { res[i] = valid_res[i] = false; continue; } diff --git a/internal/core/src/exec/operator/MvccNode.cpp b/internal/core/src/exec/operator/MvccNode.cpp index e191baff60d4b..98d7b4862abff 100644 --- a/internal/core/src/exec/operator/MvccNode.cpp +++ b/internal/core/src/exec/operator/MvccNode.cpp @@ -51,7 +51,8 @@ PhyMvccNode::GetOutput() { is_finished_ = true; return nullptr; } - + // the first vector is filtering result and second bitset is a valid bitset + // if valid_bitset[i]==false, means result[i] is null auto col_input = is_source_node_ ? std::make_shared( TargetBitmap(active_count_), TargetBitmap(active_count_)) diff --git a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h index e266b5b34ce82..e6a95c6603809 100644 --- a/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h +++ b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h @@ -100,9 +100,9 @@ class SealedDataGetter : public DataGetter { } return field_data_->operator[](idx); } else { - auto value = (*field_index_).Reverse_Lookup(idx); - AssertInfo(value.has_value(), "field data not found"); - return value.value(); + auto raw = (*field_index_).Reverse_Lookup(idx); + AssertInfo(raw.has_value(), "field data not found"); + return raw.value(); } } }; diff --git a/internal/core/src/index/BitmapIndex.cpp b/internal/core/src/index/BitmapIndex.cpp index c35e73a5902bb..b5bca930d57b5 100644 --- a/internal/core/src/index/BitmapIndex.cpp +++ b/internal/core/src/index/BitmapIndex.cpp @@ -93,7 +93,7 @@ BitmapIndex::Build(size_t n, const T* data, const bool* valid_data) { T* p = const_cast(data); for (int i = 0; i < n; ++i, ++p) { - if (!valid_data || valid_data[i]) { + if (valid_data == nullptr || valid_data[i]) { data_[*p].add(i); valid_bitset_.set(i); } @@ -1127,11 +1127,11 @@ BitmapIndex::Reverse_Lookup(size_t idx) const { } } } - return std::nullopt; PanicInfo(UnexpectedError, fmt::format( "scalar bitmap index can not lookup target value of index {}", idx)); + return std::nullopt; } template diff --git a/internal/core/src/index/ScalarIndexSort.cpp b/internal/core/src/index/ScalarIndexSort.cpp index a036d2ef512f6..8d55832b1d4c7 100644 --- a/internal/core/src/index/ScalarIndexSort.cpp +++ b/internal/core/src/index/ScalarIndexSort.cpp @@ -365,10 +365,10 @@ ScalarIndexSort::Reverse_Lookup(size_t idx) const { AssertInfo(idx < idx_to_offsets_.size(), "out of range of total count"); AssertInfo(is_built_, "index has not been built"); - auto offset = idx_to_offsets_[idx]; - if (offset < 0) { + if (!valid_bitset_[idx]) { return std::nullopt; } + auto offset = idx_to_offsets_[idx]; return data_[offset].a_; } diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index 870a58b0dd32d..289ba2409da86 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -130,7 +130,7 @@ StringIndexMarisa::Build(size_t n, { // fill key set. for (size_t i = 0; i < n; i++) { - if (!valid_data || valid_data[i]) { + if (valid_data == nullptr || valid_data[i]) { keyset.push_back(values[i].c_str()); } } @@ -501,7 +501,7 @@ StringIndexMarisa::fill_str_ids(size_t n, const bool* valid_data) { str_ids_.resize(n, MARISA_NULL_KEY_ID); for (size_t i = 0; i < n; i++) { - if (valid_data && !valid_data[i]) { + if (valid_data != nullptr && !valid_data[i]) { continue; } auto str = values[i]; diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index a95ae1ecd1665..9987347f3fe57 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -196,8 +196,9 @@ ChunkedSegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && int64_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk(int64_index->Reverse_Lookup(i), - i); + auto raw = int64_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "pk not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -210,8 +211,9 @@ ChunkedSegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && string_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - insert_record_.insert_pk( - string_index->Reverse_Lookup(i), i); + auto raw = string_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "pk not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -1630,7 +1632,11 @@ ChunkedSegmentSealedImpl::CreateTextIndex(FieldId field_id) { "converted to string index"); auto n = impl->Size(); for (size_t i = 0; i < n; i++) { - index->AddText(impl->Reverse_Lookup(i), i); + auto raw = impl->Reverse_Lookup(i); + if (!raw.has_value()) { + continue; + } + index->AddText(raw.value(), i); } } } diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index b7ef644deff06..ddbfbe44baa6c 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -198,10 +198,9 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && int64_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - AssertInfo(int64_index->Reverse_Lookup(i).has_value(), - "Primary key not found"); - insert_record_.insert_pk( - int64_index->Reverse_Lookup(i).value(), i); + auto raw = int64_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "Primary key not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -214,10 +213,9 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (!is_sorted_by_pk_ && insert_record_.empty_pks() && string_index->HasRawData()) { for (int i = 0; i < row_count; ++i) { - AssertInfo(string_index->Reverse_Lookup(i).has_value(), - "Primary key not found"); - insert_record_.insert_pk( - string_index->Reverse_Lookup(i).value(), i); + auto raw = string_index->Reverse_Lookup(i); + AssertInfo(raw.has_value(), "Primary key not found"); + insert_record_.insert_pk(raw.value(), i); } insert_record_.seal_pks(); } @@ -2091,11 +2089,11 @@ SegmentSealedImpl::CreateTextIndex(FieldId field_id) { "converted to string index"); auto n = impl->Size(); for (size_t i = 0; i < n; i++) { - auto value = impl->Reverse_Lookup(i); - if (!value.has_value()) { + auto raw = impl->Reverse_Lookup(i); + if (!raw.has_value()) { continue; } - index->AddText(impl->Reverse_Lookup(i).value(), i); + index->AddText(raw.value(), i); } } } diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index d6273055ec9e2..30b01caa86a4d 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -696,16 +696,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_bool_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -716,16 +716,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -736,16 +736,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -756,16 +756,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_int_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -776,16 +776,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_long_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -796,16 +796,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_float_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -816,16 +816,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_double_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; @@ -836,16 +836,16 @@ ReverseDataFromIndex(const index::IndexBase* index, auto ptr = dynamic_cast(index); std::vector raw_data(count); for (int64_t i = 0; i < count; ++i) { - auto value = ptr->Reverse_Lookup(seg_offsets[i]); + auto raw = ptr->Reverse_Lookup(seg_offsets[i]); // if has no value, means nullable must be true, no need to check nullable again here - if (!value.has_value()) { + if (!raw.has_value()) { valid_data[i] = false; continue; } if (nullable) { valid_data[i] = true; } - raw_data[i] = ptr->Reverse_Lookup(seg_offsets[i]).value(); + raw_data[i] = raw.value(); } auto obj = scalar_array->mutable_string_data(); *(obj->mutable_data()) = {raw_data.begin(), raw_data.end()}; diff --git a/internal/core/unittest/test_utils/AssertUtils.h b/internal/core/unittest/test_utils/AssertUtils.h index 16130fd513610..837e44fe768e1 100644 --- a/internal/core/unittest/test_utils/AssertUtils.h +++ b/internal/core/unittest/test_utils/AssertUtils.h @@ -139,7 +139,9 @@ template inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_EQ(index->Reverse_Lookup(offset).value(), arr[offset]); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_EQ(raw.value(), arr[offset]); } } @@ -147,8 +149,9 @@ template <> inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE( - compare_float(index->Reverse_Lookup(offset).value(), arr[offset])); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_TRUE(compare_float(raw.value(), arr[offset])); } } @@ -156,8 +159,9 @@ template <> inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE( - compare_double(index->Reverse_Lookup(offset).value(), arr[offset])); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_TRUE(compare_double(raw.value(), arr[offset])); } } @@ -166,8 +170,9 @@ inline void assert_reverse(ScalarIndex* index, const std::vector& arr) { for (size_t offset = 0; offset < arr.size(); ++offset) { - ASSERT_TRUE( - arr[offset].compare(index->Reverse_Lookup(offset).value()) == 0); + auto raw = index->Reverse_Lookup(offset); + ASSERT_TRUE(raw.has_value()); + ASSERT_TRUE(arr[offset].compare(raw.value()) == 0); } }