Skip to content

Commit

Permalink
fix: mask with valid data when preCheckOverflow (#37221)
Browse files Browse the repository at this point in the history
#37175

---------

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Oct 31, 2024
1 parent 2092dc0 commit b849249
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 38 deletions.
6 changes: 0 additions & 6 deletions internal/core/src/exec/expression/BinaryRangeExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,9 @@ PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1,
? active_count_ - overflow_check_pos_
: batch_size_;
overflow_check_pos_ += batch_size;
if (cached_overflow_res_ != nullptr &&
cached_overflow_res_->size() == batch_size) {
return cached_overflow_res_;
}
auto valid_res = ProcessChunksForValid<T>(is_index_mode_);
auto res_vec = std::make_shared<ColumnVector>(TargetBitmap(batch_size),
std::move(valid_res));
cached_overflow_res_ = res_vec;

return res_vec;
};

Expand Down
1 change: 0 additions & 1 deletion internal/core/src/exec/expression/BinaryRangeExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ class PhyBinaryRangeFilterExpr : public SegmentExpr {

private:
std::shared_ptr<const milvus::expr::BinaryRangeFilterExpr> expr_;
ColumnVectorPtr cached_overflow_res_{nullptr};
int64_t overflow_check_pos_{0};
};
} //namespace exec
Expand Down
2 changes: 1 addition & 1 deletion internal/core/src/exec/expression/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class SegmentExpr : public Expr {
template <typename T>
TargetBitmap
ProcessDataChunksForValid() {
TargetBitmap valid_result(batch_size_);
TargetBitmap valid_result(GetNextBatchSize());
valid_result.set();
int64_t processed_size = 0;
for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) {
Expand Down
36 changes: 8 additions & 28 deletions internal/core/src/exec/expression/UnaryExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,57 +754,37 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() {
? active_count_ - overflow_check_pos_
: batch_size_;
overflow_check_pos_ += batch_size;
if (cached_overflow_res_ != nullptr &&
cached_overflow_res_->size() == batch_size) {
return cached_overflow_res_;
}
auto valid = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
TargetBitmapView valid_res(res_vec->GetValidRawData(), batch_size);
switch (expr_->op_type_) {
case proto::plan::GreaterThan:
case proto::plan::GreaterEqual: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;

if (milvus::query::lt_lb<T>(val)) {
res.set();
res &= valid_res;
return res_vec;
}
return res_vec;
}
case proto::plan::LessThan:
case proto::plan::LessEqual: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;

if (milvus::query::gt_ub<T>(val)) {
res.set();
res &= valid_res;
return res_vec;
}
return res_vec;
}
case proto::plan::Equal: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;

res.reset();
return res_vec;
}
case proto::plan::NotEqual: {
auto valid_res = ProcessChunksForValid<T>(CanUseIndex<T>());
auto res_vec = std::make_shared<ColumnVector>(
TargetBitmap(batch_size), std::move(valid_res));
TargetBitmapView res(res_vec->GetRawData(), batch_size);
cached_overflow_res_ = res_vec;

res.set();
res &= valid_res;
return res_vec;
}
default: {
Expand Down
1 change: 0 additions & 1 deletion internal/core/src/exec/expression/UnaryExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr {

private:
std::shared_ptr<const milvus::expr::UnaryRangeFilterExpr> expr_;
ColumnVectorPtr cached_overflow_res_{nullptr};
int64_t overflow_check_pos_{0};
};
} // namespace exec
Expand Down
117 changes: 116 additions & 1 deletion internal/core/unittest/test_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,117 @@ TEST_P(ExprTest, TestRangeNullable) {
}
return v != 2000;
}},
{R"(binary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
lower_inclusive: false,
upper_inclusive: false,
lower_value: <
int64_val: 1000000
>
upper_value: <
int64_val: 1000001
>
>)",
[](int v, bool valid) { return false; }},
{R"(binary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
lower_inclusive: false,
upper_inclusive: false,
lower_value: <
int64_val: -1000001
>
upper_value: <
int64_val: -1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: GreaterEqual,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: GreaterEqual,
value: <
int64_val: -1000000
>
>)",
[](int v, bool valid) {
if (!valid) {
return false;
}
return true;
}},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: LessEqual,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) {
if (!valid) {
return false;
}
return true;
}},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: LessThan,
value: <
int64_val: -1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: Equal,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) { return false; }},
{R"(unary_range_expr: <
column_info: <
field_id: 103
data_type: Int8
>
op: NotEqual,
value: <
int64_val: 1000000
>
>)",
[](int v, bool valid) {
if (!valid) {
return false;
}
return true;
}},
};

std::string raw_plan_tmp = R"(vector_anns: <
Expand All @@ -582,6 +693,9 @@ TEST_P(ExprTest, TestRangeNullable) {
auto nullable_fid =
schema->AddDebugField("nullable", DataType::INT64, true);

auto nullable_fid_pre_check =
schema->AddDebugField("pre_check", DataType::INT8, true);

auto seg = CreateGrowingSegment(schema, empty_index_meta);
int N = 1000;
std::vector<int> data_col;
Expand Down Expand Up @@ -625,7 +739,8 @@ TEST_P(ExprTest, TestRangeNullable) {
auto val = data_col[i];
auto valid_data = valid_data_col[i];
auto ref = ref_func(val, valid_data);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
ASSERT_EQ(ans, ref)
<< clause << "@" << i << "!!" << val << "!!" << valid_data;
}
}
}
Expand Down

0 comments on commit b849249

Please sign in to comment.