From f93ea840cf446e6953307a09a5320188bb5d8c13 Mon Sep 17 00:00:00 2001 From: luzhang Date: Fri, 2 Aug 2024 16:44:56 +0800 Subject: [PATCH] enhance: refactor executor framework V2 Signed-off-by: luzhang --- internal/core/src/bitset/bitset.h | 2 +- internal/core/src/common/Types.h | 1 + internal/core/src/common/Utils.h | 17 +- internal/core/src/common/Vector.h | 10 + internal/core/src/exec/CMakeLists.txt | 4 + internal/core/src/exec/Driver.cpp | 19 + internal/core/src/exec/QueryContext.h | 67 + .../core/src/exec/operator/CallbackSink.h | 5 + internal/core/src/exec/operator/CountNode.cpp | 72 + internal/core/src/exec/operator/CountNode.h | 78 + .../core/src/exec/operator/FilterBits.cpp | 77 +- internal/core/src/exec/operator/FilterBits.h | 6 + internal/core/src/exec/operator/MvccNode.cpp | 75 + internal/core/src/exec/operator/MvccNode.h | 78 + internal/core/src/exec/operator/Operator.h | 10 + .../core/src/exec/operator/VectorSearch.cpp | 128 + .../core/src/exec/operator/VectorSearch.h | 83 + .../groupby/SearchGroupByOperator.cpp | 7 +- .../operator}/groupby/SearchGroupByOperator.h | 6 +- internal/core/src/plan/PlanNode.h | 108 +- internal/core/src/plan/PlanNodeIdGenerator.h | 68 + internal/core/src/query/CMakeLists.txt | 13 +- .../core/src/query/ExecPlanNodeVisitor.cpp | 222 ++ .../{generated => }/ExecPlanNodeVisitor.h | 29 +- internal/core/src/query/Expr.h | 359 -- internal/core/src/query/ExprImpl.h | 115 - internal/core/src/query/Plan.cpp | 17 - internal/core/src/query/PlanImpl.h | 1 + .../src/query/{generated => }/PlanNode.cpp | 0 internal/core/src/query/PlanNode.h | 8 +- .../query/{generated => }/PlanNodeVisitor.h | 0 internal/core/src/query/PlanProto.cpp | 813 +--- internal/core/src/query/PlanProto.h | 36 +- internal/core/src/query/Relational.h | 1 - internal/core/src/query/SearchOnIndex.cpp | 14 +- internal/core/src/query/SearchOnSealed.cpp | 14 +- internal/core/src/query/Utils.h | 1 - internal/core/src/query/generated/.gitignore | 3 - .../src/query/generated/ExecExprVisitor.h | 249 -- internal/core/src/query/generated/Expr.cpp | 66 - .../core/src/query/generated/ExprVisitor.h | 52 - .../query/generated/ExtractInfoExprVisitor.h | 59 - .../generated/ExtractInfoPlanNodeVisitor.h | 47 - .../src/query/generated/ShowExprVisitor.h | 82 - .../src/query/generated/ShowPlanNodeVisitor.h | 60 - .../src/query/generated/VerifyExprVisitor.h | 58 - .../query/generated/VerifyPlanNodeVisitor.h | 49 - .../src/query/visitors/ExecExprVisitor.cpp | 3525 ----------------- .../query/visitors/ExecPlanNodeVisitor.cpp | 338 -- .../query/visitors/ExtractInfoExprVisitor.cpp | 83 - .../visitors/ExtractInfoPlanNodeVisitor.cpp | 86 - .../src/query/visitors/ShowExprVisitor.cpp | 373 -- .../query/visitors/ShowPlanNodeVisitor.cpp | 175 - .../src/query/visitors/VerifyExprVisitor.cpp | 65 - .../query/visitors/VerifyPlanNodeVisitor.cpp | 53 - internal/core/src/segcore/DeletedRecord.h | 2 +- .../core/src/segcore/SegmentGrowingImpl.cpp | 4 +- .../core/src/segcore/SegmentGrowingImpl.h | 4 +- .../core/src/segcore/SegmentInterface.cpp | 10 +- internal/core/src/segcore/SegmentInterface.h | 4 +- .../core/src/segcore/SegmentSealedImpl.cpp | 4 +- internal/core/src/segcore/SegmentSealedImpl.h | 4 +- .../core/unittest/test_always_true_expr.cpp | 10 +- internal/core/unittest/test_array_expr.cpp | 71 +- .../unittest/test_array_inverted_index.cpp | 11 +- internal/core/unittest/test_c_api.cpp | 102 +- internal/core/unittest/test_chunk_vector.cpp | 1 - internal/core/unittest/test_exec.cpp | 6 +- internal/core/unittest/test_expr.cpp | 350 +- .../unittest/test_expr_materialized_view.cpp | 2 +- internal/core/unittest/test_float16.cpp | 68 +- internal/core/unittest/test_growing.cpp | 3 +- .../core/unittest/test_integer_overflow.cpp | 15 +- internal/core/unittest/test_plan_proto.cpp | 2 +- internal/core/unittest/test_query.cpp | 25 +- internal/core/unittest/test_regex_query.cpp | 37 +- internal/core/unittest/test_retrieve.cpp | 33 +- internal/core/unittest/test_sealed.cpp | 9 +- internal/core/unittest/test_string_expr.cpp | 45 +- .../core/unittest/test_utils/GenExprProto.h | 42 + 80 files changed, 1595 insertions(+), 7256 deletions(-) create mode 100644 internal/core/src/exec/operator/CountNode.cpp create mode 100644 internal/core/src/exec/operator/CountNode.h create mode 100644 internal/core/src/exec/operator/MvccNode.cpp create mode 100644 internal/core/src/exec/operator/MvccNode.h create mode 100644 internal/core/src/exec/operator/VectorSearch.cpp create mode 100644 internal/core/src/exec/operator/VectorSearch.h rename internal/core/src/{query => exec/operator}/groupby/SearchGroupByOperator.cpp (98%) rename internal/core/src/{query => exec/operator}/groupby/SearchGroupByOperator.h (99%) create mode 100644 internal/core/src/plan/PlanNodeIdGenerator.h create mode 100644 internal/core/src/query/ExecPlanNodeVisitor.cpp rename internal/core/src/query/{generated => }/ExecPlanNodeVisitor.h (78%) delete mode 100644 internal/core/src/query/Expr.h delete mode 100644 internal/core/src/query/ExprImpl.h rename internal/core/src/query/{generated => }/PlanNode.cpp (100%) rename internal/core/src/query/{generated => }/PlanNodeVisitor.h (100%) delete mode 100644 internal/core/src/query/generated/.gitignore delete mode 100644 internal/core/src/query/generated/ExecExprVisitor.h delete mode 100644 internal/core/src/query/generated/Expr.cpp delete mode 100644 internal/core/src/query/generated/ExprVisitor.h delete mode 100644 internal/core/src/query/generated/ExtractInfoExprVisitor.h delete mode 100644 internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h delete mode 100644 internal/core/src/query/generated/ShowExprVisitor.h delete mode 100644 internal/core/src/query/generated/ShowPlanNodeVisitor.h delete mode 100644 internal/core/src/query/generated/VerifyExprVisitor.h delete mode 100644 internal/core/src/query/generated/VerifyPlanNodeVisitor.h delete mode 100644 internal/core/src/query/visitors/ExecExprVisitor.cpp delete mode 100644 internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp delete mode 100644 internal/core/src/query/visitors/ExtractInfoExprVisitor.cpp delete mode 100644 internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp delete mode 100644 internal/core/src/query/visitors/ShowExprVisitor.cpp delete mode 100644 internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp delete mode 100644 internal/core/src/query/visitors/VerifyExprVisitor.cpp delete mode 100644 internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp diff --git a/internal/core/src/bitset/bitset.h b/internal/core/src/bitset/bitset.h index 27a659ae14560..7a9ed5ecc1c48 100644 --- a/internal/core/src/bitset/bitset.h +++ b/internal/core/src/bitset/bitset.h @@ -797,13 +797,13 @@ class BitsetBase { this->data(), other.data(), this->offset(), other.offset(), size); } - private: // Return the starting bit offset in our container. inline size_type offset() const { return as_derived().offset_impl(); } + private: // CRTP inline ImplT& as_derived() { diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index e9f6fe042821a..f76c112fe120b 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -376,6 +376,7 @@ using SegOffset = //using BitsetType = boost::dynamic_bitset<>; using BitsetType = CustomBitset; +using BitsetTypeView = CustomBitsetView; using BitsetTypePtr = std::shared_ptr; using BitsetTypeOpt = std::optional; diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index feb7b2bb1746b..c96aed0f92f5b 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -268,7 +268,8 @@ SparseBytesToRows(const Iterable& rows, const bool validate = false) { // SparseRowsToProto converts a list of knowhere::sparse::SparseRow to // a milvus::proto::schema::SparseFloatArray. The resulting proto is a deep copy // of the source data. source(i) returns the i-th row to be copied. -inline void SparseRowsToProto( +inline void +SparseRowsToProto( const std::function*(size_t)>& source, int64_t rows, @@ -287,4 +288,18 @@ inline void SparseRowsToProto( proto->set_dim(max_dim); } +class Defer { + public: + Defer(std::function fn) : fn_(fn) { + } + ~Defer() { + fn_(); + } + + private: + std::function fn_; +}; + +#define DeferLambda(fn) Defer Defer_##__COUNTER__(fn); + } // namespace milvus diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index bdffd67689cf9..2e54341f52605 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -129,6 +129,16 @@ class RowVector : public BaseVector { } } + RowVector(std::vector&& children) + : BaseVector(DataType::ROW, 0) { + children_values_ = std::move(children); + for (auto& child : children_values_) { + if (child->size() > length_) { + length_ = child->size(); + } + } + } + const std::vector& childrens() { return children_values_; diff --git a/internal/core/src/exec/CMakeLists.txt b/internal/core/src/exec/CMakeLists.txt index 8e134f5128d35..5a6d7eced9694 100644 --- a/internal/core/src/exec/CMakeLists.txt +++ b/internal/core/src/exec/CMakeLists.txt @@ -24,6 +24,10 @@ set(MILVUS_EXEC_SRCS expression/ExistsExpr.cpp operator/FilterBits.cpp operator/Operator.cpp + operator/MvccNode.cpp + operator/VectorSearch.cpp + operator/CountNode.cpp + operator/groupby/SearchGroupByOperator.cpp Driver.cpp Task.cpp ) diff --git a/internal/core/src/exec/Driver.cpp b/internal/core/src/exec/Driver.cpp index c2ee0c5580fe9..e6699d8a6a0cd 100644 --- a/internal/core/src/exec/Driver.cpp +++ b/internal/core/src/exec/Driver.cpp @@ -19,9 +19,13 @@ #include #include +#include "common/EasyAssert.h" #include "exec/operator/CallbackSink.h" +#include "exec/operator/CountNode.h" #include "exec/operator/FilterBits.h" +#include "exec/operator/MvccNode.h" #include "exec/operator/Operator.h" +#include "exec/operator/VectorSearch.h" #include "exec/Task.h" #include "common/EasyAssert.h" @@ -52,6 +56,21 @@ DriverFactory::CreateDriver(std::unique_ptr ctx, plannode)) { operators.push_back( std::make_unique(id, ctx.get(), filternode)); + } else if (auto mvccnode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back( + std::make_unique(id, ctx.get(), mvccnode)); + } else if (auto countnode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back( + std::make_unique(id, ctx.get(), countnode)); + } else if (auto vectorsearchnode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back(std::make_unique( + id, ctx.get(), vectorsearchnode)); } // TODO: add more operators } diff --git a/internal/core/src/exec/QueryContext.h b/internal/core/src/exec/QueryContext.h index dbda904e08088..63c1577973707 100644 --- a/internal/core/src/exec/QueryContext.h +++ b/internal/core/src/exec/QueryContext.h @@ -225,6 +225,61 @@ class QueryContext : public Context { return active_count_; } + bool + get_pk_term_offset_cache_initialized() const { + return pk_term_offset_cache_initialized_; + } + + void + set_pk_term_offset_cache_initialized(bool val) { + pk_term_offset_cache_initialized_ = val; + } + + void + set_pk_term_offset_cache(std::vector&& val) { + pk_term_offset_cache_ = std::move(val); + } + + milvus::SearchInfo + get_search_info() { + return search_info_; + } + + const query::PlaceholderGroup* + get_placeholder_group() { + return placeholder_group_; + } + + void + set_search_info(const milvus::SearchInfo& search_info) { + search_info_ = search_info; + } + + void + set_placeholder_group(const query::PlaceholderGroup* placeholder_group) { + placeholder_group_ = placeholder_group; + } + + void + set_search_result(milvus::SearchResult&& result) { + search_result_ = std::move(result); + } + + milvus::SearchResult&& + get_search_result() { + return std::move(search_result_); + } + + void + set_retrieve_result(milvus::RetrieveResult&& result) { + retrieve_result_ = std::move(result); + } + + milvus::RetrieveResult&& + get_retrieve_result() { + return std::move(retrieve_result_); + } + private: folly::Executor* executor_; //folly::Executor::KeepAlive<> executor_keepalive_; @@ -238,6 +293,18 @@ class QueryContext : public Context { int64_t active_count_; // timestamp this query generate milvus::Timestamp query_timestamp_; + + // used for pk term optimization + bool pk_term_offset_cache_initialized_; + std::vector pk_term_offset_cache_; + + // used for vector search + milvus::SearchInfo search_info_; + const query::PlaceholderGroup* placeholder_group_; + + // used for store segment search/retrieve result + milvus::SearchResult search_result_; + milvus::RetrieveResult retrieve_result_; }; // Represent the state of one thread of query execution. diff --git a/internal/core/src/exec/operator/CallbackSink.h b/internal/core/src/exec/operator/CallbackSink.h index 5e5c7479b5776..d0f5e2d37afc1 100644 --- a/internal/core/src/exec/operator/CallbackSink.h +++ b/internal/core/src/exec/operator/CallbackSink.h @@ -71,6 +71,11 @@ class CallbackSink : public Operator { return BlockingReason::kNotBlocked; } + virtual std::string + ToString() const override { + return "CallbackSink"; + } + private: void Close() override { diff --git a/internal/core/src/exec/operator/CountNode.cpp b/internal/core/src/exec/operator/CountNode.cpp new file mode 100644 index 0000000000000..2c298f02a54fa --- /dev/null +++ b/internal/core/src/exec/operator/CountNode.cpp @@ -0,0 +1,72 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "CountNode.h" + +namespace milvus { +namespace exec { + +static std::unique_ptr +wrap_num_entities(int64_t cnt, int64_t size) { + auto retrieve_result = std::make_unique(); + DataArray arr; + arr.set_type(milvus::proto::schema::Int64); + auto scalar = arr.mutable_scalars(); + scalar->mutable_long_data()->mutable_data()->Add(cnt); + retrieve_result->field_data_ = {arr}; + retrieve_result->total_data_cnt_ = size; + return retrieve_result; +} + +PhyCountNode::PhyCountNode(int32_t operator_id, + DriverContext* driverctx, + const std::shared_ptr& node) + : Operator(driverctx, node->output_type(), operator_id, node->id()) { + ExecContext* exec_context = operator_context_->get_exec_context(); + query_context_ = exec_context->get_query_context(); + segment_ = query_context_->get_segment(); + query_timestamp_ = query_context_->get_query_timestamp(); + active_count_ = query_context_->get_active_count(); +} + +void +PhyCountNode::AddInput(RowVectorPtr& input) { + input_ = std::move(input); +} + +RowVectorPtr +PhyCountNode::GetOutput() { + if (is_finished_ || !no_more_input_) { + return nullptr; + } + + auto col_input = GetColumnVector(input_); + TargetBitmapView view(col_input->GetRawData(), col_input->size()); + auto cnt = view.size() - view.count(); + query_context_->set_retrieve_result( + std::move(*(wrap_num_entities(cnt, view.size())))); + is_finished_ = true; + + return input_; +} + +bool +PhyCountNode::IsFinished() { + return is_finished_; +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/CountNode.h b/internal/core/src/exec/operator/CountNode.h new file mode 100644 index 0000000000000..cfb9512a555b2 --- /dev/null +++ b/internal/core/src/exec/operator/CountNode.h @@ -0,0 +1,78 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "exec/Driver.h" +#include "exec/expression/Expr.h" +#include "exec/operator/Operator.h" +#include "exec/QueryContext.h" + +namespace milvus { +namespace exec { + +class PhyCountNode : public Operator { + public: + PhyCountNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& node); + + bool + IsFilter() override { + return false; + } + + bool + NeedInput() const override { + return !is_finished_; + } + + void + AddInput(RowVectorPtr& input); + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override; + + void + Close() override { + } + + BlockingReason + IsBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + virtual std::string + ToString() const override { + return "PhyCountNode"; + } + + private: + const segcore::SegmentInternalInterface* segment_; + milvus::Timestamp query_timestamp_; + int64_t active_count_; + QueryContext* query_context_; + bool is_finished_{false}; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/FilterBits.cpp b/internal/core/src/exec/operator/FilterBits.cpp index ac7a19d1814db..608c4c20d84b8 100644 --- a/internal/core/src/exec/operator/FilterBits.cpp +++ b/internal/core/src/exec/operator/FilterBits.cpp @@ -28,11 +28,11 @@ FilterBits::FilterBits( filter->id(), "FilterBits") { ExecContext* exec_context = operator_context_->get_exec_context(); - QueryContext* query_context = exec_context->get_query_context(); + query_context_ = exec_context->get_query_context(); std::vector filters; filters.emplace_back(filter->filter()); exprs_ = std::make_unique(filters, exec_context); - need_process_rows_ = query_context->get_active_count(); + need_process_rows_ = query_context_->get_active_count(); num_processed_rows_ = 0; } @@ -61,21 +61,76 @@ FilterBits::GetOutput() { return nullptr; } + std::chrono::high_resolution_clock::time_point scalar_start = + std::chrono::high_resolution_clock::now(); + EvalCtx eval_ctx( operator_context_->get_exec_context(), exprs_.get(), input_.get()); - exprs_->Eval(0, 1, true, eval_ctx, results_); + TargetBitmap bitset; + while (num_processed_rows_ < need_process_rows_) { + exprs_->Eval(0, 1, true, eval_ctx, results_); + + AssertInfo( + results_.size() == 1 && results_[0] != nullptr, + "FilterBits result size should be size one and not be nullptr"); + + if (results_[0]->type() == DataType::ROW) { + auto row_vec = std::dynamic_pointer_cast(results_[0]); + auto col_vec = + std::dynamic_pointer_cast(row_vec->child(0)); + auto col_vec_size = col_vec->size(); + TargetBitmapView view(col_vec->GetRawData(), col_vec_size); + bitset.append(view); + num_processed_rows_ += col_vec_size; - AssertInfo(results_.size() == 1 && results_[0] != nullptr, - "FilterBits result size should be one and not be nullptr"); + // check whether can use pk term optimization, + // store info to query context. + if (!query_context_->get_pk_term_offset_cache_initialized()) { + auto cache_offset_vec = + std::dynamic_pointer_cast(row_vec->child(1)); - if (results_[0]->type() == DataType::ROW) { - auto row_vec = std::dynamic_pointer_cast(results_[0]); - num_processed_rows_ += row_vec->child(0)->size(); - } else { - num_processed_rows_ += results_[0]->size(); + // if get empty cache offset, means that no row heated all the segment, + // so no need to get next batch. + if (cache_offset_vec->size() == 0) { + bitset.resize(need_process_rows_); + num_processed_rows_ = need_process_rows_; + break; + } + + // cached pk term offset to query context + // ensure query context is safe-thread + auto cache_offset_data = + (int64_t*)cache_offset_vec->GetRawData(); + std::vector cached_offset; + for (size_t i = 0; i < cache_offset_vec->size(); i++) { + cached_offset.push_back(cache_offset_data[i]); + } + query_context_->set_pk_term_offset_cache( + std::move(cached_offset)); + query_context_->set_pk_term_offset_cache_initialized(true); + } + } else { + auto col_vec = std::dynamic_pointer_cast(results_[0]); + auto col_vec_size = col_vec->size(); + TargetBitmapView view(col_vec->GetRawData(), col_vec_size); + bitset.append(view); + num_processed_rows_ += col_vec_size; + } } - return std::make_shared(results_); + bitset.flip(); + Assert(bitset.size() == need_process_rows_); + // num_processed_rows_ = need_process_rows_; + std::vector col_res; + col_res.push_back(std::make_shared(std::move(bitset))); + std::chrono::high_resolution_clock::time_point scalar_end = + std::chrono::high_resolution_clock::now(); + double scalar_cost = + std::chrono::duration(scalar_end - scalar_start) + .count(); + monitor::internal_core_search_latency_scalar.Observe(scalar_cost); + + return std::make_shared(col_res); } } // namespace exec diff --git a/internal/core/src/exec/operator/FilterBits.h b/internal/core/src/exec/operator/FilterBits.h index 462c8dc5e50a9..ca29c3ecd6804 100644 --- a/internal/core/src/exec/operator/FilterBits.h +++ b/internal/core/src/exec/operator/FilterBits.h @@ -65,8 +65,14 @@ class FilterBits : public Operator { bool AllInputProcessed(); + virtual std::string + ToString() const override { + return "FilterBits"; + } + private: std::unique_ptr exprs_; + QueryContext* query_context_; int64_t num_processed_rows_; int64_t need_process_rows_; }; diff --git a/internal/core/src/exec/operator/MvccNode.cpp b/internal/core/src/exec/operator/MvccNode.cpp new file mode 100644 index 0000000000000..eeae9ebf3748d --- /dev/null +++ b/internal/core/src/exec/operator/MvccNode.cpp @@ -0,0 +1,75 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "MvccNode.h" + +namespace milvus { +namespace exec { + +PhyMvccNode::PhyMvccNode(int32_t operator_id, + DriverContext* driverctx, + const std::shared_ptr& mvcc_node) + : Operator( + driverctx, mvcc_node->output_type(), operator_id, mvcc_node->id()) { + ExecContext* exec_context = operator_context_->get_exec_context(); + QueryContext* query_context = exec_context->get_query_context(); + segment_ = query_context->get_segment(); + query_timestamp_ = query_context->get_query_timestamp(); + active_count_ = query_context->get_active_count(); + is_source_node_ = mvcc_node->sources().size() == 0; +} + +void +PhyMvccNode::AddInput(RowVectorPtr& input) { + input_ = std::move(input); +} + +RowVectorPtr +PhyMvccNode::GetOutput() { + if (is_finished_) { + return nullptr; + } + + if (!is_source_node_ && input_ == nullptr) { + return nullptr; + } + + if (active_count_ == 0) { + is_finished_ = true; + return nullptr; + } + + auto col_input = + is_source_node_ + ? std::make_shared(TargetBitmap(active_count_)) + : GetColumnVector(input_); + + TargetBitmapView data(col_input->GetRawData(), col_input->size()); + segment_->mask_with_timestamps(data, query_timestamp_); + segment_->mask_with_delete(data, active_count_, query_timestamp_); + is_finished_ = true; + + // input_ have already been updated + return std::make_shared(std::vector{col_input}); +} + +bool +PhyMvccNode::IsFinished() { + return is_finished_; +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/MvccNode.h b/internal/core/src/exec/operator/MvccNode.h new file mode 100644 index 0000000000000..332dc71c5333a --- /dev/null +++ b/internal/core/src/exec/operator/MvccNode.h @@ -0,0 +1,78 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "exec/Driver.h" +#include "exec/expression/Expr.h" +#include "exec/operator/Operator.h" +#include "exec/QueryContext.h" + +namespace milvus { +namespace exec { + +class PhyMvccNode : public Operator { + public: + PhyMvccNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& mvcc_node); + + bool + IsFilter() override { + return false; + } + + bool + NeedInput() const override { + return !is_finished_; + } + + void + AddInput(RowVectorPtr& input); + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override; + + void + Close() override { + } + + BlockingReason + IsBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + virtual std::string + ToString() const override { + return "PhyMvccNode"; + } + + private: + const segcore::SegmentInternalInterface* segment_; + milvus::Timestamp query_timestamp_; + int64_t active_count_; + bool is_finished_{false}; + bool is_source_node_{false}; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/Operator.h b/internal/core/src/exec/operator/Operator.h index 0f3b40902b04f..1115ee263ac50 100644 --- a/internal/core/src/exec/operator/Operator.h +++ b/internal/core/src/exec/operator/Operator.h @@ -153,6 +153,11 @@ class Operator { return operator_context_->get_plannode_id(); } + virtual std::string + ToString() const { + return "Base Operator"; + } + protected: std::unique_ptr operator_context_; @@ -191,6 +196,11 @@ class SourceOperator : public Operator { PanicInfo(NotImplemented, "SourceOperator does not support noMoreInput()"); } + + virtual std::string + ToString() const override { + return "source operator"; + } }; } // namespace exec diff --git a/internal/core/src/exec/operator/VectorSearch.cpp b/internal/core/src/exec/operator/VectorSearch.cpp new file mode 100644 index 0000000000000..e79a17df71f3c --- /dev/null +++ b/internal/core/src/exec/operator/VectorSearch.cpp @@ -0,0 +1,128 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "VectorSearch.h" + +#include "exec/operator/groupby/SearchGroupByOperator.h" + +namespace milvus { +namespace exec { + +static milvus::SearchResult +empty_search_result(int64_t num_queries) { + milvus::SearchResult final_result; + final_result.total_nq_ = num_queries; + final_result.unity_topK_ = 0; // no result + final_result.total_data_cnt_ = 0; + return final_result; +} + +PhyVectorSearchNode::PhyVectorSearchNode( + int32_t operator_id, + DriverContext* driverctx, + const std::shared_ptr& search_node) + : Operator(driverctx, + search_node->output_type(), + operator_id, + search_node->id()) { + ExecContext* exec_context = operator_context_->get_exec_context(); + query_context_ = exec_context->get_query_context(); + segment_ = query_context_->get_segment(); + query_timestamp_ = query_context_->get_query_timestamp(); + active_count_ = query_context_->get_active_count(); + placeholder_group_ = query_context_->get_placeholder_group(); + search_info_ = query_context_->get_search_info(); +} + +void +PhyVectorSearchNode::AddInput(RowVectorPtr& input) { + input_ = std::move(input); +} + +RowVectorPtr +PhyVectorSearchNode::GetOutput() { + if (is_finished_ || !no_more_input_) { + return nullptr; + } + + DeferLambda([&]() { is_finished_ = true; }); + if (input_ == nullptr) { + return nullptr; + } + + std::chrono::high_resolution_clock::time_point vector_start = + std::chrono::high_resolution_clock::now(); + + auto& ph = placeholder_group_->at(0); + auto src_data = ph.get_blob(); + auto num_queries = ph.num_of_queries_; + milvus::SearchResult search_result; + + auto col_input = GetColumnVector(input_); + TargetBitmapView view(col_input->GetRawData(), col_input->size()); + if (view.all()) { + query_context_->set_search_result( + std::move(empty_search_result(num_queries))); + return input_; + } + + // TODO: uniform knowhere BitsetView and milvus BitsetView + milvus::BitsetView final_view((uint8_t*)col_input->GetRawData(), + col_input->size()); + segment_->vector_search(search_info_, + src_data, + num_queries, + query_timestamp_, + final_view, + search_result); + search_result.total_data_cnt_ = final_view.size(); + if (search_result.vector_iterators_.has_value()) { + std::vector group_by_values; + milvus::exec::SearchGroupBy(search_result.vector_iterators_.value(), + search_info_, + group_by_values, + *segment_, + search_result.seg_offsets_, + search_result.distances_, + search_result.topk_per_nq_prefix_sum_); + search_result.group_by_values_ = std::move(group_by_values); + AssertInfo(search_result.seg_offsets_.size() == + search_result.group_by_values_.value().size(), + "Wrong state! search_result group_by_values_ size:{} is not " + "equal to search_result.seg_offsets.size:{}", + search_result.group_by_values_.value().size(), + search_result.seg_offsets_.size()); + } + query_context_->set_search_result(std::move(search_result)); + std::chrono::high_resolution_clock::time_point vector_end = + std::chrono::high_resolution_clock::now(); + double vector_cost = + std::chrono::duration(vector_end - vector_start) + .count(); + monitor::internal_core_search_latency_vector.Observe(vector_cost); + // for now, vector search as the end node, + // and result store in query_context + // this node interface just return bitset + return input_; +} + +bool +PhyVectorSearchNode::IsFinished() { + return is_finished_; +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/VectorSearch.h b/internal/core/src/exec/operator/VectorSearch.h new file mode 100644 index 0000000000000..e6ec630eed9c9 --- /dev/null +++ b/internal/core/src/exec/operator/VectorSearch.h @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "exec/Driver.h" +#include "exec/expression/Expr.h" +#include "exec/operator/Operator.h" +#include "exec/QueryContext.h" + +namespace milvus { +namespace exec { + +class PhyVectorSearchNode : public Operator { + public: + PhyVectorSearchNode( + int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& search_node); + + bool + IsFilter() override { + return false; + } + + bool + NeedInput() const override { + return !is_finished_; + } + + void + AddInput(RowVectorPtr& input) override; + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override; + + void + Close() override { + } + + BlockingReason + IsBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + virtual std::string + ToString() const override { + return "PhyVectorSearchNode"; + } + + private: + const milvus::segcore::SegmentInternalInterface* segment_; + QueryContext* query_context_; + milvus::Timestamp query_timestamp_; + int64_t active_count_; + bool is_finished_{false}; + + const milvus::query::PlaceholderGroup* placeholder_group_; + milvus::SearchInfo search_info_; + + milvus::SearchResult* search_result_; +}; +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/query/groupby/SearchGroupByOperator.cpp b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.cpp similarity index 98% rename from internal/core/src/query/groupby/SearchGroupByOperator.cpp rename to internal/core/src/exec/operator/groupby/SearchGroupByOperator.cpp index 7b04f9cd2faff..812283f1d7a63 100644 --- a/internal/core/src/query/groupby/SearchGroupByOperator.cpp +++ b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.cpp @@ -19,7 +19,7 @@ #include "query/Utils.h" namespace milvus { -namespace query { +namespace exec { void SearchGroupBy(const std::vector>& iterators, @@ -189,7 +189,8 @@ GroupIteratorResult(const std::shared_ptr& iterator, //3. sorted based on distances and metrics auto customComparator = [&](const auto& lhs, const auto& rhs) { - return dis_closer(std::get<1>(lhs), std::get<1>(rhs), metrics_type); + return milvus::query::dis_closer( + std::get<1>(lhs), std::get<1>(rhs), metrics_type); }; std::sort(res.begin(), res.end(), customComparator); @@ -201,5 +202,5 @@ GroupIteratorResult(const std::shared_ptr& iterator, } } -} // namespace query +} // namespace exec } // namespace milvus diff --git a/internal/core/src/query/groupby/SearchGroupByOperator.h b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h similarity index 99% rename from internal/core/src/query/groupby/SearchGroupByOperator.h rename to internal/core/src/exec/operator/groupby/SearchGroupByOperator.h index dfc51d318ebc6..899b7856a89b8 100644 --- a/internal/core/src/query/groupby/SearchGroupByOperator.h +++ b/internal/core/src/exec/operator/groupby/SearchGroupByOperator.h @@ -26,7 +26,7 @@ #include "query/Utils.h" namespace milvus { -namespace query { +namespace exec { template class DataGetter { @@ -199,7 +199,7 @@ struct GroupByMap { public: GroupByMap(int group_capacity, int group_size) - : group_capacity_(group_capacity), group_size_(group_size){}; + : group_capacity_(group_capacity), group_size_(group_size) {}; bool IsGroupResEnough() { return group_map_.size() == group_capacity_ && @@ -235,5 +235,5 @@ GroupIteratorResult(const std::shared_ptr& iterator, std::vector& distances, const knowhere::MetricType& metrics_type); -} // namespace query +} // namespace exec } // namespace milvus diff --git a/internal/core/src/plan/PlanNode.h b/internal/core/src/plan/PlanNode.h index 04cfe5f219efd..da49e66a1b6a8 100644 --- a/internal/core/src/plan/PlanNode.h +++ b/internal/core/src/plan/PlanNode.h @@ -25,6 +25,7 @@ #include "expr/ITypeExpr.h" #include "common/EasyAssert.h" #include "segcore/SegmentInterface.h" +#include "plan/PlanNodeIdGenerator.h" namespace milvus { namespace plan { @@ -68,6 +69,14 @@ class PlanNode { return {}; }; + std::string + SourceToString() const { + std::vector sources_str; + for (auto& source : sources()) { + sources_str.emplace_back(source->ToString()); + } + } + private: PlanNodeId id_; }; @@ -244,7 +253,7 @@ class FilterBitsNode : public PlanNode { std::string ToString() const override { - return fmt::format("FilterBitsNode:[filter_expr:{}]", + return fmt::format("FilterBitsNode:\n\t[filter_expr:{}]", filter_->ToString()); } @@ -260,6 +269,103 @@ class FilterBitsNode : public PlanNode { const expr::TypedExprPtr filter_; }; +class MvccNode : public PlanNode { + public: + MvccNode(const PlanNodeId& id, + std::vector sources = std::vector{}) + : PlanNode(id), sources_{std::move(sources)} { + } + + DataType + output_type() const override { + return DataType::BOOL; + } + + std::vector + sources() const override { + return sources_; + } + + std::string_view + name() const override { + return "MvccNode"; + } + + std::string + ToString() const override { + return fmt::format("MvccNode:\n\t[source node:{}]", SourceToString()); + } + + private: + const std::vector sources_; +}; + +class VectorSearchNode : public PlanNode { + public: + VectorSearchNode( + const PlanNodeId& id, + std::vector sources = std::vector{}) + : PlanNode(id), sources_{std::move(sources)} { + } + + DataType + output_type() const override { + return DataType::BOOL; + } + + std::vector + sources() const override { + return sources_; + } + + std::string_view + name() const override { + return "VectorSearchNode"; + } + + std::string + ToString() const override { + return fmt::format("VectorSearchNode:\n\t[source node:{}]", + SourceToString()); + } + + private: + const std::vector sources_; +}; + +class CountNode : public PlanNode { + public: + CountNode( + const PlanNodeId& id, + const std::vector& sources = std::vector{}) + : PlanNode(id), sources_{std::move(sources)} { + } + + DataType + output_type() const override { + return DataType::INT64; + } + + std::vector + sources() const override { + return sources_; + } + + std::string_view + name() const override { + return "CountNode"; + } + + std::string + ToString() const override { + return fmt::format("VectorSearchNode:\n\t[source node:{}]", + SourceToString()); + } + + private: + const std::vector sources_; +}; + enum class ExecutionStrategy { // Process splits as they come in any available driver. kUngrouped, diff --git a/internal/core/src/plan/PlanNodeIdGenerator.h b/internal/core/src/plan/PlanNodeIdGenerator.h new file mode 100644 index 0000000000000..312ee67fdc503 --- /dev/null +++ b/internal/core/src/plan/PlanNodeIdGenerator.h @@ -0,0 +1,68 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "fmt/format.h" + +namespace milvus::plan { + +typedef std::string PlanNodeId; + +class PlanNodeIdGenerator { + public: + static PlanNodeIdGenerator& + GetInstance() { + static PlanNodeIdGenerator instance; + return instance; + } + + explicit PlanNodeIdGenerator(int start_id = 0) : next_id_(start_id) { + } + + PlanNodeId + Next() { + if (next_id_ >= std::numeric_limits::max()) { + next_id_ = 0; + } + return fmt::format("{}", next_id_++); + } + + void + Set(int id) { + next_id_ = id; + } + + void + ReSet() { + next_id_ = 0; + } + + private: + int next_id_; +}; + +inline PlanNodeId +GetNextPlanNodeId() { + return PlanNodeIdGenerator::GetInstance().Next(); +} + +} // namespace milvus::plan \ No newline at end of file diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index 6bbc488c3bd7d..0b08ae6c23f3a 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -10,23 +10,14 @@ # or implied. See the License for the specific language governing permissions and limitations under the License set(MILVUS_QUERY_SRCS - generated/PlanNode.cpp - generated/Expr.cpp - visitors/ShowPlanNodeVisitor.cpp - visitors/ShowExprVisitor.cpp - visitors/ExecPlanNodeVisitor.cpp - visitors/ExecExprVisitor.cpp - visitors/VerifyPlanNodeVisitor.cpp - visitors/VerifyExprVisitor.cpp - visitors/ExtractInfoPlanNodeVisitor.cpp - visitors/ExtractInfoExprVisitor.cpp + PlanNode.cpp + ExecPlanNodeVisitor.cpp Plan.cpp SearchOnGrowing.cpp SearchOnSealed.cpp SearchOnIndex.cpp SearchBruteForce.cpp SubSearchResult.cpp - groupby/SearchGroupByOperator.cpp PlanProto.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) diff --git a/internal/core/src/query/ExecPlanNodeVisitor.cpp b/internal/core/src/query/ExecPlanNodeVisitor.cpp new file mode 100644 index 0000000000000..f247b71be63de --- /dev/null +++ b/internal/core/src/query/ExecPlanNodeVisitor.cpp @@ -0,0 +1,222 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "query/ExecPlanNodeVisitor.h" + +#include +#include + +#include "expr/ITypeExpr.h" +#include "query/PlanImpl.h" +#include "query/SubSearchResult.h" +#include "query/Utils.h" +#include "segcore/SegmentGrowing.h" +#include "common/Json.h" +#include "log/Log.h" +#include "plan/PlanNode.h" +#include "exec/Task.h" +#include "segcore/SegmentInterface.h" +namespace milvus::query { + +namespace impl { +// THIS CONTAINS EXTRA BODY FOR VISITOR +// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ +class ExecPlanNodeVisitor : PlanNodeVisitor { + public: + ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, + Timestamp timestamp, + const PlaceholderGroup& placeholder_group) + : segment_(segment), + timestamp_(timestamp), + placeholder_group_(placeholder_group) { + } + + SearchResult + get_moved_result(PlanNode& node) { + assert(!search_result_opt_.has_value()); + node.accept(*this); + assert(search_result_opt_.has_value()); + auto ret = std::move(search_result_opt_).value(); + search_result_opt_ = std::nullopt; + return ret; + } + + private: + template + void + VectorVisitorImpl(VectorPlanNode& node); + + private: + const segcore::SegmentInterface& segment_; + Timestamp timestamp_; + const PlaceholderGroup& placeholder_group_; + + SearchResultOpt search_result_opt_; +}; +} // namespace impl + +static SearchResult +empty_search_result(int64_t num_queries, SearchInfo& search_info) { + SearchResult final_result; + final_result.total_nq_ = num_queries; + final_result.unity_topK_ = 0; // no result + final_result.total_data_cnt_ = 0; + return final_result; +} + +BitsetType +ExecPlanNodeVisitor::ExecuteTask( + plan::PlanFragment& plan, + std::shared_ptr query_context) { + LOG_DEBUG("plannode: {}, active_count: {}, timestamp: {}", + plan.plan_node_->ToString(), + query_context->get_active_count(), + query_context->get_query_timestamp()); + + auto task = + milvus::exec::Task::Create(DEFAULT_TASK_ID, plan, 0, query_context); + int64_t processed_num = 0; + BitsetType bitset_holder; + for (;;) { + auto result = task->Next(); + if (!result) { + Assert(processed_num == query_context->get_active_count()); + break; + } + auto childrens = result->childrens(); + AssertInfo(childrens.size() == 1, + "plannode result vector's children size not equal one"); + LOG_DEBUG("output result length:{}", childrens[0]->size()); + if (auto vec = std::dynamic_pointer_cast(childrens[0])) { + processed_num += vec->size(); + BitsetTypeView view(vec->GetRawData(), vec->size()); + bitset_holder.append(view); + } else { + PanicInfo(UnexpectedError, "expr return type not matched"); + } + } + return bitset_holder; +} + +template +void +ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { + assert(!search_result_opt_.has_value()); + auto segment = + dynamic_cast(&segment_); + AssertInfo(segment, "support SegmentSmallIndex Only"); + + auto active_count = segment->get_active_count(timestamp_); + + // PreExecute: skip all calculation + if (active_count == 0) { + search_result_opt_ = std::move(SearchResult()); + return; + } + + // Construct plan fragment + auto plan = plan::PlanFragment(node.plannodes_); + + // Set query context + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment, active_count, timestamp_); + query_context->set_search_info(node.search_info_); + query_context->set_placeholder_group(placeholder_group_); + + // Do plan fragment task work + auto result = ExecuteTask(plan, query_context); + + // Store result + search_result_opt_ = std::move(query_context->get_search_result()); +} + +std::unique_ptr +wrap_num_entities(int64_t cnt) { + auto retrieve_result = std::make_unique(); + DataArray arr; + arr.set_type(milvus::proto::schema::Int64); + auto scalar = arr.mutable_scalars(); + scalar->mutable_long_data()->mutable_data()->Add(cnt); + retrieve_result->field_data_ = {arr}; + retrieve_result->total_data_cnt_ = 0; + return retrieve_result; +} + +void +ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { + assert(!retrieve_result_opt_.has_value()); + auto segment = + dynamic_cast(&segment_); + AssertInfo(segment, "Support SegmentSmallIndex Only"); + RetrieveResult retrieve_result; + retrieve_result.total_data_cnt_ = 0; + + auto active_count = segment->get_active_count(timestamp_); + + // PreExecute: skip all calculation + if (active_count == 0 && !node.is_count_) { + retrieve_result_opt_ = std::move(retrieve_result); + return; + } + + if (active_count == 0 && node.is_count_) { + retrieve_result = *(wrap_num_entities(0)); + retrieve_result_opt_ = std::move(retrieve_result); + return; + } + + // Get plan + auto plan = plan::PlanFragment(node.plannodes_); + + // Set query context + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment, active_count, timestamp_); + + // Do task execution + auto bitset_holder = ExecuteTask(plan, query_context); + + // Store result + if (node.is_count_) { + retrieve_result_opt_ = std::move(query_context->get_retrieve_result()); + } else { + auto results_pair = segment->find_first(node.limit_, bitset_holder); + retrieve_result.result_offsets_ = std::move(results_pair.first); + retrieve_result.has_more_result = results_pair.second; + retrieve_result_opt_ = std::move(retrieve_result); + } +} + +void +ExecPlanNodeVisitor::visit(FloatVectorANNS& node) { + VectorVisitorImpl(node); +} + +void +ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) { + VectorVisitorImpl(node); +} + +void +ExecPlanNodeVisitor::visit(Float16VectorANNS& node) { + VectorVisitorImpl(node); +} + +void +ExecPlanNodeVisitor::visit(BFloat16VectorANNS& node) { + VectorVisitorImpl(node); +} + +void +ExecPlanNodeVisitor::visit(SparseFloatVectorANNS& node) { + VectorVisitorImpl(node); +} + +} // namespace milvus::query diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/ExecPlanNodeVisitor.h similarity index 78% rename from internal/core/src/query/generated/ExecPlanNodeVisitor.h rename to internal/core/src/query/ExecPlanNodeVisitor.h index 96b5d9b2f948e..601f12ba3eb03 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/ExecPlanNodeVisitor.h @@ -17,6 +17,8 @@ #include "segcore/SegmentGrowing.h" #include #include "PlanNodeVisitor.h" +#include "plan/PlanNode.h" +#include "exec/QueryContext.h" namespace milvus::query { @@ -88,11 +90,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { return expr_use_pk_index_; } - void - ExecuteExprNode(const std::shared_ptr& plannode, - const milvus::segcore::SegmentInternalInterface* segment, - int64_t active_count, - BitsetType& result); + static BitsetType + ExecuteTask(plan::PlanFragment& plan, + std::shared_ptr query_context); private: template @@ -108,4 +108,23 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { RetrieveResultOpt retrieve_result_opt_; bool expr_use_pk_index_ = false; }; + +// for test use only +inline BitsetType +ExecuteQueryExpr(std::shared_ptr plannode, + const milvus::segcore::SegmentInternalInterface* segment, + uint64_t active_count, + uint64_t timestamp) { + auto plan_fragment = plan::PlanFragment(plannode); + + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment, active_count, timestamp); + auto bitset = + ExecPlanNodeVisitor::ExecuteTask(plan_fragment, query_context); + + // For test case, bitset 1 indicates true but executor is verse + bitset.flip(); + return bitset; +} + } // namespace milvus::query diff --git a/internal/core/src/query/Expr.h b/internal/core/src/query/Expr.h deleted file mode 100644 index 93a52f0076c3c..0000000000000 --- a/internal/core/src/query/Expr.h +++ /dev/null @@ -1,359 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "common/Schema.h" -#include "common/Types.h" -#include "pb/plan.pb.h" - -namespace milvus::query { - -using optype = proto::plan::OpType; - -class ExprVisitor; - -struct ColumnInfo { - FieldId field_id; - DataType data_type; - std::vector nested_path; - - ColumnInfo(const proto::plan::ColumnInfo& column_info) - : field_id(column_info.field_id()), - data_type(static_cast(column_info.data_type())), - nested_path(column_info.nested_path().begin(), - column_info.nested_path().end()) { - } - - ColumnInfo(FieldId field_id, - DataType data_type, - std::vector nested_path = {}) - : field_id(field_id), - data_type(data_type), - nested_path(std::move(nested_path)) { - } -}; - -// Base of all Exprs -struct Expr { - public: - virtual ~Expr() = default; - virtual void - accept(ExprVisitor&) = 0; -}; - -using ExprPtr = std::unique_ptr; - -struct BinaryExprBase : Expr { - const ExprPtr left_; - const ExprPtr right_; - - BinaryExprBase() = delete; - - BinaryExprBase(ExprPtr& left, ExprPtr& right) - : left_(std::move(left)), right_(std::move(right)) { - } -}; - -struct UnaryExprBase : Expr { - const ExprPtr child_; - - UnaryExprBase() = delete; - - explicit UnaryExprBase(ExprPtr& child) : child_(std::move(child)) { - } -}; - -struct LogicalUnaryExpr : UnaryExprBase { - enum class OpType { Invalid = 0, LogicalNot = 1 }; - const OpType op_type_; - - LogicalUnaryExpr(const OpType op_type, ExprPtr& child) - : UnaryExprBase(child), op_type_(op_type) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -struct LogicalBinaryExpr : BinaryExprBase { - // Note: bitA - bitB == bitA & ~bitB, alias to LogicalMinus - enum class OpType { - Invalid = 0, - LogicalAnd = 1, - LogicalOr = 2, - LogicalXor = 3, - LogicalMinus = 4 - }; - const OpType op_type_; - - LogicalBinaryExpr(const OpType op_type, ExprPtr& left, ExprPtr& right) - : BinaryExprBase(left, right), op_type_(op_type) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -struct TermExpr : Expr { - const ColumnInfo column_; - const proto::plan::GenericValue::ValCase val_case_; - const bool is_in_field_; - - protected: - // prevent accidental instantiation - TermExpr() = delete; - - TermExpr(ColumnInfo column, - const proto::plan::GenericValue::ValCase val_case, - const bool is_in_field) - : column_(std::move(column)), - val_case_(val_case), - is_in_field_(is_in_field) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -static const std::map arith_op_mapping_ = { - // arith_op_name -> arith_op - {"add", ArithOpType::Add}, - {"sub", ArithOpType::Sub}, - {"mul", ArithOpType::Mul}, - {"div", ArithOpType::Div}, - {"mod", ArithOpType::Mod}, -}; - -static const std::map mapping_arith_op_ = { - // arith_op_name -> arith_op - {ArithOpType::Add, "add"}, - {ArithOpType::Sub, "sub"}, - {ArithOpType::Mul, "mul"}, - {ArithOpType::Div, "div"}, - {ArithOpType::Mod, "mod"}, -}; - -struct BinaryArithOpEvalRangeExpr : Expr { - const ColumnInfo column_; - const proto::plan::GenericValue::ValCase val_case_; - const OpType op_type_; - const ArithOpType arith_op_; - - protected: - // prevent accidental instantiation - BinaryArithOpEvalRangeExpr() = delete; - - BinaryArithOpEvalRangeExpr( - ColumnInfo column, - const proto::plan::GenericValue::ValCase val_case, - const OpType op_type, - const ArithOpType arith_op) - : column_(std::move(column)), - val_case_(val_case), - op_type_(op_type), - arith_op_(arith_op) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -static const std::map mapping_ = { - // op_name -> op - {"lt", OpType::LessThan}, - {"le", OpType::LessEqual}, - {"lte", OpType::LessEqual}, - {"gt", OpType::GreaterThan}, - {"ge", OpType::GreaterEqual}, - {"gte", OpType::GreaterEqual}, - {"eq", OpType::Equal}, - {"ne", OpType::NotEqual}, -}; - -struct UnaryRangeExpr : Expr { - ColumnInfo column_; - const OpType op_type_; - const proto::plan::GenericValue::ValCase val_case_; - - protected: - // prevent accidental instantiation - UnaryRangeExpr() = delete; - - UnaryRangeExpr(ColumnInfo column, - const OpType op_type, - const proto::plan::GenericValue::ValCase val_case) - : column_(std::move(column)), op_type_(op_type), val_case_(val_case) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -struct BinaryRangeExpr : Expr { - const ColumnInfo column_; - const proto::plan::GenericValue::ValCase val_case_; - const bool lower_inclusive_; - const bool upper_inclusive_; - - protected: - // prevent accidental instantiation - BinaryRangeExpr() = delete; - - BinaryRangeExpr(ColumnInfo column, - const proto::plan::GenericValue::ValCase val_case, - const bool lower_inclusive, - const bool upper_inclusive) - : column_(std::move(column)), - val_case_(val_case), - lower_inclusive_(lower_inclusive), - upper_inclusive_(upper_inclusive) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -struct CompareExpr : Expr { - FieldId left_field_id_; - FieldId right_field_id_; - DataType left_data_type_; - DataType right_data_type_; - OpType op_type_; - - public: - void - accept(ExprVisitor&) override; -}; - -struct ExistsExpr : Expr { - const ColumnInfo column_; - - protected: - // prevent accidental instantiation - ExistsExpr() = delete; - - ExistsExpr(ColumnInfo column) : column_(std::move(column)) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -struct AlwaysTrueExpr : Expr { - public: - void - accept(ExprVisitor&) override; -}; - -inline ExprPtr -CreateAlwaysTrueExpr() { - return std::make_unique(); -} - -struct JsonContainsExpr : Expr { - const ColumnInfo column_; - bool same_type_; - ContainsType op_; - const proto::plan::GenericValue::ValCase val_case_; - - protected: - JsonContainsExpr() = delete; - - JsonContainsExpr(ColumnInfo column, - const bool same_type, - ContainsType op, - proto::plan::GenericValue::ValCase val_case) - : column_(std::move(column)), - same_type_(same_type), - op_(op), - val_case_(val_case) { - } - - public: - void - accept(ExprVisitor&) override; -}; - -inline bool -IsTermExpr(Expr* expr) { - TermExpr* term_expr = dynamic_cast(expr); - return term_expr != nullptr; -} - -} // namespace milvus::query - -template <> -struct fmt::formatter - : formatter { - auto - format(milvus::query::LogicalUnaryExpr::OpType c, - format_context& ctx) const { - string_view name = "unknown"; - switch (c) { - case milvus::query::LogicalUnaryExpr::OpType::Invalid: - name = "Invalid"; - break; - case milvus::query::LogicalUnaryExpr::OpType::LogicalNot: - name = "LogicalNot"; - break; - } - return formatter::format(name, ctx); - } -}; - -template <> -struct fmt::formatter - : formatter { - auto - format(milvus::query::LogicalBinaryExpr::OpType c, - format_context& ctx) const { - string_view name = "unknown"; - switch (c) { - case milvus::query::LogicalBinaryExpr::OpType::Invalid: - name = "Invalid"; - break; - case milvus::query::LogicalBinaryExpr::OpType::LogicalAnd: - name = "LogicalAdd"; - break; - case milvus::query::LogicalBinaryExpr::OpType::LogicalOr: - name = "LogicalOr"; - break; - case milvus::query::LogicalBinaryExpr::OpType::LogicalXor: - name = "LogicalXor"; - break; - case milvus::query::LogicalBinaryExpr::OpType::LogicalMinus: - name = "LogicalMinus"; - break; - } - return formatter::format(name, ctx); - } -}; diff --git a/internal/core/src/query/ExprImpl.h b/internal/core/src/query/ExprImpl.h deleted file mode 100644 index d91645aadac44..0000000000000 --- a/internal/core/src/query/ExprImpl.h +++ /dev/null @@ -1,115 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "Expr.h" -#include "pb/plan.pb.h" - -namespace milvus::query { - -template -struct TermExprImpl : TermExpr { - const std::vector terms_; - - TermExprImpl(ColumnInfo column, - const std::vector& terms, - const proto::plan::GenericValue::ValCase val_case, - const bool is_in_field = false) - : TermExpr(std::forward(column), val_case, is_in_field), - terms_(terms) { - } -}; - -template -struct BinaryArithOpEvalRangeExprImpl : BinaryArithOpEvalRangeExpr { - const T right_operand_; - const T value_; - - BinaryArithOpEvalRangeExprImpl( - ColumnInfo column, - const proto::plan::GenericValue::ValCase val_case, - const ArithOpType arith_op, - const T right_operand, - const OpType op_type, - const T value) - : BinaryArithOpEvalRangeExpr( - std::forward(column), val_case, op_type, arith_op), - right_operand_(right_operand), - value_(value) { - } -}; - -template -struct UnaryRangeExprImpl : UnaryRangeExpr { - const T value_; - - UnaryRangeExprImpl(ColumnInfo column, - const OpType op_type, - const T value, - const proto::plan::GenericValue::ValCase val_case) - : UnaryRangeExpr(std::forward(column), op_type, val_case), - value_(value) { - } -}; - -template -struct BinaryRangeExprImpl : BinaryRangeExpr { - const T lower_value_; - const T upper_value_; - - BinaryRangeExprImpl(ColumnInfo column, - const proto::plan::GenericValue::ValCase val_case, - const bool lower_inclusive, - const bool upper_inclusive, - const T lower_value, - const T upper_value) - : BinaryRangeExpr(std::forward(column), - val_case, - lower_inclusive, - upper_inclusive), - lower_value_(lower_value), - upper_value_(upper_value) { - } -}; - -struct ExistsExprImpl : ExistsExpr { - ExistsExprImpl(ColumnInfo column) - : ExistsExpr(std::forward(column)) { - } -}; - -template -struct JsonContainsExprImpl : JsonContainsExpr { - const std::vector elements_; - - JsonContainsExprImpl(ColumnInfo column, - std::vector elements, - const bool same_type, - ContainsType op, - proto::plan::GenericValue::ValCase val_case) - : JsonContainsExpr( - std::forward(column), same_type, op, val_case), - elements_(std::move(elements)) { - } -}; - -} // namespace milvus::query diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index d0cf1542ba9ea..a5c948fbb25f8 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -17,7 +17,6 @@ #include "Plan.h" #include "common/Utils.h" #include "PlanProto.h" -#include "generated/ShowPlanNodeVisitor.h" namespace milvus::query { @@ -145,20 +144,4 @@ GetNumOfQueries(const PlaceholderGroup* group) { // return plan; //} -void -Plan::check_identical(Plan& other) { - Assert(&schema_ == &other.schema_); - auto json = ShowPlanNodeVisitor().call_child(*this->plan_node_); - auto other_json = ShowPlanNodeVisitor().call_child(*other.plan_node_); - Assert(json.dump(2) == other_json.dump(2)); - Assert(this->extra_info_opt_.has_value() == - other.extra_info_opt_.has_value()); - if (this->extra_info_opt_.has_value()) { - Assert(this->extra_info_opt_->involved_fields_ == - other.extra_info_opt_->involved_fields_); - } - Assert(this->tag2field_ == other.tag2field_); - Assert(this->target_entries_ == other.target_entries_); -} - } // namespace milvus::query diff --git a/internal/core/src/query/PlanImpl.h b/internal/core/src/query/PlanImpl.h index 089902e95742f..11606b2b9117d 100644 --- a/internal/core/src/query/PlanImpl.h +++ b/internal/core/src/query/PlanImpl.h @@ -22,6 +22,7 @@ #include "common/EasyAssert.h" #include "common/Json.h" #include "common/Consts.h" +#include "common/Schema.h" namespace milvus::query { diff --git a/internal/core/src/query/generated/PlanNode.cpp b/internal/core/src/query/PlanNode.cpp similarity index 100% rename from internal/core/src/query/generated/PlanNode.cpp rename to internal/core/src/query/PlanNode.cpp diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index de39c0afd1370..5f771f40aa88d 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -18,7 +18,6 @@ #include #include "common/QueryInfo.h" -#include "query/Expr.h" namespace milvus::plan { class PlanNode; @@ -37,10 +36,9 @@ struct PlanNode { using PlanNodePtr = std::unique_ptr; struct VectorPlanNode : PlanNode { - std::optional predicate_; - std::optional> filter_plannode_; SearchInfo search_info_; std::string placeholder_tag_; + std::shared_ptr plannodes_; }; struct FloatVectorANNS : VectorPlanNode { @@ -78,8 +76,8 @@ struct RetrievePlanNode : PlanNode { void accept(PlanNodeVisitor&) override; - std::optional predicate_; - std::optional> filter_plannode_; + std::shared_ptr plannodes_; + bool is_count_; int64_t limit_; }; diff --git a/internal/core/src/query/generated/PlanNodeVisitor.h b/internal/core/src/query/PlanNodeVisitor.h similarity index 100% rename from internal/core/src/query/generated/PlanNodeVisitor.h rename to internal/core/src/query/PlanNodeVisitor.h diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 1b9c01151541f..395ec65bdff5a 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -16,203 +16,52 @@ #include #include -#include "ExprImpl.h" #include "common/VectorTrait.h" #include "common/EasyAssert.h" -#include "generated/ExtractInfoExprVisitor.h" -#include "generated/ExtractInfoPlanNodeVisitor.h" #include "pb/plan.pb.h" #include "query/Utils.h" #include "knowhere/comp/materialized_view.h" +#include "plan/PlanNode.h" namespace milvus::query { namespace planpb = milvus::proto::plan; -template -std::unique_ptr> -ExtractTermExprImpl(FieldId field_id, - DataType data_type, - const planpb::TermExpr& expr_proto) { - static_assert(IsScalar); - auto size = expr_proto.values_size(); - std::vector terms; - terms.reserve(size); - auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET; - for (int i = 0; i < size; ++i) { - auto& value_proto = expr_proto.values(i); - if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kBoolVal); - terms.push_back(static_cast(value_proto.bool_val())); - val_case = proto::plan::GenericValue::ValCase::kBoolVal; - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val); - auto value = value_proto.int64_val(); - if (out_of_range(value)) { - continue; - } - terms.push_back(static_cast(value)); - val_case = proto::plan::GenericValue::ValCase::kInt64Val; - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal); - terms.push_back(static_cast(value_proto.float_val())); - val_case = proto::plan::GenericValue::ValCase::kFloatVal; - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kStringVal); - terms.push_back(static_cast(value_proto.string_val())); - val_case = proto::plan::GenericValue::ValCase::kStringVal; - } else { - static_assert(always_false); - } - } - std::sort(terms.begin(), terms.end()); - return std::make_unique>( - expr_proto.column_info(), terms, val_case, expr_proto.is_in_field()); -} - -template -std::unique_ptr> -ExtractUnaryRangeExprImpl(FieldId field_id, - DataType data_type, - const planpb::UnaryRangeExpr& expr_proto) { - static_assert(IsScalar); - auto getValue = [&](const auto& value_proto) -> T { - if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kBoolVal); - return static_cast(value_proto.bool_val()); - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val); - return static_cast(value_proto.int64_val()); - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal); - return static_cast(value_proto.float_val()); - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kStringVal); - return static_cast(value_proto.string_val()); - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kArrayVal); - return static_cast(value_proto.array_val()); - } else { - static_assert(always_false); - } - }; - return std::make_unique>( - expr_proto.column_info(), - static_cast(expr_proto.op()), - getValue(expr_proto.value()), - expr_proto.value().val_case()); -} - -template -std::unique_ptr> -ExtractBinaryRangeExprImpl(FieldId field_id, - DataType data_type, - const planpb::BinaryRangeExpr& expr_proto) { - static_assert(IsScalar); - auto getValue = [&](const auto& value_proto) -> T { - if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kBoolVal); - return static_cast(value_proto.bool_val()); - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val); - return static_cast(value_proto.int64_val()); - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal); - return static_cast(value_proto.float_val()); - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kStringVal); - return static_cast(value_proto.string_val()); - } else { - static_assert(always_false); - } - }; - return std::make_unique>( - expr_proto.column_info(), - expr_proto.lower_value().val_case(), - expr_proto.lower_inclusive(), - expr_proto.upper_inclusive(), - getValue(expr_proto.lower_value()), - getValue(expr_proto.upper_value())); -} - -template -std::unique_ptr> -ExtractBinaryArithOpEvalRangeExprImpl( - FieldId field_id, - DataType data_type, - const planpb::BinaryArithOpEvalRangeExpr& expr_proto) { - static_assert(std::is_fundamental_v); - auto getValue = [&](const auto& value_proto) -> T { - if constexpr (std::is_same_v) { - // Handle bool here. Otherwise, it can go in `is_integral_v` - static_assert(always_false); - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val); - return static_cast(value_proto.int64_val()); - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal); - return static_cast(value_proto.float_val()); - } else { - static_assert(always_false); - } - }; - if (expr_proto.arith_op() == proto::plan::ArrayLength) { - return std::make_unique>( - expr_proto.column_info(), - expr_proto.value().val_case(), - expr_proto.arith_op(), - 0, - expr_proto.op(), - getValue(expr_proto.value())); - } - return std::make_unique>( - expr_proto.column_info(), - expr_proto.value().val_case(), - expr_proto.arith_op(), - getValue(expr_proto.right_operand()), - expr_proto.op(), - getValue(expr_proto.value())); -} - std::unique_ptr ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { // TODO: add more buffs Assert(plan_node_proto.has_vector_anns()); auto& anns_proto = plan_node_proto.vector_anns(); - auto expr_opt = [&]() -> std::optional { - if (!anns_proto.has_predicates()) { - return std::nullopt; - } else { - return ParseExpr(anns_proto.predicates()); - } - }(); auto expr_parser = [&]() -> plan::PlanNodePtr { auto expr = ParseExprs(anns_proto.predicates()); - return std::make_shared(DEFAULT_PLANNODE_ID, - expr); + return std::make_shared( + milvus::plan::GetNextPlanNodeId(), expr); }; - auto& query_info_proto = anns_proto.query_info(); - - SearchInfo search_info; - auto field_id = FieldId(anns_proto.field_id()); - search_info.field_id_ = field_id; - - search_info.metric_type_ = query_info_proto.metric_type(); - search_info.topk_ = query_info_proto.topk(); - search_info.round_decimal_ = query_info_proto.round_decimal(); - search_info.search_params_ = - nlohmann::json::parse(query_info_proto.search_params()); - search_info.materialized_view_involved = - query_info_proto.materialized_view_involved(); - - if (query_info_proto.group_by_field_id() > 0) { - auto group_by_field_id = FieldId(query_info_proto.group_by_field_id()); - search_info.group_by_field_id_ = group_by_field_id; - search_info.group_size_ = query_info_proto.group_size() > 0 - ? query_info_proto.group_size() - : 1; - } + auto search_info_parser = [&]() -> SearchInfo { + SearchInfo search_info; + auto& query_info_proto = anns_proto.query_info(); + auto field_id = FieldId(anns_proto.field_id()); + search_info.field_id_ = field_id; + + search_info.metric_type_ = query_info_proto.metric_type(); + search_info.topk_ = query_info_proto.topk(); + search_info.round_decimal_ = query_info_proto.round_decimal(); + search_info.search_params_ = + nlohmann::json::parse(query_info_proto.search_params()); + search_info.materialized_view_involved = + query_info_proto.materialized_view_involved(); + + if (query_info_proto.group_by_field_id() > 0) { + auto group_by_field_id = + FieldId(query_info_proto.group_by_field_id()); + search_info.group_by_field_id_ = group_by_field_id; + search_info.group_size_ = query_info_proto.group_size() > 0 + ? query_info_proto.group_size() + : 1; + } + return search_info; + }; auto plan_node = [&]() -> std::unique_ptr { if (anns_proto.vector_type() == @@ -232,29 +81,39 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { } }(); plan_node->placeholder_tag_ = anns_proto.placeholder_tag(); - plan_node->predicate_ = std::move(expr_opt); + plan_node->search_info_ = std::move(search_info_parser()); + + milvus::plan::PlanNodePtr plannode; + std::vector sources; if (anns_proto.has_predicates()) { - plan_node->filter_plannode_ = std::move(expr_parser()); + plannode = std::move(expr_parser()); + if (plan_node->search_info_.materialized_view_involved) { + const auto expr_info = plannode->GatherInfo(); + knowhere::MaterializedViewSearchInfo materialized_view_search_info; + for (const auto& [expr_field_id, vals] : + expr_info.field_id_to_values) { + materialized_view_search_info + .field_id_to_touched_categories_cnt[expr_field_id] = + vals.size(); + } + materialized_view_search_info.is_pure_and = expr_info.is_pure_and; + materialized_view_search_info.has_not = expr_info.has_not; + + plan_node->search_info_ + .search_params_[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] = + materialized_view_search_info; + } + sources = std::vector{plannode}; } - plan_node->search_info_ = std::move(search_info); - - if (plan_node->search_info_.materialized_view_involved && - plan_node->filter_plannode_.has_value()) { - const auto expr_info = - plan_node->filter_plannode_.value()->GatherInfo(); - knowhere::MaterializedViewSearchInfo materialized_view_search_info; - for (const auto& [expr_field_id, vals] : expr_info.field_id_to_values) { - materialized_view_search_info - .field_id_to_touched_categories_cnt[expr_field_id] = - vals.size(); - } - materialized_view_search_info.is_pure_and = expr_info.is_pure_and; - materialized_view_search_info.has_not = expr_info.has_not; - plan_node->search_info_ - .search_params_[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] = - materialized_view_search_info; - } + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), sources); + sources = std::vector{plannode}; + + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), sources); + + plan_node->plannodes_ = plannode; return plan_node; } @@ -264,38 +123,49 @@ ProtoParser::RetrievePlanNodeFromProto( const planpb::PlanNode& plan_node_proto) { Assert(plan_node_proto.has_predicates() || plan_node_proto.has_query()); + milvus::plan::PlanNodePtr plannode; + std::vector sources; + auto plan_node = [&]() -> std::unique_ptr { auto node = std::make_unique(); if (plan_node_proto.has_predicates()) { // version before 2023.03.30. node->is_count_ = false; auto& predicate_proto = plan_node_proto.predicates(); - auto expr_opt = [&]() -> ExprPtr { - return ParseExpr(predicate_proto); - }(); auto expr_parser = [&]() -> plan::PlanNodePtr { auto expr = ParseExprs(predicate_proto); return std::make_shared( - DEFAULT_PLANNODE_ID, expr); + milvus::plan::GetNextPlanNodeId(), expr); }(); - node->predicate_ = std::move(expr_opt); - node->filter_plannode_ = std::move(expr_parser); + plannode = std::move(expr_parser); + sources = std::vector{plannode}; + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), sources); + node->plannodes_ = std::move(plannode); } else { auto& query = plan_node_proto.query(); if (query.has_predicates()) { auto& predicate_proto = query.predicates(); - auto expr_opt = [&]() -> ExprPtr { - return ParseExpr(predicate_proto); - }(); auto expr_parser = [&]() -> plan::PlanNodePtr { auto expr = ParseExprs(predicate_proto); return std::make_shared( - DEFAULT_PLANNODE_ID, expr); + milvus::plan::GetNextPlanNodeId(), expr); }(); - node->predicate_ = std::move(expr_opt); - node->filter_plannode_ = std::move(expr_parser); + plannode = std::move(expr_parser); + sources = std::vector{plannode}; } + + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), sources); + sources = std::vector{plannode}; + node->is_count_ = query.is_count(); node->limit_ = query.limit(); + if (node->is_count_) { + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), sources); + sources = std::vector{plannode}; + } + node->plannodes_ = plannode; } return node; }(); @@ -308,13 +178,11 @@ ProtoParser::CreatePlan(const proto::plan::PlanNode& plan_node_proto) { auto plan = std::make_unique(schema); auto plan_node = PlanNodeFromProto(plan_node_proto); - ExtractedPlanInfo plan_info(schema.size()); - ExtractInfoPlanNodeVisitor extractor(plan_info); - plan_node->accept(extractor); - plan->tag2field_["$0"] = plan_node->search_info_.field_id_; plan->plan_node_ = std::move(plan_node); - plan->extra_info_opt_ = std::move(plan_info); + ExtractedPlanInfo extra_info(schema.size()); + extra_info.add_involved_field(plan->plan_node_->search_info_.field_id_); + plan->extra_info_opt_ = std::move(extra_info); for (auto field_id_raw : plan_node_proto.output_field_ids()) { auto field_id = FieldId(field_id_raw); @@ -329,9 +197,6 @@ ProtoParser::CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto) { auto retrieve_plan = std::make_unique(schema); auto plan_node = RetrievePlanNodeFromProto(plan_node_proto); - ExtractedPlanInfo plan_info(schema.size()); - ExtractInfoPlanNodeVisitor extractor(plan_info); - plan_node->accept(extractor); retrieve_plan->plan_node_ = std::move(plan_node); for (auto field_id_raw : plan_node_proto.output_field_ids()) { @@ -351,74 +216,6 @@ ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) { expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value()); } -ExprPtr -ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) { - auto& column_info = expr_pb.column_info(); - auto field_id = FieldId(column_info.field_id()); - auto data_type = schema[field_id].get_data_type(); - Assert(data_type == static_cast(column_info.data_type())); - - auto result = [&]() -> ExprPtr { - switch (data_type) { - case DataType::BOOL: { - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - } - - // see also: https://github.com/milvus-io/milvus/issues/23646. - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: { - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - } - - case DataType::FLOAT: { - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::DOUBLE: { - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::VARCHAR: { - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::JSON: - case DataType::ARRAY: { - switch (expr_pb.value().val_case()) { - case proto::plan::GenericValue::ValCase::kBoolVal: - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kStringVal: - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kArrayVal: - return ExtractUnaryRangeExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type: {} in expression", - expr_pb.value().val_case())); - } - } - default: { - PanicInfo(DataTypeInvalid, "unsupported data type"); - } - } - }(); - return result; -} - expr::TypedExprPtr ProtoParser::ParseBinaryRangeExprs( const proto::plan::BinaryRangeExpr& expr_pb) { @@ -434,90 +231,6 @@ ProtoParser::ParseBinaryRangeExprs( expr_pb.upper_inclusive()); } -ExprPtr -ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) { - auto& columnInfo = expr_pb.column_info(); - auto field_id = FieldId(columnInfo.field_id()); - auto data_type = schema[field_id].get_data_type(); - Assert(data_type == (DataType)columnInfo.data_type()); - - auto result = [&]() -> ExprPtr { - switch (data_type) { - case DataType::BOOL: { - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - } - - // see also: https://github.com/milvus-io/milvus/issues/23646. - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: { - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - } - - case DataType::FLOAT: { - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::DOUBLE: { - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::VARCHAR: { - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::JSON: { - switch (expr_pb.lower_value().val_case()) { - case proto::plan::GenericValue::ValCase::kBoolVal: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kStringVal: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type in expression {}", - data_type)); - } - } - case DataType::ARRAY: { - switch (expr_pb.lower_value().val_case()) { - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kStringVal: - return ExtractBinaryRangeExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type in expression {}", - data_type)); - } - } - - default: { - PanicInfo( - DataTypeInvalid, "unsupported data type {}", data_type); - } - } - }(); - return result; -} - expr::TypedExprPtr ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) { auto& left_column_info = expr_pb.left_column_info(); @@ -539,31 +252,6 @@ ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) { expr_pb.op()); } -ExprPtr -ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) { - auto& left_column_info = expr_pb.left_column_info(); - auto left_field_id = FieldId(left_column_info.field_id()); - auto left_data_type = schema[left_field_id].get_data_type(); - Assert(left_data_type == - static_cast(left_column_info.data_type())); - - auto& right_column_info = expr_pb.right_column_info(); - auto right_field_id = FieldId(right_column_info.field_id()); - auto right_data_type = schema[right_field_id].get_data_type(); - Assert(right_data_type == - static_cast(right_column_info.data_type())); - - return [&]() -> ExprPtr { - auto result = std::make_unique(); - result->left_field_id_ = left_field_id; - result->left_data_type_ = left_data_type; - result->right_field_id_ = right_field_id; - result->right_data_type_ = right_data_type; - result->op_type_ = static_cast(expr_pb.op()); - return result; - }(); -} - expr::TypedExprPtr ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) { auto& columnInfo = expr_pb.column_info(); @@ -578,113 +266,6 @@ ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) { columnInfo, values, expr_pb.is_in_field()); } -ExprPtr -ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) { - auto& columnInfo = expr_pb.column_info(); - auto field_id = FieldId(columnInfo.field_id()); - auto data_type = schema[field_id].get_data_type(); - Assert(data_type == (DataType)columnInfo.data_type()); - - // auto& field_meta = schema[field_offset]; - auto result = [&]() -> ExprPtr { - switch (data_type) { - case DataType::BOOL: { - return ExtractTermExprImpl(field_id, data_type, expr_pb); - } - case DataType::INT8: { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - case DataType::INT16: { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - case DataType::INT32: { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - case DataType::INT64: { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - case DataType::FLOAT: { - return ExtractTermExprImpl(field_id, data_type, expr_pb); - } - case DataType::DOUBLE: { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - case DataType::VARCHAR: { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - case DataType::JSON: { - if (expr_pb.values().size() == 0) { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - switch (expr_pb.values()[0].val_case()) { - case proto::plan::GenericValue::ValCase::kBoolVal: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kStringVal: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type: {} in expression", - expr_pb.values()[0].val_case())); - } - } - case DataType::ARRAY: { - if (expr_pb.values().size() == 0) { - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - } - switch (expr_pb.values()[0].val_case()) { - case proto::plan::GenericValue::ValCase::kBoolVal: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kStringVal: - return ExtractTermExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type: {} in expression", - expr_pb.values()[0].val_case())); - } - } - default: { - PanicInfo( - DataTypeInvalid, "unsupported data type {}", data_type); - } - } - }(); - return result; -} - -ExprPtr -ProtoParser::ParseUnaryExpr(const proto::plan::UnaryExpr& expr_pb) { - auto op = static_cast(expr_pb.op()); - Assert(op == LogicalUnaryExpr::OpType::LogicalNot); - auto expr = this->ParseExpr(expr_pb.child()); - return std::make_unique(op, expr); -} - expr::TypedExprPtr ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) { auto op = static_cast(expr_pb.op()); @@ -693,14 +274,6 @@ ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) { return std::make_shared(op, child_expr); } -ExprPtr -ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) { - auto op = static_cast(expr_pb.op()); - auto left_expr = this->ParseExpr(expr_pb.left()); - auto right_expr = this->ParseExpr(expr_pb.right()); - return std::make_unique(op, left_expr, right_expr); -} - expr::TypedExprPtr ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) { auto op = static_cast(expr_pb.op()); @@ -709,70 +282,6 @@ ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) { return std::make_shared(op, left_expr, right_expr); } -ExprPtr -ProtoParser::ParseBinaryArithOpEvalRangeExpr( - const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) { - auto& column_info = expr_pb.column_info(); - auto field_id = FieldId(column_info.field_id()); - auto data_type = schema[field_id].get_data_type(); - Assert(data_type == static_cast(column_info.data_type())); - - auto result = [&]() -> ExprPtr { - switch (data_type) { - // see also: https://github.com/milvus-io/milvus/issues/23646. - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: { - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - } - - case DataType::FLOAT: { - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::DOUBLE: { - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - } - case DataType::JSON: { - switch (expr_pb.value().val_case()) { - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {} in expression", - expr_pb.value().val_case()); - } - } - case DataType::ARRAY: { - switch (expr_pb.value().val_case()) { - case proto::plan::GenericValue::ValCase::kInt64Val: - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - case proto::plan::GenericValue::ValCase::kFloatVal: - return ExtractBinaryArithOpEvalRangeExprImpl( - field_id, data_type, expr_pb); - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {} in expression", - expr_pb.value().val_case()); - } - } - default: { - PanicInfo( - DataTypeInvalid, "unsupported data type {}", data_type); - } - } - }(); - return result; -} - expr::TypedExprPtr ProtoParser::ParseBinaryArithOpEvalRangeExprs( const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) { @@ -788,11 +297,6 @@ ProtoParser::ParseBinaryArithOpEvalRangeExprs( expr_pb.right_operand()); } -std::unique_ptr -ExtractExistsExprImpl(const proto::plan::ExistsExpr& expr_proto) { - return std::make_unique(expr_proto.info()); -} - expr::TypedExprPtr ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) { auto& column_info = expr_pb.info(); @@ -802,77 +306,6 @@ ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) { return std::make_shared(column_info); } -ExprPtr -ProtoParser::ParseExistExpr(const proto::plan::ExistsExpr& expr_pb) { - auto& column_info = expr_pb.info(); - auto field_id = FieldId(column_info.field_id()); - auto data_type = schema[field_id].get_data_type(); - Assert(data_type == static_cast(column_info.data_type())); - - auto result = [&]() -> ExprPtr { - switch (data_type) { - case DataType::JSON: { - return ExtractExistsExprImpl(expr_pb); - } - default: { - PanicInfo( - DataTypeInvalid, "unsupported data type {}", data_type); - } - } - }(); - return result; -} - -template -std::unique_ptr> -ExtractJsonContainsExprImpl(const proto::plan::JSONContainsExpr& expr_proto) { - static_assert(IsScalar or std::is_same_v or - std::is_same_v); - auto size = expr_proto.elements_size(); - std::vector terms; - terms.reserve(size); - auto val_case = proto::plan::GenericValue::VAL_NOT_SET; - for (int i = 0; i < size; ++i) { - auto& value_proto = expr_proto.elements(i); - if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kBoolVal); - terms.push_back(static_cast(value_proto.bool_val())); - val_case = proto::plan::GenericValue::ValCase::kBoolVal; - } else if constexpr (std::is_integral_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val); - auto value = value_proto.int64_val(); - if (out_of_range(value)) { - continue; - } - terms.push_back(static_cast(value)); - val_case = proto::plan::GenericValue::ValCase::kInt64Val; - } else if constexpr (std::is_floating_point_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal); - terms.push_back(static_cast(value_proto.float_val())); - val_case = proto::plan::GenericValue::ValCase::kFloatVal; - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kStringVal); - terms.push_back(static_cast(value_proto.string_val())); - val_case = proto::plan::GenericValue::ValCase::kStringVal; - } else if constexpr (std::is_same_v) { - Assert(value_proto.val_case() == planpb::GenericValue::kArrayVal); - terms.push_back(static_cast(value_proto.array_val())); - val_case = proto::plan::GenericValue::ValCase::kArrayVal; - } else if constexpr (std::is_same_v) { - terms.push_back(value_proto); - } else { - static_assert(always_false); - } - } - - return std::make_unique>( - expr_proto.column_info(), - terms, - expr_proto.elements_same_type(), - expr_proto.op(), - val_case); -} - expr::TypedExprPtr ProtoParser::ParseJsonContainsExprs( const proto::plan::JSONContainsExpr& expr_pb) { @@ -891,43 +324,6 @@ ProtoParser::ParseJsonContainsExprs( std::move(values)); } -ExprPtr -ProtoParser::ParseJsonContainsExpr( - const proto::plan::JSONContainsExpr& expr_pb) { - auto& columnInfo = expr_pb.column_info(); - auto field_id = FieldId(columnInfo.field_id()); - auto data_type = schema[field_id].get_data_type(); - Assert(data_type == (DataType)columnInfo.data_type()); - - // auto& field_meta = schema[field_offset]; - auto result = [&]() -> ExprPtr { - if (expr_pb.elements_size() == 0) { - PanicInfo(DataIsEmpty, "no elements in expression"); - } - if (expr_pb.elements_same_type()) { - switch (expr_pb.elements(0).val_case()) { - case proto::plan::GenericValue::kBoolVal: - return ExtractJsonContainsExprImpl(expr_pb); - case proto::plan::GenericValue::kInt64Val: - return ExtractJsonContainsExprImpl(expr_pb); - case proto::plan::GenericValue::kFloatVal: - return ExtractJsonContainsExprImpl(expr_pb); - case proto::plan::GenericValue::kStringVal: - return ExtractJsonContainsExprImpl(expr_pb); - case proto::plan::GenericValue::kArrayVal: - return ExtractJsonContainsExprImpl( - expr_pb); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported data type {}", data_type)); - } - } - return ExtractJsonContainsExprImpl(expr_pb); - }(); - return result; -} - expr::TypedExprPtr ProtoParser::CreateAlwaysTrueExprs() { return std::make_shared(); @@ -977,47 +373,4 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { } } -ExprPtr -ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) { - using ppe = proto::plan::Expr; - switch (expr_pb.expr_case()) { - case ppe::kBinaryExpr: { - return ParseBinaryExpr(expr_pb.binary_expr()); - } - case ppe::kUnaryExpr: { - return ParseUnaryExpr(expr_pb.unary_expr()); - } - case ppe::kTermExpr: { - return ParseTermExpr(expr_pb.term_expr()); - } - case ppe::kUnaryRangeExpr: { - return ParseUnaryRangeExpr(expr_pb.unary_range_expr()); - } - case ppe::kBinaryRangeExpr: { - return ParseBinaryRangeExpr(expr_pb.binary_range_expr()); - } - case ppe::kCompareExpr: { - return ParseCompareExpr(expr_pb.compare_expr()); - } - case ppe::kBinaryArithOpEvalRangeExpr: { - return ParseBinaryArithOpEvalRangeExpr( - expr_pb.binary_arith_op_eval_range_expr()); - } - case ppe::kExistsExpr: { - return ParseExistExpr(expr_pb.exists_expr()); - } - case ppe::kAlwaysTrueExpr: { - return CreateAlwaysTrueExpr(); - } - case ppe::kJsonContainsExpr: { - return ParseJsonContainsExpr(expr_pb.json_contains_expr()); - } - default: { - std::string s; - google::protobuf::TextFormat::PrintToString(expr_pb, &s); - PanicInfo(ExprInvalid, "unsupported expr proto node: {}", s); - } - } -} - } // namespace milvus::query diff --git a/internal/core/src/query/PlanProto.h b/internal/core/src/query/PlanProto.h index 51843d9c57cee..63673cefb9270 100644 --- a/internal/core/src/query/PlanProto.h +++ b/internal/core/src/query/PlanProto.h @@ -27,40 +27,6 @@ class ProtoParser { explicit ProtoParser(const Schema& schema) : schema(schema) { } - // ExprPtr - // ExprFromProto(const proto::plan::Expr& expr_proto); - - ExprPtr - ParseBinaryArithOpEvalRangeExpr( - const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb); - - ExprPtr - ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb); - - ExprPtr - ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb); - - ExprPtr - ParseCompareExpr(const proto::plan::CompareExpr& expr_pb); - - ExprPtr - ParseTermExpr(const proto::plan::TermExpr& expr_pb); - - ExprPtr - ParseUnaryExpr(const proto::plan::UnaryExpr& expr_pb); - - ExprPtr - ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb); - - ExprPtr - ParseExistExpr(const proto::plan::ExistsExpr& expr_pb); - - ExprPtr - ParseJsonContainsExpr(const proto::plan::JSONContainsExpr& expr_pb); - - ExprPtr - ParseExpr(const proto::plan::Expr& expr_pb); - std::unique_ptr PlanNodeFromProto(const proto::plan::PlanNode& plan_node_proto); @@ -112,7 +78,7 @@ class ProtoParser { }; } // namespace milvus::query - // +// template <> struct fmt::formatter : formatter { diff --git a/internal/core/src/query/Relational.h b/internal/core/src/query/Relational.h index 1839221db65ce..23dfbb3a4ccce 100644 --- a/internal/core/src/query/Relational.h +++ b/internal/core/src/query/Relational.h @@ -17,7 +17,6 @@ #include "common/Utils.h" #include "common/VectorTrait.h" #include "common/EasyAssert.h" -#include "query/Expr.h" #include "query/Utils.h" namespace milvus::query { diff --git a/internal/core/src/query/SearchOnIndex.cpp b/internal/core/src/query/SearchOnIndex.cpp index 2eb7cf9f3a344..0204f791ce217 100644 --- a/internal/core/src/query/SearchOnIndex.cpp +++ b/internal/core/src/query/SearchOnIndex.cpp @@ -10,7 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "SearchOnIndex.h" -#include "query/groupby/SearchGroupByOperator.h" +#include "exec/operator/groupby/SearchGroupByOperator.h" namespace milvus::query { void @@ -26,12 +26,12 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset, auto dataset = knowhere::GenDataSet(num_queries, dim, search_dataset.query_data); dataset->SetIsSparse(is_sparse); - if (!PrepareVectorIteratorsFromIndex(search_conf, - num_queries, - dataset, - search_result, - bitset, - indexing)) { + if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_conf, + num_queries, + dataset, + search_result, + bitset, + indexing)) { indexing.Query(dataset, search_conf, bitset, search_result); } } diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index db524c6a98f36..e3b2f21ae6150 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -17,7 +17,7 @@ #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" #include "query/helper.h" -#include "query/groupby/SearchGroupByOperator.h" +#include "exec/operator/groupby/SearchGroupByOperator.h" namespace milvus::query { @@ -48,12 +48,12 @@ SearchOnSealedIndex(const Schema& schema, dataset->SetIsSparse(is_sparse); auto vec_index = dynamic_cast(field_indexing->indexing_.get()); - if (!PrepareVectorIteratorsFromIndex(search_info, - num_queries, - dataset, - search_result, - bitset, - *vec_index)) { + if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info, + num_queries, + dataset, + search_result, + bitset, + *vec_index)) { auto index_type = vec_index->GetIndexType(); vec_index->Query(dataset, search_info, bitset, search_result); float* distances = search_result.distances_.data(); diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 830744da99f8e..8eb535d72d56e 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -14,7 +14,6 @@ #include #include -#include "query/Expr.h" #include "common/Utils.h" namespace milvus::query { diff --git a/internal/core/src/query/generated/.gitignore b/internal/core/src/query/generated/.gitignore deleted file mode 100644 index cad3ab59456e0..0000000000000 --- a/internal/core/src/query/generated/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -!.gitignore -*PlanNodeVisitor.cpp -*ExprVisitor.cpp \ No newline at end of file diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h deleted file mode 100644 index 2da1cd0cc5f68..0000000000000 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include -#include -#include -#include -#include -#include "segcore/SegmentGrowingImpl.h" -#include "query/ExprImpl.h" -#include "ExprVisitor.h" -#include "ExecPlanNodeVisitor.h" - -namespace milvus::query { - -void -AppendOneChunk(BitsetType& result, const TargetBitmapView chunk_res); - -class ExecExprVisitor : public ExprVisitor { - public: - void - visit(LogicalUnaryExpr& expr) override; - - void - visit(LogicalBinaryExpr& expr) override; - - void - visit(TermExpr& expr) override; - - void - visit(UnaryRangeExpr& expr) override; - - void - visit(BinaryArithOpEvalRangeExpr& expr) override; - - void - visit(BinaryRangeExpr& expr) override; - - void - visit(CompareExpr& expr) override; - - void - visit(ExistsExpr& expr) override; - - void - visit(AlwaysTrueExpr& expr) override; - - void - visit(JsonContainsExpr& expr) override; - - public: - ExecExprVisitor(const segcore::SegmentInternalInterface& segment, - int64_t row_count, - Timestamp timestamp) - : segment_(segment), - row_count_(row_count), - timestamp_(timestamp), - plan_visitor_(nullptr) { - } - - ExecExprVisitor(const segcore::SegmentInternalInterface& segment, - ExecPlanNodeVisitor* plan_visitor, - int64_t row_count, - Timestamp timestamp) - : segment_(segment), - plan_visitor_(plan_visitor), - row_count_(row_count), - timestamp_(timestamp) { - } - - BitsetType - call_child(Expr& expr) { - Assert(!bitset_opt_.has_value()); - expr.accept(*this); - Assert(bitset_opt_.has_value()); - auto res = std::move(bitset_opt_); - bitset_opt_ = std::nullopt; - return std::move(res.value()); - } - - public: - template - auto - ExecRangeVisitorImpl(FieldId field_id, - IndexFunc func, - ElementFunc element_func, - SkipIndexFunc skip_index_func) -> BitsetType; - - template - auto - ExecDataRangeVisitorImpl(FieldId field_id, - IndexFunc index_func, - ElementFunc element_func) -> BitsetType; - - template - auto - ExecUnaryRangeVisitorDispatcherImpl(UnaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecUnaryRangeVisitorDispatcherArray(UnaryRangeExpr& expr_raw) - -> BitsetType; - - template - auto - ExecBinaryArithOpEvalRangeVisitorDispatcherJson( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecBinaryArithOpEvalRangeVisitorDispatcherArray( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecBinaryArithOpEvalRangeVisitorDispatcher( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw) - -> BitsetType; - - template - auto - ExecBinaryRangeVisitorDispatcherArray(BinaryRangeExpr& expr_raw) - -> BitsetType; - - template - auto - ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermArrayVariableInField(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermArrayFieldInVariable(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermVisitorImplTemplateArray(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) - -> BitsetType; - - template - auto - ExecJsonContains(JsonContainsExpr& expr_raw) -> BitsetType; - - template - auto - ExecArrayContains(JsonContainsExpr& expr_raw) -> BitsetType; - - auto - ExecJsonContainsArray(JsonContainsExpr& expr_raw) -> BitsetType; - - auto - ExecJsonContainsWithDiffType(JsonContainsExpr& expr_raw) -> BitsetType; - - template - auto - ExecJsonContainsAll(JsonContainsExpr& expr_raw) -> BitsetType; - - template - auto - ExecArrayContainsAll(JsonContainsExpr& expr_raw) -> BitsetType; - - auto - ExecJsonContainsAllArray(JsonContainsExpr& expr_raw) -> BitsetType; - - auto - ExecJsonContainsAllWithDiffType(JsonContainsExpr& expr_raw) -> BitsetType; - - template - BitsetType - ExecCompareExprDispatcherForNonIndexedSegment(CompareExpr& expr, - CmpFunc cmp_func); - - // This function only used to compare sealed segment - // which has only one chunk. - template - TargetBitmap - ExecCompareRightType(const T* left_raw_data, - const FieldId& right_field_id, - const int64_t current_chunk_id, - CmpFunc cmp_func); - - template - BitsetType - ExecCompareLeftType(const FieldId& left_field_id, - const FieldId& right_field_id, - const DataType& right_field_type, - CmpFunc cmp_func); - - private: - const segcore::SegmentInternalInterface& segment_; - Timestamp timestamp_; - int64_t row_count_; - - BitsetTypeOpt bitset_opt_; - ExecPlanNodeVisitor* plan_visitor_; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/Expr.cpp b/internal/core/src/query/generated/Expr.cpp deleted file mode 100644 index b5f68211c9407..0000000000000 --- a/internal/core/src/query/generated/Expr.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -// Generated File -// DO NOT EDIT -#include "query/Expr.h" -#include "ExprVisitor.h" - -namespace milvus::query { -void -LogicalUnaryExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -LogicalBinaryExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -TermExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -UnaryRangeExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -BinaryArithOpEvalRangeExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} -void -BinaryRangeExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -CompareExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -ExistsExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -AlwaysTrueExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} - -void -JsonContainsExpr::accept(ExprVisitor& visitor) { - visitor.visit(*this); -} -} // namespace milvus::query diff --git a/internal/core/src/query/generated/ExprVisitor.h b/internal/core/src/query/generated/ExprVisitor.h deleted file mode 100644 index d76af63f68968..0000000000000 --- a/internal/core/src/query/generated/ExprVisitor.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include "query/Expr.h" -namespace milvus::query { -class ExprVisitor { - public: - virtual ~ExprVisitor() = default; - - public: - virtual void - visit(LogicalUnaryExpr&) = 0; - - virtual void - visit(LogicalBinaryExpr&) = 0; - - virtual void - visit(TermExpr&) = 0; - - virtual void - visit(UnaryRangeExpr&) = 0; - - virtual void - visit(BinaryArithOpEvalRangeExpr&) = 0; - - virtual void - visit(BinaryRangeExpr&) = 0; - - virtual void - visit(CompareExpr&) = 0; - - virtual void - visit(ExistsExpr&) = 0; - - virtual void - visit(AlwaysTrueExpr&) = 0; - - virtual void - visit(JsonContainsExpr&) = 0; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/ExtractInfoExprVisitor.h b/internal/core/src/query/generated/ExtractInfoExprVisitor.h deleted file mode 100644 index 6ce758e5f7dfa..0000000000000 --- a/internal/core/src/query/generated/ExtractInfoExprVisitor.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include "query/Plan.h" -#include "ExprVisitor.h" - -namespace milvus::query { -class ExtractInfoExprVisitor : public ExprVisitor { - public: - void - visit(LogicalUnaryExpr& expr) override; - - void - visit(LogicalBinaryExpr& expr) override; - - void - visit(TermExpr& expr) override; - - void - visit(UnaryRangeExpr& expr) override; - - void - visit(BinaryArithOpEvalRangeExpr& expr) override; - - void - visit(BinaryRangeExpr& expr) override; - - void - visit(CompareExpr& expr) override; - - void - visit(ExistsExpr& expr) override; - - void - visit(AlwaysTrueExpr& expr) override; - - void - visit(JsonContainsExpr& expr) override; - - public: - explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info) - : plan_info_(plan_info) { - } - - private: - ExtractedPlanInfo& plan_info_; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h b/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h deleted file mode 100644 index 48f813b7d5886..0000000000000 --- a/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include "query/Plan.h" -#include "PlanNodeVisitor.h" - -namespace milvus::query { -class ExtractInfoPlanNodeVisitor : public PlanNodeVisitor { - public: - void - visit(FloatVectorANNS& node) override; - - void - visit(BinaryVectorANNS& node) override; - - void - visit(Float16VectorANNS& node) override; - - void - visit(BFloat16VectorANNS& node) override; - - void - visit(SparseFloatVectorANNS& node) override; - - void - visit(RetrievePlanNode& node) override; - - public: - explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) - : plan_info_(plan_info) { - } - - private: - ExtractedPlanInfo& plan_info_; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/ShowExprVisitor.h b/internal/core/src/query/generated/ShowExprVisitor.h deleted file mode 100644 index 64532a00ddaa0..0000000000000 --- a/internal/core/src/query/generated/ShowExprVisitor.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include "query/Plan.h" -#include -#include "ExprVisitor.h" - -namespace milvus::query { -class ShowExprVisitor : public ExprVisitor { - public: - void - visit(LogicalUnaryExpr& expr) override; - - void - visit(LogicalBinaryExpr& expr) override; - - void - visit(TermExpr& expr) override; - - void - visit(UnaryRangeExpr& expr) override; - - void - visit(BinaryArithOpEvalRangeExpr& expr) override; - - void - visit(BinaryRangeExpr& expr) override; - - void - visit(CompareExpr& expr) override; - - void - visit(ExistsExpr& expr) override; - - void - visit(AlwaysTrueExpr& expr) override; - - void - visit(JsonContainsExpr& expr) override; - - public: - Json - - call_child(Expr& expr) { - assert(!json_opt_.has_value()); - expr.accept(*this); - assert(json_opt_.has_value()); - auto ret = std::move(json_opt_); - json_opt_ = std::nullopt; - return std::move(ret.value()); - } - - Json - combine(Json&& extra, UnaryExprBase& expr) { - auto result = std::move(extra); - result["child"] = call_child(*expr.child_); - return result; - } - - Json - combine(Json&& extra, BinaryExprBase& expr) { - auto result = std::move(extra); - result["left_child"] = call_child(*expr.left_); - result["right_child"] = call_child(*expr.right_); - return result; - } - - private: - std::optional json_opt_; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/ShowPlanNodeVisitor.h b/internal/core/src/query/generated/ShowPlanNodeVisitor.h deleted file mode 100644 index ec94659465471..0000000000000 --- a/internal/core/src/query/generated/ShowPlanNodeVisitor.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include "common/EasyAssert.h" -#include "common/Json.h" -#include -#include - -#include "PlanNodeVisitor.h" - -namespace milvus::query { -class ShowPlanNodeVisitor : public PlanNodeVisitor { - public: - void - visit(FloatVectorANNS& node) override; - - void - visit(BinaryVectorANNS& node) override; - - void - visit(Float16VectorANNS& node) override; - - void - visit(BFloat16VectorANNS& node) override; - - void - visit(SparseFloatVectorANNS& node) override; - - void - visit(RetrievePlanNode& node) override; - - public: - using RetType = nlohmann::json; - - public: - RetType - call_child(PlanNode& node) { - assert(!ret_.has_value()); - node.accept(*this); - assert(ret_.has_value()); - auto ret = std::move(ret_); - ret_ = std::nullopt; - return std::move(ret.value()); - } - - private: - std::optional ret_; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/VerifyExprVisitor.h b/internal/core/src/query/generated/VerifyExprVisitor.h deleted file mode 100644 index e791f2736db5d..0000000000000 --- a/internal/core/src/query/generated/VerifyExprVisitor.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include -#include -#include -#include -#include "segcore/SegmentGrowingImpl.h" -#include "query/ExprImpl.h" -#include "ExprVisitor.h" - -namespace milvus::query { -class VerifyExprVisitor : public ExprVisitor { - public: - void - visit(LogicalUnaryExpr& expr) override; - - void - visit(LogicalBinaryExpr& expr) override; - - void - visit(TermExpr& expr) override; - - void - visit(UnaryRangeExpr& expr) override; - - void - visit(BinaryArithOpEvalRangeExpr& expr) override; - - void - visit(BinaryRangeExpr& expr) override; - - void - visit(CompareExpr& expr) override; - - void - visit(ExistsExpr& expr) override; - - void - visit(AlwaysTrueExpr& expr) override; - - void - visit(JsonContainsExpr& expr) override; - - public: -}; -} // namespace milvus::query diff --git a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h b/internal/core/src/query/generated/VerifyPlanNodeVisitor.h deleted file mode 100644 index 40836460da340..0000000000000 --- a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -// Generated File -// DO NOT EDIT -#include "common/Json.h" -#include "query/PlanImpl.h" -#include "segcore/SegmentGrowing.h" -#include -#include "PlanNodeVisitor.h" - -namespace milvus::query { -class VerifyPlanNodeVisitor : public PlanNodeVisitor { - public: - void - visit(FloatVectorANNS& node) override; - - void - visit(BinaryVectorANNS& node) override; - - void - visit(Float16VectorANNS& node) override; - - void - visit(BFloat16VectorANNS& node) override; - - void - visit(SparseFloatVectorANNS& node) override; - - void - visit(RetrievePlanNode& node) override; - - public: - using RetType = SearchResult; - VerifyPlanNodeVisitor() = default; - - private: - std::optional ret_; -}; -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp deleted file mode 100644 index 7652e0e436083..0000000000000 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ /dev/null @@ -1,3525 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include "query/generated/ExecExprVisitor.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "arrow/type_fwd.h" -#include "common/Json.h" -#include "common/Types.h" -#include "common/EasyAssert.h" -#include "fmt/core.h" -#include "pb/plan.pb.h" -#include "query/ExprImpl.h" -#include "query/Relational.h" -#include "query/Utils.h" -#include "segcore/SegmentGrowingImpl.h" -#include "simdjson/error.h" -#include "query/PlanProto.h" -#include "index/SkipIndex.h" -#include "index/Meta.h" - -namespace milvus::query { -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR -namespace impl { -class ExecExprVisitor : ExprVisitor { - public: - ExecExprVisitor(const segcore::SegmentInternalInterface& segment, - int64_t row_count, - Timestamp timestamp) - : segment_(segment), row_count_(row_count), timestamp_(timestamp) { - } - - BitsetType - call_child(Expr& expr) { - AssertInfo(!bitset_opt_.has_value(), - "[ExecExprVisitor]Bitset already has value before accept"); - expr.accept(*this); - AssertInfo(bitset_opt_.has_value(), - "[ExecExprVisitor]Bitset doesn't have value after accept"); - auto res = std::move(bitset_opt_); - bitset_opt_ = std::nullopt; - return std::move(res.value()); - } - - public: - template - auto - ExecRangeVisitorImpl(FieldId field_id, - IndexFunc func, - ElementFunc element_func, - SkipIndexFunc skip_index_func) -> BitsetType; - - template - auto - ExecUnaryRangeVisitorDispatcherImpl(UnaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecBinaryArithOpEvalRangeVisitorDispatcher( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType; - - template - auto - ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) - -> BitsetType; - - private: - const segcore::SegmentInternalInterface& segment_; - int64_t row_count_; - Timestamp timestamp_; - BitsetTypeOpt bitset_opt_; -}; -} // namespace impl - -void -ExecExprVisitor::visit(LogicalUnaryExpr& expr) { - using OpType = LogicalUnaryExpr::OpType; - auto child_res = call_child(*expr.child_); - BitsetType res = std::move(child_res); - switch (expr.op_type_) { - case OpType::LogicalNot: { - res.flip(); - break; - } - default: { - PanicInfo(OpTypeInvalid, "Invalid Unary Op {}", expr.op_type_); - } - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -void -ExecExprVisitor::visit(LogicalBinaryExpr& expr) { - using OpType = LogicalBinaryExpr::OpType; - auto skip_right_expr = [](const BitsetType& left_res, - const OpType& op_type) -> bool { - return (op_type == OpType::LogicalAnd && left_res.none()) || - (op_type == OpType::LogicalOr && left_res.all()); - }; - - auto left = call_child(*expr.left_); - // skip execute right node for some situations - if (skip_right_expr(left, expr.op_type_)) { - AssertInfo(left.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(left); - return; - } - auto right = call_child(*expr.right_); - AssertInfo(left.size() == right.size(), - "[ExecExprVisitor]Left size not equal to right size"); - auto res = std::move(left); - switch (expr.op_type_) { - case OpType::LogicalAnd: { - res &= right; - break; - } - case OpType::LogicalOr: { - res |= right; - break; - } - case OpType::LogicalXor: { - res ^= right; - break; - } - case OpType::LogicalMinus: { - res -= right; - break; - } - default: { - PanicInfo(OpTypeInvalid, "Invalid Binary Op {}", expr.op_type_); - } - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -static auto -Assemble(const std::deque& srcs) -> BitsetType { - BitsetType res; - - int64_t total_size = 0; - for (auto& chunk : srcs) { - total_size += chunk.size(); - } - res.reserve(total_size); - - for (auto& chunk : srcs) { - res.append(chunk); - } - return res; -} - -void -AppendOneChunk(BitsetType& result, const TargetBitmapView chunk_res) { - result.append(chunk_res); -} - -BitsetType -AssembleChunk(const std::vector& results) { - BitsetType assemble_result; - for (auto& result : results) { - AppendOneChunk(assemble_result, result.view()); - } - return assemble_result; -} - -template -auto -ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, - IndexFunc index_func, - ElementFunc element_func, - SkipIndexFunc skip_index_func) - -> BitsetType { - auto& schema = segment_.get_schema(); - auto& field_meta = schema[field_id]; - auto indexing_barrier = segment_.num_chunk_index(field_id); - auto size_per_chunk = segment_.size_per_chunk(); - auto num_chunk = upper_div(row_count_, size_per_chunk); - std::vector results; - results.reserve(num_chunk); - - typedef std:: - conditional_t, std::string, T> - IndexInnerType; - using Index = index::ScalarIndex; - for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) { - const Index& indexing = - segment_.chunk_scalar_index(field_id, chunk_id); - // NOTE: knowhere is not const-ready - // This is a dirty workaround - auto data = index_func(const_cast(&indexing)); - AssertInfo(data.size() == size_per_chunk, - "[ExecExprVisitor]Data size not equal to size_per_chunk"); - results.emplace_back(std::move(data)); - } - - for (auto chunk_id = indexing_barrier; chunk_id < num_chunk; ++chunk_id) { - auto this_size = chunk_id == num_chunk - 1 - ? row_count_ - chunk_id * size_per_chunk - : size_per_chunk; - TargetBitmap chunk_res(this_size); - //check possible chunk metrics - auto& skipIndex = segment_.GetSkipIndex(); - if (skip_index_func(skipIndex, field_id, chunk_id)) { - results.emplace_back(std::move(chunk_res)); - continue; - } - auto chunk = segment_.chunk_data(field_id, chunk_id); - const T* data = chunk.data(); - // Can use CPU SIMD optimazation to speed up - for (int index = 0; index < this_size; ++index) { - chunk_res[index] = element_func(data[index]); - } - results.emplace_back(std::move(chunk_res)); - } - - auto final_result = AssembleChunk(results); - AssertInfo(final_result.size() == row_count_, - "[ExecExprVisitor]Final result size not equal to row count"); - return final_result; -} - -template -auto -ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, - IndexFunc index_func, - ElementFunc element_func) - -> BitsetType { - auto& schema = segment_.get_schema(); - auto& field_meta = schema[field_id]; - auto size_per_chunk = segment_.size_per_chunk(); - auto num_chunk = upper_div(row_count_, size_per_chunk); - auto indexing_barrier = segment_.num_chunk_index(field_id); - auto data_barrier = segment_.num_chunk_data(field_id); - AssertInfo(std::max(data_barrier, indexing_barrier) == num_chunk, - "max(data_barrier, index_barrier) not equal to num_chunk"); - std::vector results; - results.reserve(num_chunk); - - // for growing segment, indexing_barrier will always less than data_barrier - // so growing segment will always execute expr plan using raw data - // if sealed segment has loaded raw data on this field, then index_barrier = 0 and data_barrier = 1 - // in this case, sealed segment execute expr plan using raw data - for (auto chunk_id = 0; chunk_id < data_barrier; ++chunk_id) { - auto this_size = chunk_id == num_chunk - 1 - ? row_count_ - chunk_id * size_per_chunk - : size_per_chunk; - TargetBitmap result(this_size); - auto chunk = segment_.chunk_data(field_id, chunk_id); - const T* data = chunk.data(); - for (int index = 0; index < this_size; ++index) { - result[index] = element_func(data[index]); - } - AssertInfo(result.size() == this_size, - "[ExecExprVisitor]Chunk result size not equal to " - "expected size"); - results.emplace_back(std::move(result)); - } - - // if sealed segment has loaded scalar index for this field, then index_barrier = 1 and data_barrier = 0 - // in this case, sealed segment execute expr plan using scalar index - typedef std:: - conditional_t, std::string, T> - IndexInnerType; - using Index = index::ScalarIndex; - for (auto chunk_id = data_barrier; chunk_id < indexing_barrier; - ++chunk_id) { - auto& indexing = - segment_.chunk_scalar_index(field_id, chunk_id); - auto this_size = const_cast(&indexing)->Count(); - TargetBitmap result(this_size); - for (int offset = 0; offset < this_size; ++offset) { - result[offset] = index_func(const_cast(&indexing), offset); - } - results.emplace_back(std::move(result)); - } - - auto final_result = AssembleChunk(results); - AssertInfo(final_result.size() == row_count_, - "[ExecExprVisitor]Final result size not equal to row count"); - return final_result; -} - -#pragma clang diagnostic push -#pragma ide diagnostic ignored "Simplify" -template -auto -ExecExprVisitor::ExecUnaryRangeVisitorDispatcherImpl(UnaryRangeExpr& expr_raw) - -> BitsetType { - typedef std:: - conditional_t, std::string, T> - IndexInnerType; - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - - auto op = expr.op_type_; - auto val = IndexInnerType(expr.value_); - auto field_id = expr.column_.field_id; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - switch (op) { - case OpType::Equal: { - auto index_func = [&](Index* index) { return index->In(1, &val); }; - auto elem_func = [&](MayConstRef x) { return (x == val); }; - auto skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { - return skipIndex.CanSkipUnaryRange( - fieldId, chunkId, OpType::Equal, val); - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, skip_index_func); - } - case OpType::NotEqual: { - auto index_func = [&](Index* index) { - return index->NotIn(1, &val); - }; - auto elem_func = [&](MayConstRef x) { return (x != val); }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::GreaterEqual: { - auto index_func = [&](Index* index) { - return index->Range(val, OpType::GreaterEqual); - }; - auto elem_func = [&](MayConstRef x) { return (x >= val); }; - auto skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { - return skipIndex.CanSkipUnaryRange( - fieldId, chunkId, OpType::GreaterEqual, val); - }; - - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, skip_index_func); - } - case OpType::GreaterThan: { - auto index_func = [&](Index* index) { - return index->Range(val, OpType::GreaterThan); - }; - auto elem_func = [&](MayConstRef x) { return (x > val); }; - auto skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { - return skipIndex.CanSkipUnaryRange( - fieldId, chunkId, OpType::GreaterThan, val); - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, skip_index_func); - } - case OpType::LessEqual: { - auto index_func = [&](Index* index) { - return index->Range(val, OpType::LessEqual); - }; - auto elem_func = [&](MayConstRef x) { return (x <= val); }; - auto skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { - return skipIndex.CanSkipUnaryRange( - fieldId, chunkId, OpType::LessEqual, val); - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, skip_index_func); - } - case OpType::LessThan: { - auto index_func = [&](Index* index) { - return index->Range(val, OpType::LessThan); - }; - auto elem_func = [&](MayConstRef x) { return (x < val); }; - auto skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { - return skipIndex.CanSkipUnaryRange( - fieldId, chunkId, OpType::LessThan, val); - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, skip_index_func); - } - case OpType::PrefixMatch: { - auto index_func = [&](Index* index) { - auto dataset = std::make_unique(); - dataset->Set(milvus::index::OPERATOR_TYPE, OpType::PrefixMatch); - dataset->Set(milvus::index::PREFIX_VALUE, val); - return index->Query(std::move(dataset)); - }; - auto elem_func = [&](MayConstRef x) { - return Match(x, val, op); - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - // TODO: PostfixMatch - default: { - PanicInfo(OpTypeInvalid, "unsupported range node {}", op); - } - } -} -#pragma clang diagnostic pop - -template -auto -ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) - -> BitsetType { - // bool type is integral but will never be overflowed, - // the check method may evaluate it out of range with bool type, - // exclude bool type here - if constexpr (std::is_integral_v && !std::is_same_v) { - auto& expr = static_cast&>(expr_raw); - auto val = expr.value_; - - if (!out_of_range(val)) { - return ExecUnaryRangeVisitorDispatcherImpl(expr_raw); - } - - // see also: https://github.com/milvus-io/milvus/issues/23646. - switch (expr.op_type_) { - case proto::plan::GreaterThan: - case proto::plan::GreaterEqual: { - BitsetType r(row_count_); - if (lt_lb(val)) { - r.set(); - } - return r; - } - - case proto::plan::LessThan: - case proto::plan::LessEqual: { - BitsetType r(row_count_); - if (gt_ub(val)) { - r.set(); - } - return r; - } - - case proto::plan::Equal: { - BitsetType r(row_count_); - r.reset(); - return r; - } - - case proto::plan::NotEqual: { - BitsetType r(row_count_); - r.set(); - return r; - } - - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported range node {}", expr.op_type_)); - } - } - } - return ExecUnaryRangeVisitorDispatcherImpl(expr_raw); -} - -template -bool -CompareTwoJsonArray(T arr1, const proto::plan::Array& arr2) { - int json_array_length = 0; - if constexpr (std::is_same_v< - T, - simdjson::simdjson_result>) { - json_array_length = arr1.count_elements(); - } - if constexpr (std::is_same_v>>) { - json_array_length = arr1.size(); - } - if (arr2.array_size() != json_array_length) { - return false; - } - int i = 0; - for (auto&& it : arr1) { - switch (arr2.array(i).val_case()) { - case proto::plan::GenericValue::kBoolVal: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).bool_val()) { - return false; - } - break; - } - case proto::plan::GenericValue::kInt64Val: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).int64_val()) { - return false; - } - break; - } - case proto::plan::GenericValue::kFloatVal: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).float_val()) { - return false; - } - break; - } - case proto::plan::GenericValue::kStringVal: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).string_val()) { - return false; - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - arr2.array(i).val_case()); - } - i++; - } - return true; -} - -template -auto -ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - - auto op = expr.op_type_; - auto val = expr.value_; - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto field_id = expr.column_.field_id; - auto index_func = [=](Index* index) { return TargetBitmap{}; }; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - -#define UnaryRangeJSONCompare(cmp) \ - do { \ - auto x = json.template at(pointer); \ - if (x.error()) { \ - if constexpr (std::is_same_v) { \ - auto x = json.template at(pointer); \ - return !x.error() && (cmp); \ - } \ - return false; \ - } \ - return (cmp); \ - } while (false) - -#define UnaryRangeJSONCompareNotEqual(cmp) \ - do { \ - auto x = json.template at(pointer); \ - if (x.error()) { \ - if constexpr (std::is_same_v) { \ - auto x = json.template at(pointer); \ - return x.error() || (cmp); \ - } \ - return true; \ - } \ - return (cmp); \ - } while (false) - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - switch (op) { - case OpType::Equal: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - return CompareTwoJsonArray(array, val); - } else { - UnaryRangeJSONCompare(x.value() == val); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::NotEqual: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - return !CompareTwoJsonArray(array, val); - } else { - UnaryRangeJSONCompareNotEqual(x.value() != val); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::GreaterEqual: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - return false; - } else { - UnaryRangeJSONCompare(x.value() >= val); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::GreaterThan: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - return false; - } else { - UnaryRangeJSONCompare(x.value() > val); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::LessEqual: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - return false; - } else { - UnaryRangeJSONCompare(x.value() <= val); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::LessThan: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - return false; - } else { - UnaryRangeJSONCompare(x.value() < val); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::PrefixMatch: { - auto elem_func = [&](const milvus::Json& json) { - if constexpr (std::is_same_v) { - return false; - } else { - UnaryRangeJSONCompare( - Match(ExprValueType(x.value()), val, op)); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - // TODO: PostfixMatch - default: { - PanicInfo(OpTypeInvalid, "unsupported range node {}", op); - } - } -} - -template -auto -ExecExprVisitor::ExecUnaryRangeVisitorDispatcherArray(UnaryRangeExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - - auto op = expr.op_type_; - auto val = expr.value_; - auto field_id = expr.column_.field_id; - auto index_func = [=](Index* index) { return TargetBitmap{}; }; - int index = -1; - if (expr.column_.nested_path.size() > 0) { - index = std::stoi(expr.column_.nested_path[0]); - } - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - switch (op) { - case OpType::Equal: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return array.is_same_array(val); - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return array_data == val; - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::NotEqual: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return !array.is_same_array(val); - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return array_data != val; - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::GreaterEqual: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return false; - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return array_data >= val; - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::GreaterThan: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return false; - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return array_data > val; - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::LessEqual: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return false; - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return array_data <= val; - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::LessThan: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return false; - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return array_data < val; - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - case OpType::PrefixMatch: { - auto elem_func = [&](const milvus::ArrayView& array) { - if constexpr (std::is_same_v) { - return false; - } else { - if (index >= array.length()) { - return false; - } - auto array_data = array.template get_data(index); - return Match(array_data, val, op); - } - }; - return ExecRangeVisitorImpl( - field_id, index_func, elem_func, default_skip_index_func); - } - // TODO: PostfixMatch - default: { - PanicInfo(OpTypeInvalid, "unsupported range node {}", op); - } - } -} - -#pragma clang diagnostic push -#pragma ide diagnostic ignored "Simplify" -template -auto -ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType { - // see also: https://github.com/milvus-io/milvus/issues/23646. - typedef std::conditional_t && - !std::is_same_v, - int64_t, - T> - HighPrecisionType; - - auto& expr = - static_cast&>( - expr_raw); - using Index = index::ScalarIndex; - auto arith_op = expr.arith_op_; - auto right_operand = expr.right_operand_; - auto op = expr.op_type_; - auto val = expr.value_; - - switch (op) { - case OpType::Equal: { - switch (arith_op) { - case ArithOpType::Add: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x + right_operand) == val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x + right_operand) == val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Sub: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x - right_operand) == val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x - right_operand) == val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mul: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x * right_operand) == val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x * right_operand) == val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Div: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x / right_operand) == val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x / right_operand) == val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mod: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return static_cast(fmod(x, right_operand)) == val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return (static_cast(fmod(x, right_operand)) == val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported arithmetic operation {}", op)); - } - } - } - case OpType::NotEqual: { - switch (arith_op) { - case ArithOpType::Add: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x + right_operand) != val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x + right_operand) != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Sub: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x - right_operand) != val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x - right_operand) != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mul: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x * right_operand) != val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x * right_operand) != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Div: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return (x / right_operand) != val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return ((x / right_operand) != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mod: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - auto x = index->Reverse_Lookup(offset); - return static_cast(fmod(x, right_operand)) != val; - }; - auto elem_func = [val, right_operand](MayConstRef x) { - return (static_cast(fmod(x, right_operand)) != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported arithmetic operation {}", op)); - } - } - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported range node with arithmetic operation {}", op)); - } - } -} -#pragma clang diagnostic pop - -template -auto -ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType { - auto& expr = - static_cast&>(expr_raw); - using Index = index::ScalarIndex; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - auto arith_op = expr.arith_op_; - auto right_operand = expr.right_operand_; - auto op = expr.op_type_; - auto val = expr.value_; - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - -#define BinaryArithRangeJSONCompare(cmp) \ - do { \ - auto x = json.template at(pointer); \ - if (x.error()) { \ - if constexpr (std::is_same_v) { \ - auto x = json.template at(pointer); \ - return !x.error() && (cmp); \ - } \ - return false; \ - } \ - return (cmp); \ - } while (false) - -#define BinaryArithRangeJSONCompareNotEqual(cmp) \ - do { \ - auto x = json.template at(pointer); \ - if (x.error()) { \ - if constexpr (std::is_same_v) { \ - auto x = json.template at(pointer); \ - return x.error() || (cmp); \ - } \ - return true; \ - } \ - return (cmp); \ - } while (false) - - switch (op) { - case OpType::Equal: { - switch (arith_op) { - case ArithOpType::Add: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompare(x.value() + right_operand == - val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Sub: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompare(x.value() - right_operand == - val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mul: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompare(x.value() * right_operand == - val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Div: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompare(x.value() / right_operand == - val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mod: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompare( - static_cast( - fmod(x.value(), right_operand)) == val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::ArrayLength: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); - } - return array_length == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported arithmetic operation {}", op)); - } - } - } - case OpType::NotEqual: { - switch (arith_op) { - case ArithOpType::Add: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompareNotEqual( - x.value() + right_operand != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Sub: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompareNotEqual( - x.value() - right_operand != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mul: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompareNotEqual( - x.value() * right_operand != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Div: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompareNotEqual( - x.value() / right_operand != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mod: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - BinaryArithRangeJSONCompareNotEqual( - static_cast( - fmod(x.value(), right_operand)) != val); - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::ArrayLength: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::Json& json) { - int array_length = 0; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (!array.error()) { - array_length = array.count_elements(); - } - return array_length != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported arithmetic operation {}", op)); - } - } - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported range node with arithmetic operation {}", op)); - } - } -} // namespace milvus::query - -template -auto -ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherArray( - BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType { - auto& expr = - static_cast&>(expr_raw); - using Index = index::ScalarIndex; - - auto arith_op = expr.arith_op_; - auto right_operand = expr.right_operand_; - auto op = expr.op_type_; - auto val = expr.value_; - int index = -1; - if (expr.column_.nested_path.size() > 0) { - index = std::stoi(expr.column_.nested_path[0]); - } - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - switch (op) { - case OpType::Equal: { - switch (arith_op) { - case ArithOpType::Add: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value + right_operand == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Sub: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value - right_operand == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mul: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value * right_operand == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Div: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value / right_operand == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mod: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return static_cast( - fmod(value, right_operand)) == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::ArrayLength: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - return array.length() == val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported arithmetic operation {}", op)); - } - } - } - case OpType::NotEqual: { - switch (arith_op) { - case ArithOpType::Add: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value + right_operand != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Sub: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value - right_operand != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mul: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value * right_operand != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Div: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return value / right_operand != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::Mod: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return static_cast( - fmod(value, right_operand)) != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - case ArithOpType::ArrayLength: { - auto index_func = [val, right_operand](Index* index, - size_t offset) { - return false; - }; - auto elem_func = [&](const milvus::ArrayView& array) { - return array.length() != val; - }; - return ExecDataRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func); - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format("unsupported arithmetic operation {}", op)); - } - } - } - default: { - PanicInfo( - OpTypeInvalid, - fmt::format( - "unsupported range node with arithmetic operation {}", op)); - } - } -} // namespace milvus::query - -#pragma clang diagnostic push -#pragma ide diagnostic ignored "Simplify" -template -auto -ExecExprVisitor::ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) - -> BitsetType { - typedef std:: - conditional_t, std::string, T> - IndexInnerType; - using Index = index::ScalarIndex; - - // see also: https://github.com/milvus-io/milvus/issues/23646. - typedef std::conditional_t && - !std::is_same_v, - int64_t, - IndexInnerType> - HighPrecisionType; - auto& expr = static_cast&>(expr_raw); - - bool lower_inclusive = expr.lower_inclusive_; - bool upper_inclusive = expr.upper_inclusive_; - auto val1 = static_cast(expr.lower_value_); - auto val2 = static_cast(expr.upper_value_); - - if constexpr (std::is_integral_v && !std::is_same_v) { - if (gt_ub(val1)) { - BitsetType r(row_count_); - r.reset(); - return r; - } else if (lt_lb(val1)) { - val1 = std::numeric_limits::min(); - lower_inclusive = true; - } - - if (gt_ub(val2)) { - val2 = std::numeric_limits::max(); - upper_inclusive = true; - } else if (lt_lb(val2)) { - BitsetType r(row_count_); - r.reset(); - return r; - } - } - - auto index_func = [=](Index* index) { - return index->Range(val1, lower_inclusive, val2, upper_inclusive); - }; - if (lower_inclusive && upper_inclusive) { - auto elem_func = [val1, val2](MayConstRef x) { - return (val1 <= x && x <= val2); - }; - auto skip_index_func = [&](const SkipIndex& skip_index, - FieldId field_id, - int64_t chunk_id) { - return skip_index.CanSkipBinaryRange( - field_id, chunk_id, val1, val2, true, true); - }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, skip_index_func); - } else if (lower_inclusive && !upper_inclusive) { - auto elem_func = [val1, val2](MayConstRef x) { - return (val1 <= x && x < val2); - }; - auto skip_index_func = [&](const SkipIndex& skip_index, - FieldId field_id, - int64_t chunk_id) { - return skip_index.CanSkipBinaryRange( - field_id, chunk_id, val1, val2, true, false); - }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, skip_index_func); - } else if (!lower_inclusive && upper_inclusive) { - auto elem_func = [val1, val2](MayConstRef x) { - return (val1 < x && x <= val2); - }; - auto skip_index_func = [&](const SkipIndex& skip_index, - FieldId field_id, - int64_t chunk_id) { - return skip_index.CanSkipBinaryRange( - field_id, chunk_id, val1, val2, false, true); - }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, skip_index_func); - } else { - auto elem_func = [val1, val2](MayConstRef x) { - return (val1 < x && x < val2); - }; - auto skip_index_func = [&](const SkipIndex& skip_index, - FieldId field_id, - int64_t chunk_id) { - return skip_index.CanSkipBinaryRange( - field_id, chunk_id, val1, val2, false, false); - }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, skip_index_func); - } -} -#pragma clang diagnostic pop - -template -auto -ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - auto& expr = static_cast&>(expr_raw); - bool lower_inclusive = expr.lower_inclusive_; - bool upper_inclusive = expr.upper_inclusive_; - ExprValueType val1 = expr.lower_value_; - ExprValueType val2 = expr.upper_value_; - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - - // no json index now - auto index_func = [=](Index* index) { return TargetBitmap{}; }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - -#define BinaryRangeJSONCompare(cmp) \ - do { \ - auto x = json.template at(pointer); \ - if (x.error()) { \ - if constexpr (std::is_same_v) { \ - auto x = json.template at(pointer); \ - if (!x.error()) { \ - auto value = x.value(); \ - return (cmp); \ - } \ - } \ - return false; \ - } \ - auto value = x.value(); \ - return (cmp); \ - } while (false) - - if (lower_inclusive && upper_inclusive) { - auto elem_func = [&](const milvus::Json& json) { - BinaryRangeJSONCompare(val1 <= value && value <= val2); - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } else if (lower_inclusive && !upper_inclusive) { - auto elem_func = [&](const milvus::Json& json) { - BinaryRangeJSONCompare(val1 <= value && value < val2); - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } else if (!lower_inclusive && upper_inclusive) { - auto elem_func = [&](const milvus::Json& json) { - BinaryRangeJSONCompare(val1 < value && value <= val2); - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } else { - auto elem_func = [&](const milvus::Json& json) { - BinaryRangeJSONCompare(val1 < value && value < val2); - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } -} - -template -auto -ExecExprVisitor::ExecBinaryRangeVisitorDispatcherArray( - BinaryRangeExpr& expr_raw) -> BitsetType { - using Index = index::ScalarIndex; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - auto& expr = static_cast&>(expr_raw); - bool lower_inclusive = expr.lower_inclusive_; - bool upper_inclusive = expr.upper_inclusive_; - ExprValueType val1 = expr.lower_value_; - ExprValueType val2 = expr.upper_value_; - int index = -1; - if (expr.column_.nested_path.size() > 0) { - index = std::stoi(expr.column_.nested_path[0]); - } - - // no json index now - auto index_func = [=](Index* index) { return TargetBitmap{}; }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - - if (lower_inclusive && upper_inclusive) { - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return val1 <= value && value <= val2; - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } else if (lower_inclusive && !upper_inclusive) { - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return val1 <= value && value < val2; - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } else if (!lower_inclusive && upper_inclusive) { - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return val1 < value && value <= val2; - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } else { - auto elem_func = [&](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return val1 < value && value < val2; - }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } -} - -void -ExecExprVisitor::visit(UnaryRangeExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.column_.field_id]; - AssertInfo(expr.column_.data_type == field_meta.get_data_type(), - "[ExecExprVisitor]DataType of expr isn't field_meta data type"); - BitsetType res; - switch (expr.column_.data_type) { - case DataType::BOOL: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT8: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT16: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT32: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT64: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::FLOAT: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::DOUBLE: { - res = ExecUnaryRangeVisitorDispatcher(expr); - break; - } - case DataType::VARCHAR: { - if (segment_.type() == SegmentType::Growing) { - res = ExecUnaryRangeVisitorDispatcher(expr); - } else { - res = ExecUnaryRangeVisitorDispatcher(expr); - } - break; - } - case DataType::JSON: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kBoolVal: - res = ExecUnaryRangeVisitorDispatcherJson(expr); - break; - case proto::plan::GenericValue::ValCase::kInt64Val: - res = ExecUnaryRangeVisitorDispatcherJson(expr); - break; - case proto::plan::GenericValue::ValCase::kFloatVal: - res = ExecUnaryRangeVisitorDispatcherJson(expr); - break; - case proto::plan::GenericValue::ValCase::kStringVal: - res = - ExecUnaryRangeVisitorDispatcherJson(expr); - break; - case proto::plan::GenericValue::ValCase::kArrayVal: - res = - ExecUnaryRangeVisitorDispatcherJson( - expr); - break; - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type: {}", expr.val_case_)); - } - break; - } - case DataType::ARRAY: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kBoolVal: - res = ExecUnaryRangeVisitorDispatcherArray(expr); - break; - case proto::plan::GenericValue::ValCase::kInt64Val: - res = ExecUnaryRangeVisitorDispatcherArray(expr); - break; - case proto::plan::GenericValue::ValCase::kFloatVal: - res = ExecUnaryRangeVisitorDispatcherArray(expr); - break; - case proto::plan::GenericValue::ValCase::kStringVal: - res = - ExecUnaryRangeVisitorDispatcherArray(expr); - break; - case proto::plan::GenericValue::ValCase::kArrayVal: - res = ExecUnaryRangeVisitorDispatcherArray< - proto::plan::Array>(expr); - break; - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unknown data type: {}", expr.val_case_)); - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type: {}", - expr.column_.data_type); - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -void -ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.column_.field_id]; - AssertInfo(expr.column_.data_type == field_meta.get_data_type(), - "[ExecExprVisitor]DataType of expr isn't field_meta data type"); - BitsetType res; - switch (expr.column_.data_type) { - case DataType::INT8: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcher(expr); - break; - } - case DataType::INT16: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcher(expr); - break; - } - case DataType::INT32: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcher(expr); - break; - } - case DataType::INT64: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcher(expr); - break; - } - case DataType::FLOAT: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcher(expr); - break; - } - case DataType::DOUBLE: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcher(expr); - break; - } - case DataType::JSON: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kBoolVal: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcherJson( - expr); - break; - } - case proto::plan::GenericValue::ValCase::kInt64Val: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcherJson< - int64_t>(expr); - break; - } - case proto::plan::GenericValue::ValCase::kFloatVal: { - res = - ExecBinaryArithOpEvalRangeVisitorDispatcherJson( - expr); - break; - } - default: { - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported value type {} in expression", - expr.val_case_)); - } - } - break; - } - case DataType::ARRAY: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kInt64Val: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcherArray< - int64_t>(expr); - break; - } - case proto::plan::GenericValue::ValCase::kFloatVal: { - res = ExecBinaryArithOpEvalRangeVisitorDispatcherArray< - double>(expr); - break; - } - default: { - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported value type {} in expression", - expr.val_case_)); - } - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type: {}", - expr.column_.data_type); - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -void -ExecExprVisitor::visit(BinaryRangeExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.column_.field_id]; - AssertInfo(expr.column_.data_type == field_meta.get_data_type(), - "[ExecExprVisitor]DataType of expr isn't field_meta data type"); - BitsetType res; - switch (expr.column_.data_type) { - case DataType::BOOL: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT8: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT16: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT32: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::INT64: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::FLOAT: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::DOUBLE: { - res = ExecBinaryRangeVisitorDispatcher(expr); - break; - } - case DataType::VARCHAR: { - if (segment_.type() == SegmentType::Growing) { - res = ExecBinaryRangeVisitorDispatcher(expr); - } else { - res = ExecBinaryRangeVisitorDispatcher(expr); - } - break; - } - case DataType::JSON: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kInt64Val: { - res = ExecBinaryRangeVisitorDispatcherJson(expr); - break; - } - case proto::plan::GenericValue::ValCase::kFloatVal: { - res = ExecBinaryRangeVisitorDispatcherJson(expr); - break; - } - case proto::plan::GenericValue::ValCase::kStringVal: { - res = - ExecBinaryRangeVisitorDispatcherJson(expr); - break; - } - default: { - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported value type {} in expression", - expr.val_case_)); - } - } - break; - } - case DataType::ARRAY: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kInt64Val: { - res = ExecBinaryRangeVisitorDispatcherArray(expr); - break; - } - case proto::plan::GenericValue::ValCase::kFloatVal: { - res = ExecBinaryRangeVisitorDispatcherArray(expr); - break; - } - case proto::plan::GenericValue::ValCase::kStringVal: { - res = ExecBinaryRangeVisitorDispatcherArray( - expr); - break; - } - default: { - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported value type {} in expression", - expr.val_case_)); - } - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type: {}", - expr.column_.data_type); - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -template -struct relational { - template - bool - operator()(T const& a, U const& b) const { - return Op{}(a, b); - } - template - bool - operator()(T const&...) const { - PanicInfo(OpTypeInvalid, "incompatible operands"); - } -}; - -template -TargetBitmap -ExecExprVisitor::ExecCompareRightType(const T* left_raw_data, - const FieldId& right_field_id, - const int64_t current_chunk_id, - CmpFunc cmp_func) { - auto size_per_chunk = segment_.size_per_chunk(); - auto num_chunks = upper_div(row_count_, size_per_chunk); - auto size = current_chunk_id == num_chunks - 1 - ? row_count_ - current_chunk_id * size_per_chunk - : size_per_chunk; - - TargetBitmap result(size); - const U* right_raw_data = - segment_.chunk_data(right_field_id, current_chunk_id).data(); - - for (int i = 0; i < size; ++i) { - result[i] = cmp_func(left_raw_data[i], right_raw_data[i]); - } - - return result; -} - -template -BitsetType -ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id, - const FieldId& right_field_id, - const DataType& right_field_type, - CmpFunc cmp_func) { - auto size_per_chunk = segment_.size_per_chunk(); - auto num_chunks = upper_div(row_count_, size_per_chunk); - std::vector results; - results.reserve(num_chunks); - - for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) { - TargetBitmap result; - const T* left_raw_data = - segment_.chunk_data(left_field_id, chunk_id).data(); - - switch (right_field_type) { - case DataType::BOOL: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - case DataType::INT8: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - case DataType::INT16: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - case DataType::INT32: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - case DataType::INT64: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - case DataType::FLOAT: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - case DataType::DOUBLE: - result = ExecCompareRightType( - left_raw_data, right_field_id, chunk_id, cmp_func); - break; - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported right datatype {} of compare expr", - right_field_type)); - } - results.push_back(std::move(result)); - } - auto final_result = AssembleChunk(results); - AssertInfo(final_result.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - return final_result; -} - -template -BitsetType -ExecExprVisitor::ExecCompareExprDispatcherForNonIndexedSegment( - CompareExpr& expr, CmpFunc cmp_func) { - switch (expr.left_data_type_) { - case DataType::BOOL: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - case DataType::INT8: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - case DataType::INT16: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - case DataType::INT32: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - case DataType::INT64: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - case DataType::FLOAT: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - case DataType::DOUBLE: - return ExecCompareLeftType(expr.left_field_id_, - expr.right_field_id_, - expr.right_data_type_, - cmp_func); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported right datatype {} of compare expr", - expr.left_data_type_)); - } -} - -template -auto -ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) - -> BitsetType { - using number = boost::variant; - auto is_string_expr = [&expr]() -> bool { - return expr.left_data_type_ == DataType::VARCHAR || - expr.right_data_type_ == DataType::VARCHAR; - }; - - auto size_per_chunk = segment_.size_per_chunk(); - auto num_chunk = upper_div(row_count_, size_per_chunk); - std::deque bitsets; - - // check for sealed segment, load either raw field data or index - auto left_indexing_barrier = segment_.num_chunk_index(expr.left_field_id_); - auto left_data_barrier = segment_.num_chunk_data(expr.left_field_id_); - AssertInfo(std::max(left_data_barrier, left_indexing_barrier) == num_chunk, - "max(left_data_barrier, left_indexing_barrier) not equal to " - "num_chunk"); - - auto right_indexing_barrier = - segment_.num_chunk_index(expr.right_field_id_); - auto right_data_barrier = segment_.num_chunk_data(expr.right_field_id_); - AssertInfo( - std::max(right_data_barrier, right_indexing_barrier) == num_chunk, - "max(right_data_barrier, right_indexing_barrier) not equal to " - "num_chunk"); - - // For segment both fields has no index, can use SIMD to speed up. - // Avoiding too much call stack that blocks SIMD. - if (left_indexing_barrier == 0 && right_indexing_barrier == 0 && - !is_string_expr()) { - return ExecCompareExprDispatcherForNonIndexedSegment(expr, op); - } - - // TODO: refactoring the code that contains too much call stack. - for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) { - auto size = chunk_id == num_chunk - 1 - ? row_count_ - chunk_id * size_per_chunk - : size_per_chunk; - auto getChunkData = - [&, chunk_id](DataType type, FieldId field_id, int64_t data_barrier) - -> std::function { - switch (type) { - case DataType::BOOL: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::INT8: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::INT16: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::INT32: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::INT64: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::FLOAT: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::DOUBLE: { - if (chunk_id < data_barrier) { - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = segment_.chunk_scalar_index( - field_id, chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - case DataType::VARCHAR: { - if (chunk_id < data_barrier) { - if (segment_.type() == SegmentType::Growing && - !storage::MmapManager::GetInstance() - .GetMmapConfig() - .growing_enable_mmap) { - auto chunk_data = - segment_ - .chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } else { - auto chunk_data = segment_ - .chunk_data( - field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return std::string(chunk_data[i]); - }; - } - } else { - // for case, sealed segment has loaded index for scalar field instead of raw data - auto& indexing = - segment_.chunk_scalar_index(field_id, - chunk_id); - if (indexing.HasRawData()) { - return [&indexing](int i) -> const number { - return indexing.Reverse_Lookup(i); - }; - } - auto chunk_data = - segment_.chunk_data(field_id, chunk_id) - .data(); - return [chunk_data](int i) -> const number { - return chunk_data[i]; - }; - } - } - default: - PanicInfo( - DataTypeInvalid, "unsupported data type {}", type); - } - }; - auto left = getChunkData( - expr.left_data_type_, expr.left_field_id_, left_data_barrier); - auto right = getChunkData( - expr.right_data_type_, expr.right_field_id_, right_data_barrier); - - BitsetType bitset(size); - for (int i = 0; i < size; ++i) { - bool is_in = boost::apply_visitor( - Relational{}, left(i), right(i)); - bitset[i] = is_in; - } - bitsets.emplace_back(std::move(bitset)); - } - auto final_result = Assemble(bitsets); - AssertInfo(final_result.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - return final_result; -} - -void -ExecExprVisitor::visit(CompareExpr& expr) { - auto& schema = segment_.get_schema(); - auto& left_field_meta = schema[expr.left_field_id_]; - auto& right_field_meta = schema[expr.right_field_id_]; - AssertInfo(expr.left_data_type_ == left_field_meta.get_data_type(), - "[ExecExprVisitor]Left data type not equal to left field " - "meta type"); - AssertInfo(expr.right_data_type_ == right_field_meta.get_data_type(), - "[ExecExprVisitor]right data type not equal to right field " - "meta type"); - - BitsetType res; - switch (expr.op_type_) { - case OpType::Equal: { - res = ExecCompareExprDispatcher(expr, std::equal_to<>{}); - break; - } - case OpType::NotEqual: { - res = ExecCompareExprDispatcher(expr, std::not_equal_to<>{}); - break; - } - case OpType::GreaterEqual: { - res = ExecCompareExprDispatcher(expr, std::greater_equal<>{}); - break; - } - case OpType::GreaterThan: { - res = ExecCompareExprDispatcher(expr, std::greater<>{}); - break; - } - case OpType::LessEqual: { - res = ExecCompareExprDispatcher(expr, std::less_equal<>{}); - break; - } - case OpType::LessThan: { - res = ExecCompareExprDispatcher(expr, std::less<>{}); - break; - } - case OpType::PrefixMatch: { - res = - ExecCompareExprDispatcher(expr, MatchOp{}); - break; - } - // case OpType::PostfixMatch: { - // } - default: { - PanicInfo(OpTypeInvalid, "unsupported optype {}", expr.op_type_); - } - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -template -auto -ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType { - typedef std:: - conditional_t, std::string, T> - InnerType; - auto& expr = static_cast&>(expr_raw); - auto& schema = segment_.get_schema(); - auto primary_filed_id = schema.get_primary_field_id(); - auto field_id = expr_raw.column_.field_id; - auto& field_meta = schema[field_id]; - - bool use_pk_index = false; - if (primary_filed_id.has_value()) { - use_pk_index = primary_filed_id.value() == field_id && - IsPrimaryKeyDataType(field_meta.get_data_type()); - } - - if (use_pk_index) { - auto id_array = std::make_unique(); - switch (field_meta.get_data_type()) { - case DataType::INT64: { - auto dst_ids = id_array->mutable_int_id(); - for (const auto& id : expr.terms_) { - dst_ids->add_data((int64_t&)id); - } - break; - } - case DataType::VARCHAR: { - auto dst_ids = id_array->mutable_str_id(); - for (const auto& id : expr.terms_) { - dst_ids->add_data((std::string&)id); - } - break; - } - default: { - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported data type {}", expr.val_case_)); - } - } - - auto [uids, seg_offsets] = segment_.search_ids(*id_array, timestamp_); - BitsetType bitset(row_count_); - std::vector cached_offsets; - for (const auto& offset : seg_offsets) { - auto _offset = (int64_t)offset.get(); - bitset[_offset] = true; - cached_offsets.push_back(_offset); - } - // If enable plan_visitor pk index cache, pass offsets_ to it - if (plan_visitor_ != nullptr) { - plan_visitor_->SetExprUsePkIndex(true); - } - AssertInfo(bitset.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - return bitset; - } - - return ExecTermVisitorImplTemplate(expr_raw); -} - -template -auto -ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType { - typedef std:: - conditional_t, std::string, T> - IndexInnerType; - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - const auto& terms = expr.terms_; - auto n = terms.size(); - std::unordered_set term_set(expr.terms_.begin(), expr.terms_.end()); - - auto index_func = [&terms, n](Index* index) { - return index->In(n, terms.data()); - }; - - auto elem_func = [&term_set](MayConstRef x) { - return term_set.find(x) != term_set.end(); - }; - - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -// TODO: bool is so ugly here. -template <> -auto -ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) - -> BitsetType { - using T = bool; - auto& expr = static_cast&>(expr_raw); - using Index = index::ScalarIndex; - const auto& terms = expr.terms_; - auto n = terms.size(); - std::unordered_set term_set(expr.terms_.begin(), expr.terms_.end()); - - auto index_func = [&terms, n](Index* index) { - auto bool_arr_copy = new bool[terms.size()]; - int it = 0; - for (auto elem : terms) { - bool_arr_copy[it++] = elem; - } - auto bitset = index->In(n, bool_arr_copy); - delete[] bool_arr_copy; - return bitset; - }; - - auto elem_func = [&terms, &term_set](MayConstRef x) { - //// terms has already been sorted. - // return std::binary_search(terms.begin(), terms.end(), x); - return term_set.find(x) != term_set.end(); - }; - auto skip_index_func = - [&](const SkipIndex& skipIndex, FieldId fieldId, int64_t chunkId) { - for (const auto& term : term_set) { - if (!skipIndex.CanSkipUnaryRange( - fieldId, chunkId, OpType::Equal, term)) { - return false; - } - } - return true; - }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, skip_index_func); -} - -template -auto -ExecExprVisitor::ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - - std::unordered_set term_set(expr.terms_.begin(), - expr.terms_.end()); - - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - if (term_set.empty()) { - auto elem_func = [=](const milvus::Json& json) { return false; }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } - - auto elem_func = [&term_set, &pointer](const milvus::Json& json) { - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - auto x = json.template at(pointer); - if (x.error()) { - if constexpr (std::is_same_v) { - auto x = json.template at(pointer); - if (x.error()) { - return false; - } - - auto value = x.value(); - // if the term set is {1}, and the value is 1.1, we should not return true. - return std::floor(value) == value && - term_set.find(ExprValueType(value)) != term_set.end(); - } - return false; - } - return term_set.find(ExprValueType(x.value())) != term_set.end(); - }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecTermArrayFieldInVariable(TermExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - int index = -1; - if (expr.column_.nested_path.size() > 0) { - index = std::stoi(expr.column_.nested_path[0]); - } - std::unordered_set term_set(expr.terms_.begin(), - expr.terms_.end()); - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - if (term_set.empty()) { - auto elem_func = [=](const milvus::ArrayView& array) { return false; }; - return ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - } - - auto elem_func = [&term_set, &index](const milvus::ArrayView& array) { - if (index >= array.length()) { - return false; - } - auto value = array.get_data(index); - return term_set.find(ExprValueType(value)) != term_set.end(); - }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - - AssertInfo(expr.terms_.size() == 1, - "element length in json array must be one"); - ExprValueType target_val = expr.terms_[0]; - - auto elem_func = [&target_val, &pointer](const milvus::Json& json) { - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) - return false; - for (auto it = array.begin(); it != array.end(); ++it) { - auto val = (*it).template get(); - if (val.error()) { - return false; - } - if (val.value() == target_val) { - return true; - } - } - return false; - }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecTermArrayVariableInField(TermExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - - AssertInfo(expr.terms_.size() == 1, - "element length in json array must be one"); - ExprValueType target_val = expr.terms_[0]; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - auto elem_func = [&target_val](const milvus::ArrayView& array) { - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - for (int i = 0; i < array.length(); i++) { - auto val = array.template get_data(i); - if (val == target_val) { - return true; - } - } - return false; - }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) - -> BitsetType { - if (expr_raw.is_in_field_) { - return ExecTermJsonVariableInField(expr_raw); - } else { - return ExecTermJsonFieldInVariable(expr_raw); - } -} - -template -auto -ExecExprVisitor::ExecTermVisitorImplTemplateArray(TermExpr& expr_raw) - -> BitsetType { - if (expr_raw.is_in_field_) { - return ExecTermArrayVariableInField(expr_raw); - } else { - return ExecTermArrayFieldInVariable(expr_raw); - } -} - -void -ExecExprVisitor::visit(TermExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.column_.field_id]; - AssertInfo(expr.column_.data_type == field_meta.get_data_type(), - "[ExecExprVisitor]DataType of expr isn't field_meta " - "data type "); - BitsetType res; - switch (expr.column_.data_type) { - case DataType::BOOL: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT8: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT16: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT32: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT64: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::FLOAT: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::DOUBLE: { - res = ExecTermVisitorImpl(expr); - break; - } - case DataType::VARCHAR: { - if (segment_.type() == SegmentType::Growing) { - res = ExecTermVisitorImpl(expr); - } else { - res = ExecTermVisitorImpl(expr); - } - break; - } - case DataType::JSON: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kBoolVal: - res = ExecTermVisitorImplTemplateJson(expr); - break; - case proto::plan::GenericValue::ValCase::kInt64Val: - res = ExecTermVisitorImplTemplateJson(expr); - break; - case proto::plan::GenericValue::ValCase::kFloatVal: - res = ExecTermVisitorImplTemplateJson(expr); - break; - case proto::plan::GenericValue::ValCase::kStringVal: - res = ExecTermVisitorImplTemplateJson(expr); - break; - case proto::plan::GenericValue::ValCase::VAL_NOT_SET: - res = ExecTermVisitorImplTemplateJson(expr); - break; - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - expr.val_case_); - } - break; - } - case DataType::ARRAY: { - switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kBoolVal: - res = ExecTermVisitorImplTemplateArray(expr); - break; - case proto::plan::GenericValue::ValCase::kInt64Val: - res = ExecTermVisitorImplTemplateArray(expr); - break; - case proto::plan::GenericValue::ValCase::kFloatVal: - res = ExecTermVisitorImplTemplateArray(expr); - break; - case proto::plan::GenericValue::ValCase::kStringVal: - res = ExecTermVisitorImplTemplateArray(expr); - break; - case proto::plan::GenericValue::ValCase::VAL_NOT_SET: - res = ExecTermVisitorImplTemplateArray(expr); - break; - default: - PanicInfo( - Unsupported, - fmt::format("unknown data type: {}", expr.val_case_)); - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - expr.column_.data_type); - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -void -ExecExprVisitor::visit(ExistsExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.column_.field_id]; - AssertInfo(expr.column_.data_type == field_meta.get_data_type(), - "[ExecExprVisitor]DataType of expr isn't field_meta data type"); - BitsetType res; - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - switch (expr.column_.data_type) { - case DataType::JSON: { - using Index = index::ScalarIndex; - auto index_func = [&](Index* index) { return TargetBitmap{}; }; - auto elem_func = [&](const milvus::Json& json) { - auto x = json.exist(pointer); - return x; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { - return false; - }; - res = ExecRangeVisitorImpl(expr.column_.field_id, - index_func, - elem_func, - default_skip_index_func); - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - expr.column_.data_type); - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -void -ExecExprVisitor::visit(AlwaysTrueExpr& expr) { - BitsetType res(row_count_); - res.set(); - bitset_opt_ = std::move(res); -} - -template -auto -ExecExprVisitor::ExecJsonContains(JsonContainsExpr& expr_raw) -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - std::unordered_set elements; - for (auto const& element : expr.elements_) { - elements.insert(element); - } - auto elem_func = [&elements, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - for (auto&& it : array) { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (elements.count(val.value()) > 0) { - return true; - } - } - return false; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecArrayContains(JsonContainsExpr& expr_raw) -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - AssertInfo(expr.column_.nested_path.size() == 0, - "[ExecArrayContains]nested path must be null"); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - std::unordered_set elements; - for (auto const& element : expr.elements_) { - elements.insert(element); - } - auto elem_func = [&elements](const milvus::ArrayView& array) { - for (int i = 0; i < array.length(); ++i) { - if (elements.count(array.template get_data(i)) > 0) { - return true; - } - } - return false; - }; - - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -auto -ExecExprVisitor::ExecJsonContainsArray(JsonContainsExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = - static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - auto& elements = expr.elements_; - auto elem_func = [&elements, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - for (auto&& it : array) { - auto val = it.get_array(); - if (val.error()) { - continue; - } - std::vector> - json_array; - json_array.reserve(val.count_elements()); - for (auto&& e : val) { - json_array.emplace_back(e); - } - for (auto const& element : elements) { - if (CompareTwoJsonArray(json_array, element)) { - return true; - } - } - } - return false; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -auto -ExecExprVisitor::ExecJsonContainsWithDiffType(JsonContainsExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = - static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - auto& elements = expr.elements_; - auto elem_func = [&elements, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - // Note: array can only be iterated once - for (auto&& it : array) { - for (auto const& element : elements) { - switch (element.val_case()) { - case proto::plan::GenericValue::kBoolVal: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.bool_val()) { - return true; - } - break; - } - case proto::plan::GenericValue::kInt64Val: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.int64_val()) { - return true; - } - break; - } - case proto::plan::GenericValue::kFloatVal: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.float_val()) { - return true; - } - break; - } - case proto::plan::GenericValue::kStringVal: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.string_val()) { - return true; - } - break; - } - case proto::plan::GenericValue::kArrayVal: { - auto val = it.get_array(); - if (val.error()) { - continue; - } - if (CompareTwoJsonArray(val, element.array_val())) { - return true; - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - element.val_case()); - } - } - } - return false; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecJsonContainsAll(JsonContainsExpr& expr_raw) -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - std::unordered_set elements; - for (auto const& element : expr.elements_) { - elements.insert(element); - } - // auto elements = expr.elements_; - auto elem_func = [&elements, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - std::unordered_set tmp_elements(elements); - // Note: array can only be iterated once - for (auto&& it : array) { - auto val = it.template get(); - if (val.error()) { - continue; - } - tmp_elements.erase(val.value()); - if (tmp_elements.size() == 0) { - return true; - } - } - return tmp_elements.size() == 0; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -template -auto -ExecExprVisitor::ExecArrayContainsAll(JsonContainsExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = static_cast&>(expr_raw); - AssertInfo(expr.column_.nested_path.size() == 0, - "[ExecArrayContains]nested path must be null"); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - using GetType = - std::conditional_t, - std::string_view, - ExprValueType>; - - std::unordered_set elements; - for (auto const& element : expr.elements_) { - elements.insert(element); - } - // auto elements = expr.elements_; - auto elem_func = [&elements](const milvus::ArrayView& array) { - std::unordered_set tmp_elements(elements); - // Note: array can only be iterated once - for (int i = 0; i < array.length(); ++i) { - tmp_elements.erase(array.template get_data(i)); - if (tmp_elements.size() == 0) { - return true; - } - } - return tmp_elements.size() == 0; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -auto -ExecExprVisitor::ExecJsonContainsAllArray(JsonContainsExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = - static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - auto& elements = expr.elements_; - - auto elem_func = [&elements, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - std::unordered_set exist_elements_index; - for (auto&& it : array) { - auto val = it.get_array(); - if (val.error()) { - continue; - } - std::vector> - json_array; - json_array.reserve(val.count_elements()); - for (auto&& e : val) { - json_array.emplace_back(e); - } - for (int index = 0; index < elements.size(); ++index) { - if (CompareTwoJsonArray(json_array, elements[index])) { - exist_elements_index.insert(index); - } - } - if (exist_elements_index.size() == elements.size()) { - return true; - } - } - return exist_elements_index.size() == elements.size(); - }; - - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -auto -ExecExprVisitor::ExecJsonContainsAllWithDiffType(JsonContainsExpr& expr_raw) - -> BitsetType { - using Index = index::ScalarIndex; - auto& expr = - static_cast&>(expr_raw); - auto pointer = milvus::Json::pointer(expr.column_.nested_path); - auto index_func = [](Index* index) { return TargetBitmap{}; }; - - auto elements = expr.elements_; - std::unordered_set elements_index; - int i = 0; - for (auto& element : expr.elements_) { - elements_index.insert(i); - i++; - } - auto elem_func = - [&elements, &elements_index, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; - } - std::unordered_set tmp_elements_index(elements_index); - for (auto&& it : array) { - int i = -1; - for (auto& element : elements) { - i++; - switch (element.val_case()) { - case proto::plan::GenericValue::kBoolVal: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.bool_val()) { - tmp_elements_index.erase(i); - } - break; - } - case proto::plan::GenericValue::kInt64Val: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.int64_val()) { - tmp_elements_index.erase(i); - } - break; - } - case proto::plan::GenericValue::kFloatVal: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.float_val()) { - tmp_elements_index.erase(i); - } - break; - } - case proto::plan::GenericValue::kStringVal: { - auto val = it.template get(); - if (val.error()) { - continue; - } - if (val.value() == element.string_val()) { - tmp_elements_index.erase(i); - } - break; - } - case proto::plan::GenericValue::kArrayVal: { - auto val = it.get_array(); - if (val.error()) { - continue; - } - if (CompareTwoJsonArray(val, element.array_val())) { - tmp_elements_index.erase(i); - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - element.val_case()); - } - if (tmp_elements_index.size() == 0) { - return true; - } - } - if (tmp_elements_index.size() == 0) { - return true; - } - } - return tmp_elements_index.size() == 0; - }; - auto default_skip_index_func = [&](const SkipIndex& skipIndex, - FieldId fieldId, - int64_t chunkId) { return false; }; - - return ExecRangeVisitorImpl( - expr.column_.field_id, index_func, elem_func, default_skip_index_func); -} - -void -ExecExprVisitor::visit(JsonContainsExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.column_.field_id]; - AssertInfo( - expr.column_.data_type == DataType::JSON || - expr.column_.data_type == DataType::ARRAY, - "[ExecExprVisitor]DataType of JsonContainsExpr isn't json data type"); - BitsetType res; - auto data_type = expr.column_.data_type; - switch (expr.op_) { - case proto::plan::JSONContainsExpr_JSONOp_Contains: - case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: { - if (IsArrayDataType(data_type)) { - switch (expr.val_case_) { - case proto::plan::GenericValue::kBoolVal: { - res = ExecArrayContains(expr); - break; - } - case proto::plan::GenericValue::kInt64Val: { - res = ExecArrayContains(expr); - break; - } - case proto::plan::GenericValue::kFloatVal: { - res = ExecArrayContains(expr); - break; - } - case proto::plan::GenericValue::kStringVal: { - res = ExecArrayContains(expr); - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - expr.val_case_); - } - } else { - if (expr.same_type_) { - switch (expr.val_case_) { - case proto::plan::GenericValue::kBoolVal: { - res = ExecJsonContains(expr); - break; - } - case proto::plan::GenericValue::kInt64Val: { - res = ExecJsonContains(expr); - break; - } - case proto::plan::GenericValue::kFloatVal: { - res = ExecJsonContains(expr); - break; - } - case proto::plan::GenericValue::kStringVal: { - res = ExecJsonContains(expr); - break; - } - case proto::plan::GenericValue::kArrayVal: { - res = ExecJsonContainsArray(expr); - break; - } - default: - PanicInfo(Unsupported, - "unsupported value type {}", - expr.val_case_); - } - } else { - res = ExecJsonContainsWithDiffType(expr); - } - } - break; - } - case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { - if (IsArrayDataType(data_type)) { - switch (expr.val_case_) { - case proto::plan::GenericValue::kBoolVal: { - res = ExecArrayContainsAll(expr); - break; - } - case proto::plan::GenericValue::kInt64Val: { - res = ExecArrayContainsAll(expr); - break; - } - case proto::plan::GenericValue::kFloatVal: { - res = ExecArrayContainsAll(expr); - break; - } - case proto::plan::GenericValue::kStringVal: { - res = ExecArrayContainsAll(expr); - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported data type {}", - expr.val_case_); - } - } else { - if (expr.same_type_) { - switch (expr.val_case_) { - case proto::plan::GenericValue::kBoolVal: { - res = ExecJsonContainsAll(expr); - break; - } - case proto::plan::GenericValue::kInt64Val: { - res = ExecJsonContainsAll(expr); - break; - } - case proto::plan::GenericValue::kFloatVal: { - res = ExecJsonContainsAll(expr); - break; - } - case proto::plan::GenericValue::kStringVal: { - res = ExecJsonContainsAll(expr); - break; - } - case proto::plan::GenericValue::kArrayVal: { - res = ExecJsonContainsAllArray(expr); - break; - } - default: - PanicInfo( - Unsupported, - fmt::format( - "unsupported value type {} in expression", - expr.val_case_)); - } - } else { - res = ExecJsonContainsAllWithDiffType(expr); - } - } - break; - } - default: - PanicInfo(DataTypeInvalid, - "unsupported json contains type {}", - expr.val_case_); - } - AssertInfo(res.size() == row_count_, - "[ExecExprVisitor]Size of results not equal row count"); - bitset_opt_ = std::move(res); -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp deleted file mode 100644 index f3c01beb30b5d..0000000000000 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include "query/generated/ExecPlanNodeVisitor.h" - -#include -#include - -#include "expr/ITypeExpr.h" -#include "query/PlanImpl.h" -#include "query/SubSearchResult.h" -#include "query/generated/ExecExprVisitor.h" -#include "query/Utils.h" -#include "segcore/SegmentGrowing.h" -#include "common/Json.h" -#include "log/Log.h" -#include "plan/PlanNode.h" -#include "exec/Task.h" -#include "segcore/SegmentInterface.h" -#include "query/groupby/SearchGroupByOperator.h" -namespace milvus::query { - -namespace impl { -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ -class ExecPlanNodeVisitor : PlanNodeVisitor { - public: - ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, - Timestamp timestamp, - const PlaceholderGroup& placeholder_group) - : segment_(segment), - timestamp_(timestamp), - placeholder_group_(placeholder_group) { - } - - SearchResult - get_moved_result(PlanNode& node) { - assert(!search_result_opt_.has_value()); - node.accept(*this); - assert(search_result_opt_.has_value()); - auto ret = std::move(search_result_opt_).value(); - search_result_opt_ = std::nullopt; - return ret; - } - - private: - template - void - VectorVisitorImpl(VectorPlanNode& node); - - private: - const segcore::SegmentInterface& segment_; - Timestamp timestamp_; - const PlaceholderGroup& placeholder_group_; - - SearchResultOpt search_result_opt_; -}; -} // namespace impl - -static SearchResult -empty_search_result(int64_t num_queries, SearchInfo& search_info) { - SearchResult final_result; - final_result.total_nq_ = num_queries; - final_result.unity_topK_ = 0; // no result - final_result.total_data_cnt_ = 0; - return final_result; -} - -void -ExecPlanNodeVisitor::ExecuteExprNode( - const std::shared_ptr& plannode, - const milvus::segcore::SegmentInternalInterface* segment, - int64_t active_count, - BitsetType& bitset_holder) { - bitset_holder.clear(); - LOG_DEBUG("plannode: {}, active_count: {}, timestamp: {}", - plannode->ToString(), - active_count, - timestamp_); - auto plan = plan::PlanFragment(plannode); - // TODO: get query id from proxy - auto query_context = std::make_shared( - DEAFULT_QUERY_ID, segment, active_count, timestamp_); - - auto task = - milvus::exec::Task::Create(DEFAULT_TASK_ID, plan, 0, query_context); - bool cache_offset_getted = false; - for (;;) { - auto result = task->Next(); - if (!result) { - break; - } - auto childrens = result->childrens(); - AssertInfo(childrens.size() == 1, - "expr result vector's children size not equal one"); - LOG_DEBUG("output result length:{}", childrens[0]->size()); - if (auto vec = std::dynamic_pointer_cast(childrens[0])) { - TargetBitmapView view(vec->GetRawData(), vec->size()); - AppendOneChunk(bitset_holder, view); - } else if (auto row = - std::dynamic_pointer_cast(childrens[0])) { - auto bit_vec = - std::dynamic_pointer_cast(row->child(0)); - TargetBitmapView view(bit_vec->GetRawData(), bit_vec->size()); - AppendOneChunk(bitset_holder, view); - - if (!cache_offset_getted) { - // offset cache only get once because not support iterator batch - auto cache_bits_vec = - std::dynamic_pointer_cast(row->child(1)); - TargetBitmapView view(cache_bits_vec->GetRawData(), - cache_bits_vec->size()); - // If get empty cached bits. mean no record hits in this segment - // no need to get next batch. - if (view.count() == 0) { - bitset_holder.resize(active_count); - task->RequestCancel(); - break; - } - cache_offset_getted = true; - } - } else { - PanicInfo(UnexpectedError, "expr return type not matched"); - } - } - // std::string s; - // boost::to_string(*bitset_holder, s); - // std::cout << bitset_holder->size() << " . " << s << std::endl; -} - -template -void -ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { - // TODO: optimize here, remove the dynamic cast - assert(!search_result_opt_.has_value()); - auto segment = - dynamic_cast(&segment_); - AssertInfo(segment, "support SegmentSmallIndex Only"); - SearchResult search_result; - auto& ph = placeholder_group_->at(0); - auto src_data = ph.get_blob(); - auto num_queries = ph.num_of_queries_; - - // TODO: add API to unify row_count - // auto row_count = segment->get_row_count(); - auto active_count = segment->get_active_count(timestamp_); - - // skip all calculation - if (active_count == 0) { - search_result_opt_ = - empty_search_result(num_queries, node.search_info_); - return; - } - - std::chrono::high_resolution_clock::time_point scalar_start = - std::chrono::high_resolution_clock::now(); - std::unique_ptr bitset_holder; - if (node.filter_plannode_.has_value()) { - BitsetType expr_res; - ExecuteExprNode( - node.filter_plannode_.value(), segment, active_count, expr_res); - bitset_holder = std::make_unique(expr_res.clone()); - bitset_holder->flip(); - } else { - bitset_holder = std::make_unique(active_count, false); - } - segment->mask_with_timestamps(*bitset_holder, timestamp_); - - segment->mask_with_delete(*bitset_holder, active_count, timestamp_); - std::chrono::high_resolution_clock::time_point scalar_end = - std::chrono::high_resolution_clock::now(); - double scalar_cost = - std::chrono::duration(scalar_end - scalar_start) - .count(); - monitor::internal_core_search_latency_scalar.Observe(scalar_cost); - - // if bitset_holder is all 1's, we got empty result - if (bitset_holder->all()) { - search_result_opt_ = - empty_search_result(num_queries, node.search_info_); - return; - } - - std::chrono::high_resolution_clock::time_point vector_start = - std::chrono::high_resolution_clock::now(); - BitsetView final_view = *bitset_holder; - segment->vector_search(node.search_info_, - src_data, - num_queries, - timestamp_, - final_view, - search_result); - search_result.total_data_cnt_ = final_view.size(); - if (search_result.vector_iterators_.has_value()) { - AssertInfo(search_result.vector_iterators_.value().size() == - search_result.total_nq_, - "Vector Iterators' count must be equal to total_nq_, Check " - "your code"); - std::vector group_by_values; - SearchGroupBy(search_result.vector_iterators_.value(), - node.search_info_, - group_by_values, - *segment, - search_result.seg_offsets_, - search_result.distances_, - search_result.topk_per_nq_prefix_sum_); - search_result.group_by_values_ = std::move(group_by_values); - search_result.group_size_ = node.search_info_.group_size_; - AssertInfo(search_result.seg_offsets_.size() == - search_result.group_by_values_.value().size(), - "Wrong state! search_result group_by_values_ size:{} is not " - "equal to search_result.seg_offsets.size:{}", - search_result.group_by_values_.value().size(), - search_result.seg_offsets_.size()); - } - search_result_opt_ = std::move(search_result); - std::chrono::high_resolution_clock::time_point vector_end = - std::chrono::high_resolution_clock::now(); - double vector_cost = - std::chrono::duration(vector_end - vector_start) - .count(); - monitor::internal_core_search_latency_vector.Observe(vector_cost); - - double total_cost = - std::chrono::duration(vector_end - scalar_start) - .count(); - double scalar_ratio = total_cost > 0.0 ? scalar_cost / total_cost : 0.0; - monitor::internal_core_search_latency_scalar_proportion.Observe( - scalar_ratio); -} - -std::unique_ptr -wrap_num_entities(int64_t cnt) { - auto retrieve_result = std::make_unique(); - DataArray arr; - arr.set_type(milvus::proto::schema::Int64); - auto scalar = arr.mutable_scalars(); - scalar->mutable_long_data()->mutable_data()->Add(cnt); - retrieve_result->field_data_ = {arr}; - retrieve_result->total_data_cnt_ = 0; - return retrieve_result; -} - -void -ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { - assert(!retrieve_result_opt_.has_value()); - auto segment = - dynamic_cast(&segment_); - AssertInfo(segment, "Support SegmentSmallIndex Only"); - RetrieveResult retrieve_result; - retrieve_result.total_data_cnt_ = 0; - - auto active_count = segment->get_active_count(timestamp_); - - if (active_count == 0 && !node.is_count_) { - retrieve_result_opt_ = std::move(retrieve_result); - return; - } - - if (active_count == 0 && node.is_count_) { - retrieve_result = *(wrap_num_entities(0)); - retrieve_result_opt_ = std::move(retrieve_result); - return; - } - - BitsetType bitset_holder; - // For case that retrieve by expression, bitset will be allocated when expression is being executed. - if (node.is_count_) { - bitset_holder.resize(active_count); - } - - std::vector cache_offsets; - if (node.filter_plannode_.has_value()) { - ExecuteExprNode(node.filter_plannode_.value(), - segment, - active_count, - bitset_holder); - bitset_holder.flip(); - } - - segment->mask_with_timestamps(bitset_holder, timestamp_); - - segment->mask_with_delete(bitset_holder, active_count, timestamp_); - // if bitset_holder is all 1's, we got empty result - if (bitset_holder.all() && !node.is_count_) { - retrieve_result_opt_ = std::move(retrieve_result); - return; - } - - if (node.is_count_) { - auto cnt = bitset_holder.size() - bitset_holder.count(); - retrieve_result = *(wrap_num_entities(cnt)); - retrieve_result.total_data_cnt_ = bitset_holder.size(); - retrieve_result_opt_ = std::move(retrieve_result); - return; - } - - retrieve_result.total_data_cnt_ = bitset_holder.size(); - auto results_pair = segment->find_first(node.limit_, bitset_holder); - retrieve_result.result_offsets_ = std::move(results_pair.first); - retrieve_result.has_more_result = results_pair.second; - retrieve_result_opt_ = std::move(retrieve_result); -} - -void -ExecPlanNodeVisitor::visit(FloatVectorANNS& node) { - VectorVisitorImpl(node); -} - -void -ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) { - VectorVisitorImpl(node); -} - -void -ExecPlanNodeVisitor::visit(Float16VectorANNS& node) { - VectorVisitorImpl(node); -} - -void -ExecPlanNodeVisitor::visit(BFloat16VectorANNS& node) { - VectorVisitorImpl(node); -} - -void -ExecPlanNodeVisitor::visit(SparseFloatVectorANNS& node) { - VectorVisitorImpl(node); -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExtractInfoExprVisitor.cpp b/internal/core/src/query/visitors/ExtractInfoExprVisitor.cpp deleted file mode 100644 index a6a7acd272110..0000000000000 --- a/internal/core/src/query/visitors/ExtractInfoExprVisitor.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include "query/Plan.h" -#include "query/generated/ExtractInfoExprVisitor.h" - -namespace milvus::query { - -namespace impl { -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ -class ExtractInfoExprVisitor : ExprVisitor { - public: - explicit ExtractInfoExprVisitor(ExtractedPlanInfo& plan_info) - : plan_info_(plan_info) { - } - - private: - ExtractedPlanInfo& plan_info_; -}; -} // namespace impl - -void -ExtractInfoExprVisitor::visit(LogicalUnaryExpr& expr) { - expr.child_->accept(*this); -} - -void -ExtractInfoExprVisitor::visit(LogicalBinaryExpr& expr) { - expr.left_->accept(*this); - expr.right_->accept(*this); -} - -void -ExtractInfoExprVisitor::visit(TermExpr& expr) { - plan_info_.add_involved_field(expr.column_.field_id); -} - -void -ExtractInfoExprVisitor::visit(UnaryRangeExpr& expr) { - plan_info_.add_involved_field(expr.column_.field_id); -} - -void -ExtractInfoExprVisitor::visit(BinaryRangeExpr& expr) { - plan_info_.add_involved_field(expr.column_.field_id); -} - -void -ExtractInfoExprVisitor::visit(CompareExpr& expr) { - plan_info_.add_involved_field(expr.left_field_id_); - plan_info_.add_involved_field(expr.right_field_id_); -} - -void -ExtractInfoExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { - plan_info_.add_involved_field(expr.column_.field_id); -} - -void -ExtractInfoExprVisitor::visit(ExistsExpr& expr) { - plan_info_.add_involved_field(expr.column_.field_id); -} - -void -ExtractInfoExprVisitor::visit(AlwaysTrueExpr& expr) { - // all is involved. -} - -void -ExtractInfoExprVisitor::visit(JsonContainsExpr& expr) { - plan_info_.add_involved_field(expr.column_.field_id); -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp deleted file mode 100644 index 2de8f92df6d38..0000000000000 --- a/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include "query/Plan.h" -#include "query/generated/ExtractInfoPlanNodeVisitor.h" -#include "query/generated/ExtractInfoExprVisitor.h" - -namespace milvus::query { - -namespace impl { -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ -class ExtractInfoPlanNodeVisitor : PlanNodeVisitor { - public: - explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) - : plan_info_(plan_info) { - } - - private: - ExtractedPlanInfo& plan_info_; -}; -} // namespace impl - -void -ExtractInfoPlanNodeVisitor::visit(FloatVectorANNS& node) { - plan_info_.add_involved_field(node.search_info_.field_id_); - if (node.predicate_.has_value()) { - ExtractInfoExprVisitor expr_visitor(plan_info_); - node.predicate_.value()->accept(expr_visitor); - } -} - -void -ExtractInfoPlanNodeVisitor::visit(BinaryVectorANNS& node) { - plan_info_.add_involved_field(node.search_info_.field_id_); - if (node.predicate_.has_value()) { - ExtractInfoExprVisitor expr_visitor(plan_info_); - node.predicate_.value()->accept(expr_visitor); - } -} - -void -ExtractInfoPlanNodeVisitor::visit(Float16VectorANNS& node) { - plan_info_.add_involved_field(node.search_info_.field_id_); - if (node.predicate_.has_value()) { - ExtractInfoExprVisitor expr_visitor(plan_info_); - node.predicate_.value()->accept(expr_visitor); - } -} - -void -ExtractInfoPlanNodeVisitor::visit(BFloat16VectorANNS& node) { - plan_info_.add_involved_field(node.search_info_.field_id_); - if (node.predicate_.has_value()) { - ExtractInfoExprVisitor expr_visitor(plan_info_); - node.predicate_.value()->accept(expr_visitor); - } -} - -void -ExtractInfoPlanNodeVisitor::visit(SparseFloatVectorANNS& node) { - plan_info_.add_involved_field(node.search_info_.field_id_); - if (node.predicate_.has_value()) { - ExtractInfoExprVisitor expr_visitor(plan_info_); - node.predicate_.value()->accept(expr_visitor); - } -} - -void -ExtractInfoPlanNodeVisitor::visit(RetrievePlanNode& node) { - // Assert(node.predicate_.has_value()); - ExtractInfoExprVisitor expr_visitor(plan_info_); - if (node.predicate_.has_value()) { - node.predicate_.value()->accept(expr_visitor); - } -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ShowExprVisitor.cpp b/internal/core/src/query/visitors/ShowExprVisitor.cpp deleted file mode 100644 index ba2320e820eab..0000000000000 --- a/internal/core/src/query/visitors/ShowExprVisitor.cpp +++ /dev/null @@ -1,373 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include - -#include "common/Types.h" -#include "query/ExprImpl.h" -#include "query/Plan.h" -#include "query/generated/ShowExprVisitor.h" - -namespace milvus::query { -using Json = nlohmann::json; - -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR -namespace impl { -class ShowExprNodeVisitor : ExprVisitor { - public: - using RetType = Json; - - public: - RetType - call_child(Expr& expr) { - assert(!ret_.has_value()); - expr.accept(*this); - assert(ret_.has_value()); - auto ret = std::move(ret_); - ret_ = std::nullopt; - return std::move(ret.value()); - } - - Json - combine(Json&& extra, UnaryExprBase& expr) { - auto result = std::move(extra); - result["child"] = call_child(*expr.child_); - return result; - } - - Json - combine(Json&& extra, BinaryExprBase& expr) { - auto result = std::move(extra); - result["left_child"] = call_child(*expr.left_); - result["right_child"] = call_child(*expr.right_); - return result; - } - - private: - std::optional ret_; -}; -} // namespace impl - -void -ShowExprVisitor::visit(LogicalUnaryExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - using OpType = LogicalUnaryExpr::OpType; - - // TODO: use magic_enum if available - AssertInfo(expr.op_type_ == OpType::LogicalNot, - "[ShowExprVisitor]Expr op type isn't LogicNot"); - auto op_name = "LogicalNot"; - - Json extra{ - {"expr_type", "BoolUnary"}, - {"op", op_name}, - }; - json_opt_ = this->combine(std::move(extra), expr); -} - -void -ShowExprVisitor::visit(LogicalBinaryExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - using OpType = LogicalBinaryExpr::OpType; - - // TODO: use magic_enum if available - auto op_name = [](OpType op) { - switch (op) { - case OpType::LogicalAnd: - return "LogicalAnd"; - case OpType::LogicalOr: - return "LogicalOr"; - case OpType::LogicalXor: - return "LogicalXor"; - default: - PanicInfo(OpTypeInvalid, - fmt::format("unsupported operation {}", op)); - } - }(expr.op_type_); - - Json extra{ - {"expr_type", "BoolBinary"}, - {"op", op_name}, - }; - json_opt_ = this->combine(std::move(extra), expr); -} - -template -static Json -TermExtract(const TermExpr& expr_raw) { - auto expr = dynamic_cast*>(&expr_raw); - AssertInfo(expr, "[ShowExprVisitor]TermExpr cast to TermExprImpl failed"); - return Json{expr->terms_}; -} - -void -ShowExprVisitor::visit(TermExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - AssertInfo(IsVectorDataType(expr.column_.data_type) == false, - "[ShowExprVisitor]Data type of expr isn't vector type"); - auto terms = [&] { - switch (expr.column_.data_type) { - case DataType::BOOL: - return TermExtract(expr); - case DataType::INT8: - return TermExtract(expr); - case DataType::INT16: - return TermExtract(expr); - case DataType::INT32: - return TermExtract(expr); - case DataType::INT64: - return TermExtract(expr); - case DataType::DOUBLE: - return TermExtract(expr); - case DataType::FLOAT: - return TermExtract(expr); - case DataType::JSON: - return TermExtract(expr); - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported type {}", expr.column_.data_type)); - } - }(); - - Json res{{"expr_type", "Term"}, - {"field_id", expr.column_.field_id.get()}, - {"data_type", GetDataTypeName(expr.column_.data_type)}, - {"terms", std::move(terms)}}; - - json_opt_ = res; -} - -template -static Json -UnaryRangeExtract(const UnaryRangeExpr& expr_raw) { - using proto::plan::OpType; - using proto::plan::OpType_Name; - auto expr = dynamic_cast*>(&expr_raw); - AssertInfo( - expr, - "[ShowExprVisitor]UnaryRangeExpr cast to UnaryRangeExprImpl failed"); - Json res{{"expr_type", "UnaryRange"}, - {"field_id", expr->column_.field_id.get()}, - {"data_type", GetDataTypeName(expr->column_.data_type)}, - {"op", OpType_Name(static_cast(expr->op_type_))}, - {"value", expr->value_}}; - return res; -} - -void -ShowExprVisitor::visit(UnaryRangeExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - AssertInfo(IsVectorDataType(expr.column_.data_type) == false, - "[ShowExprVisitor]Data type of expr isn't vector type"); - switch (expr.column_.data_type) { - case DataType::BOOL: - json_opt_ = UnaryRangeExtract(expr); - return; - - // see also: https://github.com/milvus-io/milvus/issues/23646. - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: - json_opt_ = UnaryRangeExtract(expr); - return; - - case DataType::DOUBLE: - json_opt_ = UnaryRangeExtract(expr); - return; - case DataType::FLOAT: - json_opt_ = UnaryRangeExtract(expr); - return; - case DataType::JSON: - json_opt_ = UnaryRangeExtract(expr); - return; - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported type {}", expr.column_.data_type)); - } -} - -template -static Json -BinaryRangeExtract(const BinaryRangeExpr& expr_raw) { - using proto::plan::OpType; - using proto::plan::OpType_Name; - auto expr = dynamic_cast*>(&expr_raw); - AssertInfo( - expr, - "[ShowExprVisitor]BinaryRangeExpr cast to BinaryRangeExprImpl failed"); - Json res{{"expr_type", "BinaryRange"}, - {"field_id", expr->column_.field_id.get()}, - {"data_type", GetDataTypeName(expr->column_.data_type)}, - {"lower_inclusive", expr->lower_inclusive_}, - {"upper_inclusive", expr->upper_inclusive_}, - {"lower_value", expr->lower_value_}, - {"upper_value", expr->upper_value_}}; - return res; -} - -void -ShowExprVisitor::visit(BinaryRangeExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - AssertInfo(IsVectorDataType(expr.column_.data_type) == false, - "[ShowExprVisitor]Data type of expr isn't vector type"); - switch (expr.column_.data_type) { - case DataType::BOOL: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::INT8: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::INT16: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::INT32: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::INT64: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::DOUBLE: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::FLOAT: - json_opt_ = BinaryRangeExtract(expr); - return; - case DataType::JSON: - json_opt_ = BinaryRangeExtract(expr); - return; - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported type {}", expr.column_.data_type)); - } -} - -void -ShowExprVisitor::visit(CompareExpr& expr) { - using proto::plan::OpType; - using proto::plan::OpType_Name; - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - - Json res{{"expr_type", "Compare"}, - {"left_field_id", expr.left_field_id_.get()}, - {"left_data_type", GetDataTypeName(expr.left_data_type_)}, - {"right_field_id", expr.right_field_id_.get()}, - {"right_data_type", GetDataTypeName(expr.right_data_type_)}, - {"op", OpType_Name(static_cast(expr.op_type_))}}; - json_opt_ = res; -} - -template -static Json -BinaryArithOpEvalRangeExtract(const BinaryArithOpEvalRangeExpr& expr_raw) { - using proto::plan::ArithOpType; - using proto::plan::ArithOpType_Name; - using proto::plan::OpType; - using proto::plan::OpType_Name; - - auto expr = - dynamic_cast*>(&expr_raw); - AssertInfo(expr, - "[ShowExprVisitor]BinaryArithOpEvalRangeExpr cast to " - "BinaryArithOpEvalRangeExprImpl failed"); - - Json res{{"expr_type", "BinaryArithOpEvalRange"}, - {"field_offset", expr->column_.field_id.get()}, - {"data_type", GetDataTypeName(expr->column_.data_type)}, - {"arith_op", - ArithOpType_Name(static_cast(expr->arith_op_))}, - {"right_operand", expr->right_operand_}, - {"op", OpType_Name(static_cast(expr->op_type_))}, - {"value", expr->value_}}; - return res; -} - -void -ShowExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - AssertInfo(IsVectorDataType(expr.column_.data_type) == false, - "[ShowExprVisitor]Data type of expr isn't vector type"); - switch (expr.column_.data_type) { - // see also: https://github.com/milvus-io/milvus/issues/23646. - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: - json_opt_ = BinaryArithOpEvalRangeExtract(expr); - return; - - case DataType::DOUBLE: - json_opt_ = BinaryArithOpEvalRangeExtract(expr); - return; - case DataType::FLOAT: - json_opt_ = BinaryArithOpEvalRangeExtract(expr); - return; - case DataType::JSON: - json_opt_ = BinaryArithOpEvalRangeExtract(expr); - return; - default: - PanicInfo( - DataTypeInvalid, - fmt::format("unsupported type {}", expr.column_.data_type)); - } -} - -void -ShowExprVisitor::visit(ExistsExpr& expr) { - using proto::plan::OpType; - using proto::plan::OpType_Name; - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - - Json res{{"expr_type", "Exists"}, - {"field_id", expr.column_.field_id.get()}, - {"data_type", expr.column_.data_type}, - {"nested_path", expr.column_.nested_path}}; - json_opt_ = res; -} - -void -ShowExprVisitor::visit(AlwaysTrueExpr& expr) { - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - Json res{{"expr_type", "AlwaysTrue"}}; - json_opt_ = res; -} - -void -ShowExprVisitor::visit(JsonContainsExpr& expr) { - using proto::plan::OpType; - using proto::plan::OpType_Name; - AssertInfo(!json_opt_.has_value(), - "[ShowExprVisitor]Ret json already has value before visit"); - - Json res{{"expr_type", "JsonContains"}, - {"field_id", expr.column_.field_id.get()}, - {"data_type", expr.column_.data_type}, - {"nested_path", expr.column_.nested_path}, - {"same_type", expr.same_type_}, - {"op", expr.op_}, - {"val_case", expr.val_case_}}; - json_opt_ = res; -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp deleted file mode 100644 index 6b438cbcbf09b..0000000000000 --- a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include - -#include "common/EasyAssert.h" -#include "common/Json.h" -#include "query/generated/ShowExprVisitor.h" -#include "query/generated/ShowPlanNodeVisitor.h" - -namespace milvus::query { -#if 0 -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR -class ShowPlanNodeVisitorImpl : PlanNodeVisitor { - public: - using RetType = nlohmann::json; - - public: - RetType - call_child(PlanNode& node) { - assert(!ret_.has_value()); - node.accept(*this); - assert(ret_.has_value()); - auto ret = std::move(ret_); - ret_ = std::nullopt; - return std::move(ret.value()); - } - - private: - std::optional ret_; -}; -#endif - -using Json = nlohmann::json; - -static std::string -get_indent(int indent) { - return std::string(10, '\t'); -} - -void -ShowPlanNodeVisitor::visit(FloatVectorANNS& node) { - // std::vector data(node.data_.get(), node.data_.get() + node.total_nq_ * node.dim_); - assert(!ret_); - auto& info = node.search_info_; - Json json_body{ - {"node_type", "FloatVectorANNS"}, // - {"metric_type", info.metric_type_}, // - {"field_id_", info.field_id_.get()}, // - {"topk", info.topk_}, // - {"search_params", info.search_params_}, // - {"placeholder_tag", node.placeholder_tag_}, // - }; - if (node.predicate_.has_value()) { - ShowExprVisitor expr_show; - AssertInfo(node.predicate_.value(), - "[ShowPlanNodeVisitor]Can't get value from node predict"); - json_body["predicate"] = - expr_show.call_child(node.predicate_->operator*()); - } else { - json_body["predicate"] = "None"; - } - ret_ = json_body; -} - -void -ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) { - assert(!ret_); - auto& info = node.search_info_; - Json json_body{ - {"node_type", "BinaryVectorANNS"}, // - {"metric_type", info.metric_type_}, // - {"field_id_", info.field_id_.get()}, // - {"topk", info.topk_}, // - {"search_params", info.search_params_}, // - {"placeholder_tag", node.placeholder_tag_}, // - }; - if (node.predicate_.has_value()) { - ShowExprVisitor expr_show; - AssertInfo(node.predicate_.value(), - "[ShowPlanNodeVisitor]Can't get value from node predict"); - json_body["predicate"] = - expr_show.call_child(node.predicate_->operator*()); - } else { - json_body["predicate"] = "None"; - } - ret_ = json_body; -} - -void -ShowPlanNodeVisitor::visit(Float16VectorANNS& node) { - assert(!ret_); - auto& info = node.search_info_; - Json json_body{ - {"node_type", "Float16VectorANNS"}, // - {"metric_type", info.metric_type_}, // - {"field_id_", info.field_id_.get()}, // - {"topk", info.topk_}, // - {"search_params", info.search_params_}, // - {"placeholder_tag", node.placeholder_tag_}, // - }; - if (node.predicate_.has_value()) { - ShowExprVisitor expr_show; - AssertInfo(node.predicate_.value(), - "[ShowPlanNodeVisitor]Can't get value from node predict"); - json_body["predicate"] = - expr_show.call_child(node.predicate_->operator*()); - } else { - json_body["predicate"] = "None"; - } - ret_ = json_body; -} - -void -ShowPlanNodeVisitor::visit(BFloat16VectorANNS& node) { - assert(!ret_); - auto& info = node.search_info_; - Json json_body{ - {"node_type", "BFloat16VectorANNS"}, // - {"metric_type", info.metric_type_}, // - {"field_id_", info.field_id_.get()}, // - {"topk", info.topk_}, // - {"search_params", info.search_params_}, // - {"placeholder_tag", node.placeholder_tag_}, // - }; - if (node.predicate_.has_value()) { - ShowExprVisitor expr_show; - AssertInfo(node.predicate_.value(), - "[ShowPlanNodeVisitor]Can't get value from node predict"); - json_body["predicate"] = - expr_show.call_child(node.predicate_->operator*()); - } else { - json_body["predicate"] = "None"; - } - ret_ = json_body; -} - -void -ShowPlanNodeVisitor::visit(SparseFloatVectorANNS& node) { - assert(!ret_); - auto& info = node.search_info_; - Json json_body{ - {"node_type", "SparseFloatVectorANNS"}, // - {"metric_type", info.metric_type_}, // - {"field_id_", info.field_id_.get()}, // - {"topk", info.topk_}, // - {"search_params", info.search_params_}, // - {"placeholder_tag", node.placeholder_tag_}, // - }; - if (node.predicate_.has_value()) { - ShowExprVisitor expr_show; - AssertInfo(node.predicate_.value(), - "[ShowPlanNodeVisitor]Can't get value from node predict"); - json_body["predicate"] = - expr_show.call_child(node.predicate_->operator*()); - } else { - json_body["predicate"] = "None"; - } - ret_ = json_body; -} - -void -ShowPlanNodeVisitor::visit(RetrievePlanNode& node) { -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/VerifyExprVisitor.cpp b/internal/core/src/query/visitors/VerifyExprVisitor.cpp deleted file mode 100644 index a3c158cdb9414..0000000000000 --- a/internal/core/src/query/visitors/VerifyExprVisitor.cpp +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include "query/generated/VerifyExprVisitor.h" - -namespace milvus::query { -void -VerifyExprVisitor::visit(LogicalUnaryExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(LogicalBinaryExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(TermExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(UnaryRangeExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(BinaryRangeExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(CompareExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(ExistsExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(AlwaysTrueExpr& expr) { - // TODO -} - -void -VerifyExprVisitor::visit(JsonContainsExpr& expr) { - // TODO -} - -} // namespace milvus::query diff --git a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp deleted file mode 100644 index 2612e37daaa38..0000000000000 --- a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include "query/generated/VerifyPlanNodeVisitor.h" - -namespace milvus::query { - -namespace impl { -// THIS CONTAINS EXTRA BODY FOR VISITOR -// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ -class VerifyPlanNodeVisitor : PlanNodeVisitor { - public: - using RetType = SearchResult; - VerifyPlanNodeVisitor() = default; - - private: - std::optional ret_; -}; -} // namespace impl - -void -VerifyPlanNodeVisitor::visit(FloatVectorANNS&) { -} - -void -VerifyPlanNodeVisitor::visit(BinaryVectorANNS&) { -} - -void -VerifyPlanNodeVisitor::visit(Float16VectorANNS&) { -} - -void -VerifyPlanNodeVisitor::visit(BFloat16VectorANNS&) { -} - -void -VerifyPlanNodeVisitor::visit(SparseFloatVectorANNS&) { -} - -void -VerifyPlanNodeVisitor::visit(RetrievePlanNode&) { -} - -} // namespace milvus::query diff --git a/internal/core/src/segcore/DeletedRecord.h b/internal/core/src/segcore/DeletedRecord.h index 7238e5fc10fac..82bf01fe83f7d 100644 --- a/internal/core/src/segcore/DeletedRecord.h +++ b/internal/core/src/segcore/DeletedRecord.h @@ -89,7 +89,7 @@ class DeletedRecord { } void - Query(BitsetType& bitset, int64_t insert_barrier, Timestamp timestamp) { + Query(BitsetTypeView& bitset, int64_t insert_barrier, Timestamp timestamp) { Assert(bitset.size() == insert_barrier); // TODO: add cache to bitset if (deleted_pairs_.size() == 0) { diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index c6cc0fa35b86d..e2a4b8fb642a1 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -43,7 +43,7 @@ SegmentGrowingImpl::PreInsert(int64_t size) { } void -SegmentGrowingImpl::mask_with_delete(BitsetType& bitset, +SegmentGrowingImpl::mask_with_delete(BitsetTypeView& bitset, int64_t ins_barrier, Timestamp timestamp) const { deleted_record_.Query(bitset, ins_barrier, timestamp); @@ -773,7 +773,7 @@ SegmentGrowingImpl::get_active_count(Timestamp ts) const { } void -SegmentGrowingImpl::mask_with_timestamps(BitsetType& bitset_chunk, +SegmentGrowingImpl::mask_with_timestamps(BitsetTypeView& bitset_chunk, Timestamp timestamp) const { // DO NOTHING } diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 734ef83bc8688..e16aa3e4d3490 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -256,7 +256,7 @@ class SegmentGrowingImpl : public SegmentGrowing { } void - mask_with_timestamps(BitsetType& bitset_chunk, + mask_with_timestamps(BitsetTypeView& bitset_chunk, Timestamp timestamp) const override; void @@ -272,7 +272,7 @@ class SegmentGrowingImpl : public SegmentGrowing { public: void - mask_with_delete(BitsetType& bitset, + mask_with_delete(BitsetTypeView& bitset, int64_t ins_barrier, Timestamp timestamp) const override; diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index e62c378d97786..c7ed889993596 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -18,7 +18,7 @@ #include "common/SystemProperty.h" #include "common/Tracer.h" #include "common/Types.h" -#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" namespace milvus::segcore { @@ -238,6 +238,14 @@ SegmentInternalInterface::get_real_count() const { #endif auto plan = std::make_unique(get_schema()); plan->plan_node_ = std::make_unique(); + milvus::plan::PlanNodePtr plannode; + std::vector sources; + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId()); + sources = std::vector{plannode}; + plannode = std::make_shared( + milvus::plan::GetNextPlanNodeId(), sources); + plan->plan_node_->plannodes_ = plannode; plan->plan_node_->is_count_ = true; auto res = Retrieve(nullptr, plan.get(), MAX_TIMESTAMP, INT64_MAX, false); AssertInfo(res->fields_data().size() == 1, diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index 845dbade5cdbf..b11049334d2b0 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -271,7 +271,7 @@ class SegmentInternalInterface : public SegmentInterface { SearchResult& output) const = 0; virtual void - mask_with_delete(BitsetType& bitset, + mask_with_delete(BitsetTypeView& bitset, int64_t ins_barrier, Timestamp timestamp) const = 0; @@ -285,7 +285,7 @@ class SegmentInternalInterface : public SegmentInterface { // bitset 1 means not hit. 0 means hit. virtual void - mask_with_timestamps(BitsetType& bitset_chunk, + mask_with_timestamps(BitsetTypeView& bitset_chunk, Timestamp timestamp) const = 0; // count of chunks diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index e8878ea6979bb..762970e6b1bd3 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -733,7 +733,7 @@ SegmentSealedImpl::get_schema() const { } void -SegmentSealedImpl::mask_with_delete(BitsetType& bitset, +SegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset, int64_t ins_barrier, Timestamp timestamp) const { deleted_record_.Query(bitset, ins_barrier, timestamp); @@ -1569,7 +1569,7 @@ SegmentSealedImpl::get_active_count(Timestamp ts) const { } void -SegmentSealedImpl::mask_with_timestamps(BitsetType& bitset_chunk, +SegmentSealedImpl::mask_with_timestamps(BitsetTypeView& bitset_chunk, Timestamp timestamp) const { // TODO change the AssertInfo(insert_record_.timestamps_.num_chunk() == 1, diff --git a/internal/core/src/segcore/SegmentSealedImpl.h b/internal/core/src/segcore/SegmentSealedImpl.h index d739d69d95ccf..7546cbaf2dc8e 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.h +++ b/internal/core/src/segcore/SegmentSealedImpl.h @@ -264,7 +264,7 @@ class SegmentSealedImpl : public SegmentSealed { } void - mask_with_timestamps(BitsetType& bitset_chunk, + mask_with_timestamps(BitsetTypeView& bitset_chunk, Timestamp timestamp) const override; void @@ -276,7 +276,7 @@ class SegmentSealedImpl : public SegmentSealed { SearchResult& output) const override; void - mask_with_delete(BitsetType& bitset, + mask_with_delete(BitsetTypeView& bitset, int64_t ins_barrier, Timestamp timestamp) const override; diff --git a/internal/core/unittest/test_always_true_expr.cpp b/internal/core/unittest/test_always_true_expr.cpp index ab0e03f1f3edf..2d54525e8a306 100644 --- a/internal/core/unittest/test_always_true_expr.cpp +++ b/internal/core/unittest/test_always_true_expr.cpp @@ -16,12 +16,12 @@ #include #include "common/Types.h" -#include "query/Expr.h" -#include "query/generated/ExecExprVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "test_utils/DataGen.h" +#include "test_utils/GenExprProto.h" #include "expr/ITypeExpr.h" #include "plan/PlanNode.h" +#include "query/ExecPlanNodeVisitor.h" class ExprAlwaysTrueTest : public ::testing::TestWithParam {}; @@ -61,12 +61,10 @@ TEST_P(ExprAlwaysTrueTest, AlwaysTrue) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); auto expr = std::make_shared(); BitsetType final; - std::shared_ptr plan = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + auto plan = milvus::test::CreateRetrievePlanByExpr(expr); + final = ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/internal/core/unittest/test_array_expr.cpp b/internal/core/unittest/test_array_expr.cpp index ec503d8952274..3e5621d5e2968 100644 --- a/internal/core/unittest/test_array_expr.cpp +++ b/internal/core/unittest/test_array_expr.cpp @@ -21,11 +21,9 @@ #include "index/IndexFactory.h" #include "pb/plan.pb.h" #include "plan/PlanNode.h" -#include "query/Expr.h" -#include "query/ExprImpl.h" #include "query/Plan.h" #include "query/PlanNode.h" -#include "query/generated/ExecExprVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "simdjson/padded_string.h" #include "test_utils/DataGen.h" @@ -598,7 +596,6 @@ TEST(Expr, TestArrayRange) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, array_type, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -607,10 +604,11 @@ TEST(Expr, TestArrayRange) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -715,7 +713,6 @@ TEST(Expr, TestArrayEqual) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -724,10 +721,11 @@ TEST(Expr, TestArrayEqual) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -891,7 +889,6 @@ TEST(Expr, TestArrayContains) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true, true}, {}}, {{false, false}, {}}}; @@ -921,7 +918,8 @@ TEST(Expr, TestArrayContains) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -975,7 +973,8 @@ TEST(Expr, TestArrayContains) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1019,7 +1018,8 @@ TEST(Expr, TestArrayContains) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1073,7 +1073,8 @@ TEST(Expr, TestArrayContains) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1118,7 +1119,8 @@ TEST(Expr, TestArrayContains) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1170,7 +1172,8 @@ TEST(Expr, TestArrayContains) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1237,7 +1240,6 @@ TEST(Expr, TestArrayBinaryArith) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vectorplan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2168,7 +2171,6 @@ TEST(Expr, TestArrayStringMatch) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> prefix_testcases{ {OpType::PrefixMatch, @@ -2206,7 +2208,8 @@ TEST(Expr, TestArrayStringMatch) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -2268,7 +2271,6 @@ TEST(Expr, TestArrayInTerm) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vectorplan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2451,7 +2454,6 @@ TEST(Expr, TestTermInArray) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); struct TermTestCases { std::vector values; @@ -2499,7 +2501,8 @@ TEST(Expr, TestTermInArray) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) diff --git a/internal/core/unittest/test_array_inverted_index.cpp b/internal/core/unittest/test_array_inverted_index.cpp index f0d59022bf43f..1c543b7711e9b 100644 --- a/internal/core/unittest/test_array_inverted_index.cpp +++ b/internal/core/unittest/test_array_inverted_index.cpp @@ -19,7 +19,7 @@ #include "test_utils/DataGen.h" #include "test_utils/GenExprProto.h" #include "query/PlanProto.h" -#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" using namespace milvus; using namespace milvus::query; @@ -156,9 +156,8 @@ TYPED_TEST_P(ArrayInvertedIndexTest, ArrayContainsAny) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(this->seg_.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + final = ExecuteQueryExpr(parsed, segpromote, this->N_, MAX_TIMESTAMP); std::unordered_set elems(this->vec_of_array_[0].begin(), this->vec_of_array_[0].end()); @@ -205,9 +204,8 @@ TYPED_TEST_P(ArrayInvertedIndexTest, ArrayContainsAll) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(this->seg_.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + final = ExecuteQueryExpr(parsed, segpromote, this->N_, MAX_TIMESTAMP); std::unordered_set elems(this->vec_of_array_[0].begin(), this->vec_of_array_[0].end()); @@ -262,9 +260,8 @@ TYPED_TEST_P(ArrayInvertedIndexTest, ArrayEqual) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(this->seg_.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + final = ExecuteQueryExpr(parsed, segpromote, this->N_, MAX_TIMESTAMP); auto ref = [this](size_t offset) -> bool { if (this->vec_of_array_[0].size() != diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index d58150e28fe0a..f5e79f4a33276 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -29,7 +29,6 @@ #include "index/IndexFactory.h" #include "knowhere/comp/index_param.h" #include "pb/plan.pb.h" -#include "query/ExprImpl.h" #include "segcore/Collection.h" #include "segcore/reduce/Reduce.h" #include "segcore/reduce_c.h" @@ -39,7 +38,7 @@ #include "test_utils/PbHelper.h" #include "test_utils/indexbuilder_test_utils.h" #include "test_utils/storage_test_utils.h" -#include "query/generated/ExecExprVisitor.h" +#include "test_utils/GenExprProto.h" #include "expr/ITypeExpr.h" #include "plan/PlanNode.h" #include "exec/expression/Expr.h" @@ -49,6 +48,7 @@ namespace chrono = std::chrono; using namespace milvus; +using namespace milvus::test; using namespace milvus::index; using namespace milvus::segcore; using namespace milvus::tracer; @@ -640,8 +640,7 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { FieldId(101), DataType::INT64, std::vector()), retrive_pks); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; auto max_ts = dataset.timestamps_[N - 1] + 10; @@ -667,8 +666,7 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), retrive_pks); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); suc = query_result->ParseFromArray(retrieve_result->proto_blob, @@ -754,8 +752,7 @@ TEST(CApiTest, MultiDeleteSealedSegment) { retrive_pks); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; auto max_ts = dataset.timestamps_[N - 1] + 10; @@ -781,8 +778,7 @@ TEST(CApiTest, MultiDeleteSealedSegment) { milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), retrive_pks); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); suc = query_result->ParseFromArray(retrieve_result->proto_blob, @@ -875,8 +871,7 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) { retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -958,8 +953,7 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) { FieldId(101), DataType::INT64, std::vector()), retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -1137,8 +1131,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) { FieldId(101), DataType::INT64, std::vector()), retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -1235,8 +1228,7 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) { FieldId(101), DataType::INT64, std::vector()), retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -1433,8 +1425,7 @@ TEST(CApiTest, RetrieveTestWithExpr) { FieldId(101), DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -4438,8 +4429,7 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) { milvus::expr::ColumnInfo( i64_fid, DataType::INT64, std::vector()), retrive_row_ids); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = CreateRetrievePlanByExpr(term_expr); std::vector target_field_ids; // retrieve value @@ -4754,49 +4744,6 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_L2) { DeleteSegment(segment); } -TEST(CApiTest, AssembeChunkTest) { - TargetBitmap chunk(1000); - for (size_t i = 0; i < 1000; ++i) { - chunk[i] = (i % 2 == 0); - } - BitsetType result; - milvus::query::AppendOneChunk(result, chunk); - // std::string s; - // boost::to_string(result, s); - // std::cout << s << std::endl; - int index = 0; - for (size_t i = 0; i < 1000; i++) { - ASSERT_EQ(result[index++], chunk[i]) << i; - } - - chunk = TargetBitmap(934); - for (int i = 0; i < 934; ++i) { - chunk[i] = (i % 2 == 0); - } - milvus::query::AppendOneChunk(result, chunk); - for (size_t i = 0; i < 934; i++) { - ASSERT_EQ(result[index++], chunk[i]) << i; - } - - chunk = TargetBitmap(62); - for (int i = 0; i < 62; ++i) { - chunk[i] = (i % 2 == 0); - } - milvus::query::AppendOneChunk(result, chunk); - for (size_t i = 0; i < 62; i++) { - ASSERT_EQ(result[index++], chunk[i]) << i; - } - - chunk = TargetBitmap(105); - for (int i = 0; i < 105; ++i) { - chunk[i] = (i % 2 == 0); - } - milvus::query::AppendOneChunk(result, chunk); - for (size_t i = 0; i < 105; i++) { - ASSERT_EQ(result[index++], chunk[i]) << i; - } -} - std::vector search_id(const BitsetType& bitset, Timestamp* timestamps, @@ -4878,31 +4825,6 @@ TEST(CApiTest, SearchIdTest) { } } -TEST(CApiTest, AssembeChunkPerfTest) { - TargetBitmap chunk(100000000); - for (size_t i = 0; i < 100000000; ++i) { - chunk[i] = (i % 2 == 0); - } - BitsetType result; - // while (true) { - std::cout << "start test" << std::endl; - auto start = std::chrono::steady_clock::now(); - milvus::query::AppendOneChunk(result, chunk); - std::cout << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << "us" << std::endl; - int index = 0; - for (size_t i = 0; i < 1000; i++) { - ASSERT_EQ(result[index++], chunk[i]) << i; - } - // } - // std::string s; - // boost::to_string(result, s); - // std::cout << s << std::endl; -} - TEST(CApiTest, Indexing_Without_Predicate_float16) { // insert data to segment constexpr auto TOPK = 5; diff --git a/internal/core/unittest/test_chunk_vector.cpp b/internal/core/unittest/test_chunk_vector.cpp index b0d67663e4df1..2e5991d4c57cf 100644 --- a/internal/core/unittest/test_chunk_vector.cpp +++ b/internal/core/unittest/test_chunk_vector.cpp @@ -18,7 +18,6 @@ #include "pb/schema.pb.h" #include "test_utils/DataGen.h" #include "query/Plan.h" -#include "query/generated/ExecExprVisitor.h" using namespace milvus::segcore; using namespace milvus; diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp index 1d871cf7de9f9..884376ee63cb5 100644 --- a/internal/core/unittest/test_exec.cpp +++ b/internal/core/unittest/test_exec.cpp @@ -17,12 +17,8 @@ #include #include -#include "query/Expr.h" -#include "query/PlanImpl.h" #include "query/PlanNode.h" -#include "query/generated/ExecPlanNodeVisitor.h" -#include "query/generated/ExprVisitor.h" -#include "query/generated/ShowPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/SegmentSealed.h" #include "test_utils/AssertUtils.h" #include "test_utils/DataGen.h" diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 302f86aac1623..b03ce5a09b509 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -22,13 +22,10 @@ #include "common/Json.h" #include "common/Types.h" #include "pb/plan.pb.h" -#include "query/Expr.h" -#include "query/ExprImpl.h" #include "query/Plan.h" #include "query/PlanNode.h" #include "query/PlanProto.h" -#include "query/generated/ShowPlanNodeVisitor.h" -#include "query/generated/ExecExprVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "simdjson/padded_string.h" #include "segcore/segment_c.h" @@ -125,7 +122,6 @@ TEST_P(ExprTest, Range) { schema->AddDebugField("age", DataType::INT32); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - ShowPlanNodeVisitor shower; Assert(plan->tag2field_.at("$0") == schema->get_field_id(FieldName("fakevec"))); } @@ -191,13 +187,6 @@ TEST_P(ExprTest, ShowExecutor) { info.metric_type_ = metric_type; info.topk_ = 20; info.field_id_ = field_id; - node->predicate_ = std::nullopt; - ShowPlanNodeVisitor show_visitor; - PlanNodePtr base(node.release()); - auto res = show_visitor.call_child(*base); - auto dup = res; - dup["data"] = "...collased..."; - std::cout << dup.dump(4); } TEST_P(ExprTest, TestRange) { @@ -365,7 +354,6 @@ TEST_P(ExprTest, TestRange) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -373,12 +361,12 @@ TEST_P(ExprTest, TestRange) { auto plan_str = translate_text_plan_with_metric_type(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -434,7 +422,6 @@ TEST_P(ExprTest, TestBinaryRangeJSON) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](int64_t value) { int64_t lower = testcase.lower, upper = testcase.upper; @@ -447,7 +434,6 @@ TEST_P(ExprTest, TestBinaryRangeJSON) { return lower <= value && value <= upper; }; auto pointer = milvus::Json::pointer(testcase.nested_path); - RetrievePlanNode plan; milvus::proto::plan::GenericValue lower_val; lower_val.set_int64_val(testcase.lower); milvus::proto::plan::GenericValue upper_val; @@ -459,11 +445,10 @@ TEST_P(ExprTest, TestBinaryRangeJSON) { upper_val, testcase.lower_inclusive, testcase.upper_inclusive); - BitsetType final; - plan.filter_plannode_ = + auto plannode = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode( - plan.filter_plannode_.value(), seg_promote, N * num_iters, final); + BitsetType final = ExecuteQueryExpr( + plannode, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -525,19 +510,17 @@ TEST_P(ExprTest, TestExistsJson) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](bool value) { return value; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); auto expr = std::make_shared(milvus::expr::ColumnInfo( json_fid, DataType::JSON, testcase.nested_path)); BitsetType final; - plan.filter_plannode_ = + auto plannode = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode( - plan.filter_plannode_.value(), seg_promote, N * num_iters, final); + final = ExecuteQueryExpr( + plannode, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -621,7 +604,7 @@ TEST_P(ExprTest, TestUnaryRangeJson) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + std::vector ops{ OpType::Equal, OpType::NotEqual, @@ -672,10 +655,10 @@ TEST_P(ExprTest, TestUnaryRangeJson) { json_fid, DataType::JSON, testcase.nested_path), op, value); - BitsetType final; auto plan = std::make_shared( DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + auto final = ExecuteQueryExpr( + plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); EXPECT_EQ(final.size(), N * num_iters); @@ -738,7 +721,8 @@ TEST_P(ExprTest, TestUnaryRangeJson) { BitsetType final; auto plan = std::make_shared( DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = ExecuteQueryExpr( + plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -786,7 +770,6 @@ TEST_P(ExprTest, TestTermJson) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](int64_t value) { std::unordered_set term_set(testcase.term.begin(), @@ -807,7 +790,8 @@ TEST_P(ExprTest, TestTermJson) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -892,7 +876,6 @@ TEST_P(ExprTest, TestTerm) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -901,10 +884,11 @@ TEST_P(ExprTest, TestTerm) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -979,7 +963,6 @@ TEST_P(ExprTest, TestCompare) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -988,10 +971,11 @@ TEST_P(ExprTest, TestCompare) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1074,7 +1058,6 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { load_index_info.index = std::move(age64_index); seg->LoadIndex(load_index_info); - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % clause % @@ -1086,8 +1069,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndex) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -1139,7 +1125,6 @@ TEST_P(ExprTest, TestCompareExpr) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { switch (type) { case DataType::BOOL: { @@ -1227,25 +1212,25 @@ TEST_P(ExprTest, TestCompareExpr) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT8); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT16); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT32); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::INT64); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::FLOAT); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); expr = build_expr(DataType::DOUBLE); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); std::cout << "end compare test" << std::endl; } @@ -1482,14 +1467,13 @@ TEST(Expr, TestExprPerformance) { }; auto test_case_base = [=, &seg](expr::TypedExprPtr expr) { - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); std::cout << expr->ToString() << std::endl; BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); for (int i = 0; i < 100; i++) { - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); } std::cout << "cost: " @@ -1647,11 +1631,11 @@ TEST_P(ExprTest, test_term_pk) { } auto expr = std::make_shared( expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < 10; ++i) { EXPECT_EQ(final[i], true); @@ -1668,74 +1652,13 @@ TEST_P(ExprTest, test_term_pk) { expr = std::make_shared( expr::ColumnInfo(int64_fid, DataType::INT64), retrieve_ints); plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { EXPECT_EQ(final[i], false); } } -TEST_P(ExprTest, TestSealedSegmentGetBatchSize) { - auto schema = std::make_shared(); - auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); - auto int8_fid = schema->AddDebugField("int8", DataType::INT8); - auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); - schema->set_primary_field_id(str1_fid); - - auto seg = CreateSealedSegment(schema); - int N = 100000; - auto raw_data = DataGen(schema, N); - // load field data - auto fields = schema->get_fields(); - for (auto field_data : raw_data.raw_->fields_data()) { - int64_t field_id = field_data.field_id(); - - auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); - auto field_meta = fields.at(FieldId(field_id)); - info.channel->push( - CreateFieldDataFromDataArray(N, &field_data, field_meta)); - info.channel->close(); - - seg->LoadFieldData(FieldId(field_id), info); - } - - proto::plan::GenericValue val; - val.set_int64_val(10); - auto expr = std::make_shared( - expr::ColumnInfo(int8_fid, DataType::INT8), - proto::plan::OpType::GreaterThan, - val); - auto plan_node = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - - std::vector test_batch_size = { - 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; - for (const auto& batch_size : test_batch_size) { - EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; - auto plan = plan::PlanFragment(plan_node); - auto query_context = std::make_shared( - "query id", seg.get(), N, MAX_TIMESTAMP); - - auto task = - milvus::exec::Task::Create("task_expr", plan, 0, query_context); - auto last_num = N % batch_size; - auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; - int iter = 0; - for (;;) { - auto result = task->Next(); - if (!result) { - break; - } - auto childrens = result->childrens(); - if (++iter != iter_num) { - EXPECT_EQ(childrens[0]->size(), batch_size); - } else { - EXPECT_EQ(childrens[0]->size(), last_num); - } - } - } -} - TEST_P(ExprTest, TestGrowingSegmentGetBatchSize) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); @@ -1824,7 +1747,6 @@ TEST_P(ExprTest, TestConjuctExpr) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { ::milvus::proto::plan::GenericValue value; @@ -1851,7 +1773,7 @@ TEST_P(ExprTest, TestConjuctExpr) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); for (int i = 0; i < N; ++i) { EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; } @@ -1893,8 +1815,6 @@ TEST_P(ExprTest, TestUnaryBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - std::vector> test_cases = { {int8_fid, DataType::INT8}, {int16_fid, DataType::INT16}, @@ -1920,7 +1840,7 @@ TEST_P(ExprTest, TestUnaryBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 10; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -1964,8 +1884,6 @@ TEST_P(ExprTest, TestBinaryRangeBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - std::vector> test_cases = { {int8_fid, DataType::INT8}, {int16_fid, DataType::INT16}, @@ -2000,7 +1918,7 @@ TEST_P(ExprTest, TestBinaryRangeBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 10; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -2044,8 +1962,6 @@ TEST_P(ExprTest, TestLogicalUnaryBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - std::vector> test_cases = { {int8_fid, DataType::INT8}, {int16_fid, DataType::INT16}, @@ -2074,7 +1990,7 @@ TEST_P(ExprTest, TestLogicalUnaryBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 50; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -2118,8 +2034,6 @@ TEST_P(ExprTest, TestBinaryLogicalBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - std::vector> test_cases = { {int8_fid, DataType::INT8}, {int16_fid, DataType::INT16}, @@ -2158,7 +2072,7 @@ TEST_P(ExprTest, TestBinaryLogicalBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 50; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -2202,8 +2116,6 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - std::vector> test_cases = { {int8_fid, DataType::INT8}, {int16_fid, DataType::INT16}, @@ -2238,7 +2150,7 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeBenchExpr) { int64_t all_cost = 0; for (int i = 0; i < 50; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -2285,8 +2197,6 @@ TEST_P(ExprTest, TestCompareExprBenchTest) { seg->LoadFieldData(FieldId(field_id), info); } - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); - std::vector< std::pair, std::pair>> test_cases = { @@ -2311,7 +2221,7 @@ TEST_P(ExprTest, TestCompareExprBenchTest) { int64_t all_cost = 0; for (int i = 0; i < 10; i++) { auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); all_cost += std::chrono::duration_cast( std::chrono::steady_clock::now() - start) .count(); @@ -2466,13 +2376,13 @@ TEST_P(ExprTest, TestRefactorExprs) { }; auto test_case = [&](int n) { auto expr = build_expr(UnaryRangeExpr, n); - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); std::cout << "start test" << std::endl; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg.get(), N, final); + final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); std::cout << n << "cost: " << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -2562,7 +2472,6 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { load_index_info.index = std::move(str2_index); seg->LoadIndex(load_index_info); - query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % clause % str1_fid.get() % str2_fid.get(); @@ -2572,8 +2481,11 @@ TEST_P(ExprTest, TestCompareWithScalarIndexMaris) { *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg.get(), N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -3255,7 +3167,7 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto [clause, ref_func, dtype] : testcases) { auto loc = raw_plan_tmp.find("@@@@@"); auto raw_plan = raw_plan_tmp; @@ -3281,10 +3193,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRange) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg.get(), + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -4108,7 +4021,6 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@@"); @@ -4118,10 +4030,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSON) { auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -4175,7 +4088,7 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + for (auto testcase : testcases) { auto check = [&](double value) { if (testcase.op == OpType::Equal) { @@ -4198,7 +4111,8 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -4240,7 +4154,8 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeJSONFloat) { BitsetType final; auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -4708,7 +4623,7 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { seg->LoadIndex(load_index_info); auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -4744,8 +4659,11 @@ TEST_P(ExprTest, TestBinaryArithOpEvalRangeWithScalarSortIndex) { *schema, binary_plan.data(), binary_plan.size()); BitsetType final; - visitor.ExecuteExprNode( - plan->plan_node_->filter_plannode_.value(), seg_promote, N, final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -4895,7 +4813,7 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -4941,10 +4859,11 @@ TEST_P(ExprTest, TestUnaryRangeWithJSON) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -5072,7 +4991,7 @@ TEST_P(ExprTest, TestTermWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -5118,10 +5037,11 @@ TEST_P(ExprTest, TestTermWithJSON) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -5216,7 +5136,7 @@ TEST_P(ExprTest, TestExistsWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -5269,10 +5189,11 @@ TEST_P(ExprTest, TestExistsWithJSON) { *schema, unary_plan.data(), unary_plan.size()); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -5341,7 +5262,6 @@ TEST_P(ExprTest, TestTermInFieldJson) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true}, {"bool"}}, {{false}, {"bool"}}}; @@ -5367,7 +5287,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); // std::cout << "cost" // << std::chrono::duration_cast( // std::chrono::steady_clock::now() - start) @@ -5415,7 +5336,9 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); + std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5463,7 +5386,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5511,7 +5435,8 @@ TEST_P(ExprTest, TestTermInFieldJson) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5699,7 +5624,6 @@ TEST_P(ExprTest, TestJsonContainsAny) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true}, {"bool"}}, {{false}, {"bool"}}}; @@ -5726,7 +5650,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5775,7 +5700,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5824,7 +5750,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5873,7 +5800,8 @@ TEST_P(ExprTest, TestJsonContainsAny) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -5919,7 +5847,6 @@ TEST_P(ExprTest, TestJsonContainsAll) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true, true}, {"bool"}}, {{false, false}, {"bool"}}}; @@ -5951,7 +5878,8 @@ TEST_P(ExprTest, TestJsonContainsAll) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6007,7 +5935,8 @@ TEST_P(ExprTest, TestJsonContainsAll) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6063,7 +5992,8 @@ TEST_P(ExprTest, TestJsonContainsAll) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6117,7 +6047,8 @@ TEST_P(ExprTest, TestJsonContainsAll) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6163,7 +6094,6 @@ TEST_P(ExprTest, TestJsonContainsArray) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); proto::plan::GenericValue generic_a; auto* a = generic_a.mutable_array_val(); @@ -6223,7 +6153,8 @@ TEST_P(ExprTest, TestJsonContainsArray) { std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6256,7 +6187,8 @@ TEST_P(ExprTest, TestJsonContainsArray) { std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6308,7 +6240,8 @@ TEST_P(ExprTest, TestJsonContainsArray) { std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6337,7 +6270,8 @@ TEST_P(ExprTest, TestJsonContainsArray) { std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6391,7 +6325,8 @@ TEST_P(ExprTest, TestJsonContainsArray) { std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6421,7 +6356,8 @@ TEST_P(ExprTest, TestJsonContainsArray) { std::make_shared(DEFAULT_PLANNODE_ID, expr); BitsetType final; auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6490,7 +6426,6 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArray) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); proto::plan::GenericValue int_value; int_value.set_int64_val(1); @@ -6525,7 +6460,8 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArray) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6552,7 +6488,8 @@ TEST_P(ExprTest, TestJsonContainsDiffTypeArray) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6592,7 +6529,6 @@ TEST_P(ExprTest, TestJsonContainsDiffType) { } auto seg_promote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); proto::plan::GenericValue int_val; int_val.set_int64_val(int64_t(3)); @@ -6633,7 +6569,8 @@ TEST_P(ExprTest, TestJsonContainsDiffType) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -6659,7 +6596,8 @@ TEST_P(ExprTest, TestJsonContainsDiffType) { auto plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - visitor.ExecuteExprNode(plan, seg_promote, N * num_iters, final); + final = + ExecuteQueryExpr(plan, seg_promote, N * num_iters, MAX_TIMESTAMP); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) diff --git a/internal/core/unittest/test_expr_materialized_view.cpp b/internal/core/unittest/test_expr_materialized_view.cpp index a0d56952416fa..596344d6d9357 100644 --- a/internal/core/unittest/test_expr_materialized_view.cpp +++ b/internal/core/unittest/test_expr_materialized_view.cpp @@ -27,7 +27,7 @@ #include "knowhere/config.h" #include "query/Plan.h" #include "query/PlanImpl.h" -#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "plan/PlanNode.h" #include "segcore/SegmentSealed.h" #include "segcore/SegmentSealedImpl.h" diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index 38da5af55588c..670855c5c330d 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -15,7 +15,6 @@ #include "common/Types.h" #include "index/IndexFactory.h" #include "knowhere/comp/index_param.h" -#include "query/ExprImpl.h" #include "segcore/reduce/Reduce.h" #include "segcore/reduce_c.h" #include "test_utils/DataGen.h" @@ -24,24 +23,20 @@ #include "pb/schema.pb.h" #include "pb/plan.pb.h" -#include "query/Expr.h" #include "query/Plan.h" #include "query/Utils.h" #include "query/PlanImpl.h" #include "query/PlanNode.h" #include "query/PlanProto.h" #include "query/SearchBruteForce.h" -#include "query/generated/ExecPlanNodeVisitor.h" -#include "query/generated/PlanNodeVisitor.h" -#include "query/generated/ExecExprVisitor.h" -#include "query/generated/ExprVisitor.h" -#include "query/generated/ShowPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/Collection.h" #include "segcore/SegmentSealed.h" #include "segcore/SegmentGrowing.h" #include "segcore/SegmentGrowingImpl.h" #include "test_utils/AssertUtils.h" #include "test_utils/DataGen.h" +#include "test_utils/GenExprProto.h" using namespace milvus; using namespace milvus::index; @@ -91,26 +86,6 @@ const int64_t ROW_COUNT = 100 * 1000; // } // } -TEST(Float16, ShowExecutor) { - auto metric_type = knowhere::metric::L2; - auto node = std::make_unique(); - auto schema = std::make_shared(); - auto field_id = schema->AddDebugField( - "fakevec", DataType::VECTOR_FLOAT16, 16, metric_type); - int64_t num_queries = 100L; - auto raw_data = DataGen(schema, num_queries); - auto& info = node->search_info_; - info.metric_type_ = metric_type; - info.topk_ = 20; - info.field_id_ = field_id; - node->predicate_ = std::nullopt; - ShowPlanNodeVisitor show_visitor; - PlanNodePtr base(node.release()); - auto res = show_visitor.call_child(*base); - auto dup = res; - std::cout << dup.dump(4); -} - TEST(Float16, ExecWithoutPredicateFlat) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( @@ -235,8 +210,8 @@ TEST(Float16, RetrieveEmpty) { fid_64, DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -347,26 +322,6 @@ TEST(Float16, ExecWithPredicate) { // } // } -TEST(BFloat16, ShowExecutor) { - auto metric_type = knowhere::metric::L2; - auto node = std::make_unique(); - auto schema = std::make_shared(); - auto field_id = schema->AddDebugField( - "fakevec", DataType::VECTOR_BFLOAT16, 16, metric_type); - int64_t num_queries = 100L; - auto raw_data = DataGen(schema, num_queries); - auto& info = node->search_info_; - info.metric_type_ = metric_type; - info.topk_ = 20; - info.field_id_ = field_id; - node->predicate_ = std::nullopt; - ShowPlanNodeVisitor show_visitor; - PlanNodePtr base(node.release()); - auto res = show_visitor.call_child(*base); - auto dup = res; - std::cout << dup.dump(4); -} - TEST(BFloat16, ExecWithoutPredicateFlat) { auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( @@ -480,16 +435,19 @@ TEST(BFloat16, RetrieveEmpty) { auto plan = std::make_unique(*schema); std::vector values; + std::vector retrieve_ints; for (int i = 0; i < req_size; ++i) { values.emplace_back(choose(i)); + proto::plan::GenericValue val; + val.set_int64_val(i); + retrieve_ints.push_back(val); } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( - fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + auto term_expr = std::make_shared( + expr::ColumnInfo(fid_64, DataType::INT64), retrieve_ints); + auto expr_plan = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->plannodes_ = std::move(expr_plan); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; diff --git a/internal/core/unittest/test_growing.cpp b/internal/core/unittest/test_growing.cpp index 77a75f78998d5..fc35bed1a68d0 100644 --- a/internal/core/unittest/test_growing.cpp +++ b/internal/core/unittest/test_growing.cpp @@ -87,7 +87,8 @@ TEST(Growing, RemoveDuplicatedRecords) { BitsetType bitset(c); std::cout << "start to search delete" << std::endl; - segment->mask_with_delete(bitset, c, 1003); + BitsetTypeView bitset_view(bitset); + segment->mask_with_delete(bitset_view, c, 1003); for (int i = 0; i < bitset.size(); i++) { ASSERT_EQ(bitset[i], bits[i]) << "index:" << i << std::endl; diff --git a/internal/core/unittest/test_integer_overflow.cpp b/internal/core/unittest/test_integer_overflow.cpp index be0e3e67fe28f..84b98e3d00813 100644 --- a/internal/core/unittest/test_integer_overflow.cpp +++ b/internal/core/unittest/test_integer_overflow.cpp @@ -16,9 +16,8 @@ #include #include "common/Types.h" -#include "query/Expr.h" #include "query/Plan.h" -#include "query/generated/ExecExprVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "test_utils/DataGen.h" @@ -620,12 +619,14 @@ binary_arith_op_eval_range_expr: < auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + // vectorsearch node => mvcc node => filter node + // just test filter node + final = ExecuteQueryExpr( + (plan->plan_node_->plannodes_->sources()[0])->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/internal/core/unittest/test_plan_proto.cpp b/internal/core/unittest/test_plan_proto.cpp index e848e52d45a09..6c803c41e01fa 100644 --- a/internal/core/unittest/test_plan_proto.cpp +++ b/internal/core/unittest/test_plan_proto.cpp @@ -28,5 +28,5 @@ TEST(PlanProto, NotSetUnsupported) { proto::plan::Expr expr_pb; ProtoParser parser(*schema); - ASSERT_ANY_THROW(parser.ParseExpr(expr_pb)); + ASSERT_ANY_THROW(parser.ParseExprs(expr_pb)); } diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index ff5b8a5b48576..a9e4c80eb0ea3 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -12,12 +12,9 @@ #include #include "pb/schema.pb.h" -#include "query/Expr.h" #include "query/PlanImpl.h" #include "query/PlanNode.h" -#include "query/generated/ExecPlanNodeVisitor.h" -#include "query/generated/ExprVisitor.h" -#include "query/generated/ShowPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/SegmentSealed.h" #include "test_utils/AssertUtils.h" #include "test_utils/DataGen.h" @@ -31,26 +28,6 @@ namespace { const int64_t ROW_COUNT = 100 * 1000; } -TEST(Query, ShowExecutor) { - auto metric_type = knowhere::metric::L2; - auto node = std::make_unique(); - auto schema = std::make_shared(); - auto field_id = schema->AddDebugField( - "fakevec", DataType::VECTOR_FLOAT, 16, metric_type); - int64_t num_queries = 100L; - auto raw_data = DataGen(schema, num_queries); - auto& info = node->search_info_; - info.metric_type_ = metric_type; - info.topk_ = 20; - info.field_id_ = field_id; - node->predicate_ = std::nullopt; - ShowPlanNodeVisitor show_visitor; - PlanNodePtr base(node.release()); - auto res = show_visitor.call_child(*base); - auto dup = res; - std::cout << dup.dump(4); -} - TEST(Query, ParsePlaceholderGroup) { const char* raw_plan = R"(vector_anns: < field_id: 100 diff --git a/internal/core/unittest/test_regex_query.cpp b/internal/core/unittest/test_regex_query.cpp index 455a582d7a42e..71751553e4bb5 100644 --- a/internal/core/unittest/test_regex_query.cpp +++ b/internal/core/unittest/test_regex_query.cpp @@ -26,7 +26,7 @@ #include "knowhere/comp/brute_force.h" #include "test_utils/GenExprProto.h" #include "query/PlanProto.h" -#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "index/InvertedIndexTantivy.h" using namespace milvus; @@ -125,11 +125,8 @@ TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnNonStringField) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - ASSERT_ANY_THROW( - - visitor.ExecuteExprNode(parsed, segpromote, N, final)); + ASSERT_ANY_THROW(ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP)); } TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnStringField) { @@ -150,9 +147,8 @@ TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnStringField) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_TRUE(final[1]); ASSERT_TRUE(final[2]); @@ -177,9 +173,8 @@ TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnJsonField) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_FALSE(final[1]); ASSERT_TRUE(final[2]); @@ -333,9 +328,7 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnNonStringField) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); - BitsetType final; - ASSERT_ANY_THROW(visitor.ExecuteExprNode(parsed, segpromote, N, final)); + ASSERT_ANY_THROW(ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP)); } TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnStringField) { @@ -356,9 +349,8 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnStringField) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_TRUE(final[1]); ASSERT_TRUE(final[2]); @@ -383,9 +375,8 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnJsonField) { std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_FALSE(final[1]); ASSERT_TRUE(final[2]); @@ -413,9 +404,7 @@ TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnIndexedNonStringField) { auto segpromote = dynamic_cast(seg.get()); query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - ASSERT_ANY_THROW( - - visitor.ExecuteExprNode(parsed, segpromote, N, final)); + ASSERT_ANY_THROW(ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP)); } TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnStlSortStringField) { @@ -438,9 +427,8 @@ TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnStlSortStringField) { LoadStlSortIndex(); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_TRUE(final[1]); ASSERT_TRUE(final[2]); @@ -468,10 +456,8 @@ TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnInvertedIndexStringField) { LoadInvertedIndex(); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_TRUE(final[1]); ASSERT_TRUE(final[2]); @@ -499,10 +485,9 @@ TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnUnsupportedIndex) { LoadMockIndex(); auto segpromote = dynamic_cast(seg.get()); - query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; // regex query under this index will be executed using raw data (brute force). - visitor.ExecuteExprNode(parsed, segpromote, N, final); + final = ExecuteQueryExpr(parsed, segpromote, N, MAX_TIMESTAMP); ASSERT_FALSE(final[0]); ASSERT_TRUE(final[1]); ASSERT_TRUE(final[2]); diff --git a/internal/core/unittest/test_retrieve.cpp b/internal/core/unittest/test_retrieve.cpp index 31846efbbfc68..dea5a2493e525 100644 --- a/internal/core/unittest/test_retrieve.cpp +++ b/internal/core/unittest/test_retrieve.cpp @@ -13,8 +13,8 @@ #include "common/Types.h" #include "knowhere/comp/index_param.h" -#include "query/Expr.h" #include "test_utils/DataGen.h" +#include "test_utils/GenExprProto.h" #include "plan/PlanNode.h" using namespace milvus; @@ -77,8 +77,8 @@ TEST_P(RetrieveTest, AutoID) { fid_64, DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_fields_id{fid_64, fid_vec}; plan->field_ids_ = target_fields_id; @@ -137,8 +137,8 @@ TEST_P(RetrieveTest, AutoID2) { fid_64, DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -202,8 +202,8 @@ TEST_P(RetrieveTest, NotExist) { fid_64, DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -259,8 +259,8 @@ TEST_P(RetrieveTest, Empty) { fid_64, DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -302,8 +302,7 @@ TEST_P(RetrieveTest, Limit) { OpType::GreaterEqual, unary_val); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, expr); + plan->plan_node_->plannodes_ = milvus::test::CreateRetrievePlanByExpr(expr); // test query results exceed the limit size std::vector target_fields{TimestampFieldID, fid_64, fid_vec}; @@ -350,9 +349,7 @@ TEST_P(RetrieveTest, FillEntry) { OpType::GreaterEqual, unary_val); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, expr); - + plan->plan_node_->plannodes_ = milvus::test::CreateRetrievePlanByExpr(expr); // test query results exceed the limit size std::vector target_fields{TimestampFieldID, fid_64, @@ -403,8 +400,8 @@ TEST_P(RetrieveTest, LargeTimestamp) { values); ; plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -476,8 +473,8 @@ TEST_P(RetrieveTest, Delete) { fid_64, DataType::INT64, std::vector()), values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->filter_plannode_ = - std::make_shared(DEFAULT_PLANNODE_ID, term_expr); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); std::vector target_offsets{fid_ts, fid_64, fid_vec}; plan->field_ids_ = target_offsets; diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index 5494e0bfcf552..16a16a356f787 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -1035,7 +1035,8 @@ TEST(Sealed, Delete) { segment->LoadDeletedRecord(info); BitsetType bitset(N, false); - segment->mask_with_delete(bitset, 10, 11); + auto bitset_view = BitsetTypeView(bitset); + segment->mask_with_delete(bitset_view, 10, 11); ASSERT_EQ(bitset.count(), pks.size()); int64_t new_count = 3; @@ -1133,11 +1134,12 @@ TEST(Sealed, OverlapDelete) { segment->LoadDeletedRecord(overlap_info); BitsetType bitset(N, false); + auto bitset_view = BitsetTypeView(bitset); auto deleted_record2 = pks.size(); ASSERT_EQ(segment->get_deleted_count(), deleted_record1 + deleted_record2) << "deleted_count=" << segment->get_deleted_count() << " pks_count=" << deleted_record1 + deleted_record2 << std::endl; - segment->mask_with_delete(bitset, 10, 12); + segment->mask_with_delete(bitset_view, 10, 12); ASSERT_EQ(bitset.count(), pks.size()) << "bitset_count=" << bitset.count() << " pks_count=" << pks.size() << std::endl; @@ -1330,7 +1332,8 @@ TEST(Sealed, DeleteDuplicatedRecords) { BitsetType bitset(c); std::cout << "start to search delete" << std::endl; - segment->mask_with_delete(bitset, c, 1003); + BitsetTypeView bitset_view(bitset); + segment->mask_with_delete(bitset_view, c, 1003); for (int i = 0; i < bitset.size(); i++) { ASSERT_EQ(bitset[i], bits[i]) << "index:" << i << std::endl; diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index c3b29e54f6e96..a406ab8e86bb0 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -16,12 +16,11 @@ #include "common/Tracer.h" #include "pb/plan.pb.h" -#include "query/Expr.h" #include "query/PlanProto.h" #include "query/SearchBruteForce.h" #include "query/Utils.h" -#include "query/generated/ExecExprVisitor.h" -#include "query/generated/PlanNodeVisitor.h" +#include "query/PlanNodeVisitor.h" +#include "query/ExecPlanNodeVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "test_utils/DataGen.h" #include "test_utils/GenExprProto.h" @@ -282,12 +281,12 @@ TEST(StringExpr, Term) { for (const auto& [_, term] : terms) { auto plan_proto = GenTermPlan(fvec_meta, str_meta, term); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -396,12 +395,12 @@ TEST(StringExpr, Compare) { for (const auto& [op, ref_func] : testcases) { auto plan_proto = gen_compare_plan(op); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -492,12 +491,12 @@ TEST(StringExpr, UnaryRange) { for (const auto& [op, value, ref_func] : testcases) { auto plan_proto = gen_unary_range_plan(op, value); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -606,12 +605,12 @@ TEST(StringExpr, BinaryRange) { auto plan_proto = gen_binary_range_plan(lb_inclusive, ub_inclusive, lb, ub); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); BitsetType final; - visitor.ExecuteExprNode(plan->plan_node_->filter_plannode_.value(), - seg_promote, - N * num_iters, - final); + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/internal/core/unittest/test_utils/GenExprProto.h b/internal/core/unittest/test_utils/GenExprProto.h index 77f0a4964e4bb..a1744d3c5e268 100644 --- a/internal/core/unittest/test_utils/GenExprProto.h +++ b/internal/core/unittest/test_utils/GenExprProto.h @@ -11,7 +11,13 @@ #pragma once +#include +#include + +#include "common/Consts.h" +#include "expr/ITypeExpr.h" #include "pb/plan.pb.h" +#include "plan/PlanNode.h" namespace milvus::test { inline auto @@ -62,4 +68,40 @@ inline auto GenExpr() { return std::make_unique(); } + +inline std::shared_ptr +CreateRetrievePlanByExpr(std::shared_ptr expr) { + auto init_plannode_id = std::stoi(DEFAULT_PLANNODE_ID); + milvus::plan::PlanNodePtr plannode; + std::vector sources; + + plannode = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + sources = std::vector{plannode}; + + plannode = std::make_shared( + std::to_string(init_plannode_id++), sources); + return plannode; +} + +inline std::shared_ptr +CreateSearchPlanByExpr(std::shared_ptr expr) { + auto init_plannode_id = std::stoi(DEFAULT_PLANNODE_ID); + milvus::plan::PlanNodePtr plannode; + std::vector sources; + + plannode = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + sources = std::vector{plannode}; + + plannode = std::make_shared( + std::to_string(init_plannode_id++), sources); + sources = std::vector{plannode}; + + plannode = std::make_shared( + std::to_string(init_plannode_id++), sources); + + return plannode; +} + } // namespace milvus::test