From 12824a875b802e7bb54f1d9eff7285310f9fcf14 Mon Sep 17 00:00:00 2001 From: Yinzuo Jiang Date: Tue, 17 Sep 2024 19:53:23 +0800 Subject: [PATCH] feat: Implement custom function module in milvus expr OSPP 2024 project: https://summer-ospp.ac.cn/org/prodetail/247410235?list=org&navpage=org Solutions: - parser (planparserv2) - add CallExpr in planparserv2/Plan.g4 - update parser_visitor and show_visitor - grpc protobuf - add CallExpr in plan.proto - execution (`core/src/exec`) - add `CallExpr` `ValueExpr` and `ColumnExpr` (both logical and physical) for function call and function parameters - function factory (`core/src/exec/expression/function`) - create a global hashmap when starting milvus (see server.go) - the global hashmap stores function signatures and their function pointers, the CallExpr in execution engine can get the function pointer by function signature. - custom functions - empty(string) - add cpp/go unittests and E2E tests closes: #36559 Signed-off-by: Yinzuo Jiang --- docs/design_docs/segcore/visitor.md | 13 +- internal/core/CMakeLists.txt | 6 + internal/core/src/common/Types.h | 1 + internal/core/src/common/Vector.h | 89 +++- internal/core/src/common/init_c.cpp | 2 +- internal/core/src/exec/Task.cpp | 2 +- .../core/src/exec/expression/CallExpr.cpp | 46 ++ internal/core/src/exec/expression/CallExpr.h | 83 +++ .../core/src/exec/expression/ColumnExpr.cpp | 359 +++++++++++++ .../core/src/exec/expression/ColumnExpr.h | 200 +++++++ .../core/src/exec/expression/CompareExpr.h | 2 +- internal/core/src/exec/expression/Expr.cpp | 42 +- internal/core/src/exec/expression/Expr.h | 6 +- .../core/src/exec/expression/ValueExpr.cpp | 96 ++++ internal/core/src/exec/expression/ValueExpr.h | 67 +++ .../expression/function/FunctionFactory.cpp | 81 +++ .../expression/function/FunctionFactory.h | 113 ++++ .../exec/expression/function/impl/Empty.cpp | 48 ++ .../src/exec/expression/function/impl/Empty.h | 33 ++ .../src/exec/expression/function/init_c.cpp | 23 + .../src/exec/expression/function/init_c.h | 28 + .../core/src/exec/operator/FilterBitsNode.cpp | 24 +- internal/core/src/expr/ITypeExpr.h | 144 ++++-- internal/core/src/query/PlanProto.cpp | 88 +++- internal/core/src/query/PlanProto.h | 40 +- internal/core/unittest/test_exec.cpp | 63 ++- internal/core/unittest/test_expr.cpp | 69 +++ internal/parser/planparserv2/Plan.g4 | 1 + .../planparserv2/check_identical_test.go | 28 +- .../parser/planparserv2/generated/Plan.interp | 2 +- .../generated/plan_base_visitor.go | 4 + .../planparserv2/generated/plan_parser.go | 488 ++++++++++++------ .../planparserv2/generated/plan_visitor.go | 3 + .../parser/planparserv2/parser_visitor.go | 22 + .../planparserv2/plan_parser_v2_test.go | 46 +- internal/parser/planparserv2/show_visitor.go | 16 +- internal/proto/plan.proto | 6 + internal/querynodev2/server.go | 3 + tests/python_client/requirements.txt | 1 + tests/python_client/testcases/test_query.py | 17 + 40 files changed, 2147 insertions(+), 258 deletions(-) create mode 100644 internal/core/src/exec/expression/CallExpr.cpp create mode 100644 internal/core/src/exec/expression/CallExpr.h create mode 100644 internal/core/src/exec/expression/ColumnExpr.cpp create mode 100644 internal/core/src/exec/expression/ColumnExpr.h create mode 100644 internal/core/src/exec/expression/ValueExpr.cpp create mode 100644 internal/core/src/exec/expression/ValueExpr.h create mode 100644 internal/core/src/exec/expression/function/FunctionFactory.cpp create mode 100644 internal/core/src/exec/expression/function/FunctionFactory.h create mode 100644 internal/core/src/exec/expression/function/impl/Empty.cpp create mode 100644 internal/core/src/exec/expression/function/impl/Empty.h create mode 100644 internal/core/src/exec/expression/function/init_c.cpp create mode 100644 internal/core/src/exec/expression/function/init_c.h diff --git a/docs/design_docs/segcore/visitor.md b/docs/design_docs/segcore/visitor.md index 6cf70d7a8f08f..2c8fd5568fece 100644 --- a/docs/design_docs/segcore/visitor.md +++ b/docs/design_docs/segcore/visitor.md @@ -1,20 +1,15 @@ # Visitor Pattern Visitor Pattern is used in segcore for parse and execute Execution Plan. -1. Inside `${core}/src/query/PlanNode.h`, contains physical plan for vector search: +1. Inside `${internal/core}/src/query/PlanNode.h`, contains physical plan for vector search: 1. `FloatVectorANNS` FloatVector search execution node 2. `BinaryVectorANNS` BinaryVector search execution node -2. `${core}/src/query/Expr.h` contains physical plan for scalar expression: +2. `${internal/core}/src/query/Expr.h` contains physical plan for scalar expression: 1. `TermExpr` support operation like `col in [1, 2, 3]` 2. `RangeExpr` support constant compare with data column like `a >= 5` `1 < b < 2` 3. `CompareExpr` support compare with different columns, like `a < b` 4. `LogicalBinaryExpr` support and/or 5. `LogicalUnaryExpr` support not -Currently, under `${core/query/visitors}` directory, there are the following visitors: -1. `ShowPlanNodeVisitor` prints PlanNode in json -2. `ShowExprVisitor` Expr -> json -3. `Verify...Visitor` validates ... -4. `ExtractInfo...Visitor` extracts info from..., including involved_fields and else -5. `ExecExprVisitor` generates bitmask according to expression -6. `ExecPlanNodeVistor` physical plan executor only supports ANNS node for now +Currently, under `${internal/core/src/query}` directory, there are the following visitors: +1. `ExecPlanNodeVistor` physical plan executor only supports ANNS node for now diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index 10b3c3e76aff5..07c111a452401 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -292,6 +292,12 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/segcore/ FILES_MATCHING PATTERN "*_c.h" ) +# Install exec/expression/function +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/exec/expression/function/ + DESTINATION include/exec/expression/function + FILES_MATCHING PATTERN "*_c.h" +) + # Install indexbuilder install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/ DESTINATION include/indexbuilder diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 2473b21a88372..a9276cb0f7c6d 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -59,6 +59,7 @@ using float16 = knowhere::fp16; using bfloat16 = knowhere::bf16; using bin1 = knowhere::bin1; +// See also: schema.proto enum class DataType { NONE = 0, BOOL = 1, diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index ac5f0b217b0c5..de8ddc429b51b 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -17,9 +17,11 @@ #pragma once #include -#include +#include #include "common/FieldData.h" +#include "common/FieldDataInterface.h" +#include "common/Types.h" namespace milvus { @@ -27,7 +29,6 @@ namespace milvus { * @brief base class for different type vector * @todo implement full null value support */ - class BaseVector { public: BaseVector(DataType data_type, @@ -87,7 +88,7 @@ class ColumnVector final : public BaseVector { } void* - GetRawData() { + GetRawData() const { return values_->Data(); } @@ -103,6 +104,80 @@ class ColumnVector final : public BaseVector { using ColumnVectorPtr = std::shared_ptr; +template +class ValueVector : public BaseVector { + public: + ValueVector(DataType data_type, + size_t length, + std::optional null_count = std::nullopt) + : BaseVector(data_type, length) { + } + + virtual const T + GetValueAt(size_t index) const = 0; + + virtual TargetBitmap + Apply(std::function func) = 0; +}; + +template +class ConstantVector final : public ValueVector { + public: + ConstantVector(DataType data_type, + size_t length, + const T& val, + std::optional null_count = std::nullopt) + : ValueVector(data_type, length, null_count), val_(val) { + } + + const T + GetValueAt(size_t _) const override { + return val_; + } + + TargetBitmap + Apply(std::function func) override { + return TargetBitmap(this->size(), func(val_)); + } + + private: + const T val_; +}; + +// TODO: simd support +template +class ColumnValueVector final : public ValueVector { + public: + ColumnValueVector(DataType data_type, + size_t length, + std::optional null_count = std::nullopt) + : ValueVector(data_type, length, null_count) { + vec_.reserve(length); + } + + void + Set(size_t index, const T& val) { + vec_[index] = val; + }; + + const T + GetValueAt(size_t i) const override { + return vec_[i]; + } + + TargetBitmap + Apply(std::function func) override { + TargetBitmap result_vec(this->size()); + for (int i = 0; i < this->size(); ++i) { + result_vec.set(i, func(GetValueAt(i))); + } + return result_vec; + } + + private: + FixedVector vec_; +}; + /** * @brief Multi vectors for scalar types * mainly using it to pass internal result in segcore scalar engine system @@ -130,8 +205,7 @@ class RowVector : public BaseVector { } RowVector(std::vector&& children) - : BaseVector(DataType::ROW, 0) { - children_values_ = std::move(children); + : BaseVector(DataType::ROW, 0), children_values_(std::move(children)) { for (auto& child : children_values_) { if (child->size() > length_) { length_ = child->size(); @@ -140,12 +214,12 @@ class RowVector : public BaseVector { } const std::vector& - childrens() { + childrens() const { return children_values_; } VectorPtr - child(int index) { + child(int index) const { assert(index < children_values_.size()); return children_values_[index]; } @@ -155,5 +229,4 @@ class RowVector : public BaseVector { }; using RowVectorPtr = std::shared_ptr; - } // namespace milvus diff --git a/internal/core/src/common/init_c.cpp b/internal/core/src/common/init_c.cpp index ce961b7d8bde1..77764ffa555e9 100644 --- a/internal/core/src/common/init_c.cpp +++ b/internal/core/src/common/init_c.cpp @@ -105,4 +105,4 @@ SetTrace(CTraceConfig* config) { config->oltpSecure, config->nodeID}; milvus::tracer::initTelemetry(traceConfig); -} \ No newline at end of file +} diff --git a/internal/core/src/exec/Task.cpp b/internal/core/src/exec/Task.cpp index d03ca3f97fb93..14731417f0ebb 100644 --- a/internal/core/src/exec/Task.cpp +++ b/internal/core/src/exec/Task.cpp @@ -235,4 +235,4 @@ Task::Next(ContinueFuture* future) { } } // namespace exec -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/exec/expression/CallExpr.cpp b/internal/core/src/exec/expression/CallExpr.cpp new file mode 100644 index 0000000000000..0f1de8b638bce --- /dev/null +++ b/internal/core/src/exec/expression/CallExpr.cpp @@ -0,0 +1,46 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/FieldDataInterface.h" +#include "common/Vector.h" +#include "exec/expression/CallExpr.h" +#include "exec/expression/EvalCtx.h" +#include "exec/expression/function/FunctionFactory.h" + +#include +#include + +namespace milvus { +namespace exec { + +void +PhyCallExpr::Eval(EvalCtx& context, VectorPtr& result) { + AssertInfo(inputs_.size() == expr_->inputs().size(), + "logical call expr needs {} inputs, but {} inputs are provided", + expr_->inputs().size(), + inputs_.size()); + std::vector args; + for (auto &input: this->inputs_) { + VectorPtr arg_result; + input->Eval(context, arg_result); + args.push_back(std::move(arg_result)); + } + RowVector row_vector(std::move(args)); + this->expr_->function_ptr()(context, row_vector, result); +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/CallExpr.h b/internal/core/src/exec/expression/CallExpr.h new file mode 100644 index 0000000000000..f074c7b423e77 --- /dev/null +++ b/internal/core/src/exec/expression/CallExpr.h @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "common/EasyAssert.h" +#include "common/FieldDataInterface.h" +#include "common/Utils.h" +#include "common/Vector.h" +#include "exec/expression/EvalCtx.h" +#include "exec/expression/Expr.h" +#include "exec/expression/function/FunctionFactory.h" +#include "expr/ITypeExpr.h" +#include "fmt/core.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyCallExpr : public Expr { + public: + PhyCallExpr(const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + int64_t active_count, + int64_t batch_size) + : Expr(DataType::BOOL, std::move(input), name), + expr_(expr), + active_count_(active_count), + segment_(segment), + batch_size_(batch_size) { + size_per_chunk_ = segment_->size_per_chunk(); + num_chunk_ = upper_div(active_count_, size_per_chunk_); + AssertInfo( + batch_size_ > 0, + fmt::format("expr batch size should greater than zero, but now: {}", + batch_size_)); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + void + MoveCursor() override { + for (auto input : inputs_) { + input->MoveCursor(); + } + } + + private: + std::shared_ptr expr_; + + int64_t active_count_{0}; + int64_t num_chunk_{0}; + int64_t current_chunk_id_{0}; + int64_t current_chunk_pos_{0}; + int64_t size_per_chunk_{0}; + + const segcore::SegmentInternalInterface* segment_; + int64_t batch_size_; +}; + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ColumnExpr.cpp b/internal/core/src/exec/expression/ColumnExpr.cpp new file mode 100644 index 0000000000000..d40c60d030aee --- /dev/null +++ b/internal/core/src/exec/expression/ColumnExpr.cpp @@ -0,0 +1,359 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ColumnExpr.h" + +namespace milvus { +namespace exec { + +int64_t +PhyColumnExpr::GetNextBatchSize() { + auto current_rows = GetCurrentRows(); + + return current_rows + batch_size_ >= active_count_ + ? active_count_ - current_rows + : batch_size_; +} + +template +MultipleChunkDataAccessor +PhyColumnExpr::GetChunkData(FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) { + if (index) { + auto& indexing = const_cast&>( + segment_->chunk_scalar_index(field_id, current_chunk_id)); + auto current_chunk_size = segment_->type() == SegmentType::Growing + ? size_per_chunk_ + : active_count_; + + if (indexing.HasRawData()) { + return [&, current_chunk_size]() -> const number { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + indexing = const_cast&>( + segment_->chunk_scalar_index(field_id, + current_chunk_id)); + } + return indexing.Reverse_Lookup(current_chunk_pos++); + }; + } + } + auto chunk_data = + segment_->chunk_data(field_id, current_chunk_id).data(); + auto current_chunk_size = segment_->chunk_size(field_id, current_chunk_id); + return + [=, ¤t_chunk_id, ¤t_chunk_pos]() mutable -> const number { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + chunk_data = + segment_->chunk_data(field_id, current_chunk_id).data(); + current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + } + + return chunk_data[current_chunk_pos++]; + }; +} + +template <> +MultipleChunkDataAccessor +PhyColumnExpr::GetChunkData(FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) { + if (index) { + auto& indexing = const_cast&>( + segment_->chunk_scalar_index(field_id, + current_chunk_id)); + auto current_chunk_size = segment_->type() == SegmentType::Growing + ? size_per_chunk_ + : active_count_; + + if (indexing.HasRawData()) { + return [&, current_chunk_size]() mutable -> const number { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + indexing = const_cast&>( + segment_->chunk_scalar_index( + field_id, current_chunk_id)); + } + return indexing.Reverse_Lookup(current_chunk_pos++); + }; + } + } + if (segment_->type() == SegmentType::Growing && + !storage::MmapManager::GetInstance() + .GetMmapConfig() + .growing_enable_mmap) { + auto chunk_data = + segment_->chunk_data(field_id, current_chunk_id) + .data(); + auto current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + return [=, + ¤t_chunk_id, + ¤t_chunk_pos]() mutable -> const number { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + chunk_data = + segment_ + ->chunk_data(field_id, current_chunk_id) + .data(); + current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + } + + return chunk_data[current_chunk_pos++]; + }; + } else { + auto chunk_data = + segment_->chunk_view(field_id, current_chunk_id) + .first.data(); + auto current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + return [=, + ¤t_chunk_id, + ¤t_chunk_pos]() mutable -> const number { + if (current_chunk_pos >= current_chunk_size) { + current_chunk_id++; + current_chunk_pos = 0; + chunk_data = segment_ + ->chunk_view( + field_id, current_chunk_id) + .first.data(); + current_chunk_size = + segment_->chunk_size(field_id, current_chunk_id); + } + + return std::string(chunk_data[current_chunk_pos++]); + }; + } +} + +MultipleChunkDataAccessor +PhyColumnExpr::GetChunkData(DataType data_type, + FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos) { + switch (data_type) { + case DataType::BOOL: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT8: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT16: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT32: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::INT64: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::FLOAT: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::DOUBLE: + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + case DataType::VARCHAR: { + return GetChunkData( + field_id, index, current_chunk_id, current_chunk_pos); + } + default: + PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); + } +} + +template +ChunkDataAccessor +PhyColumnExpr::GetChunkData(FieldId field_id, int chunk_id, int data_barrier) { + if (chunk_id >= data_barrier) { + auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); + if (indexing.HasRawData()) { + return [&indexing](int i) -> const number { + return indexing.Reverse_Lookup(i); + }; + } + } + auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); + return [chunk_data](int i) -> const number { return chunk_data[i]; }; +} + +template <> +ChunkDataAccessor +PhyColumnExpr::GetChunkData(FieldId field_id, + int chunk_id, + int data_barrier) { + if (chunk_id >= data_barrier) { + auto& indexing = + segment_->chunk_scalar_index(field_id, chunk_id); + if (indexing.HasRawData()) { + return [&indexing](int i) -> const std::string { + return indexing.Reverse_Lookup(i); + }; + } + } + if (segment_->type() == SegmentType::Growing && + !storage::MmapManager::GetInstance() + .GetMmapConfig() + .growing_enable_mmap) { + auto chunk_data = + segment_->chunk_data(field_id, chunk_id).data(); + return [chunk_data](int i) -> const number { return chunk_data[i]; }; + } else { + auto chunk_data = + segment_->chunk_view(field_id, chunk_id) + .first.data(); + return [chunk_data](int i) -> const number { + return std::string(chunk_data[i]); + }; + } +} + +ChunkDataAccessor +PhyColumnExpr::GetChunkData(DataType data_type, + FieldId field_id, + int chunk_id, + int data_barrier) { + switch (data_type) { + case DataType::BOOL: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT8: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT16: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT32: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT64: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::FLOAT: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::DOUBLE: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::VARCHAR: { + return GetChunkData(field_id, chunk_id, data_barrier); + } + default: + PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); + } +} + +void +PhyColumnExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (this->expr_->type()) { + case DataType::BOOL: + result = DoEval(); + break; + case DataType::INT8: + result = DoEval(); + break; + case DataType::INT16: + result = DoEval(); + break; + case DataType::INT32: + result = DoEval(); + break; + case DataType::INT64: + result = DoEval(); + break; + case DataType::FLOAT: + result = DoEval(); + break; + case DataType::DOUBLE: + result = DoEval(); + break; + case DataType::VARCHAR: { + result = DoEval(); + break; + } + default: + PanicInfo(DataTypeInvalid, + "unsupported data type: {}", + this->expr_->type()); + } +} + +template +VectorPtr +PhyColumnExpr::DoEval() { + // same as PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) + if (segment_->is_chunked()) { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = std::make_shared>( + expr_->GetColumn().data_type_, real_batch_size); + + auto chunk_data = GetChunkData(expr_->GetColumn().data_type_, + expr_->GetColumn().field_id_, + is_indexed_, + current_chunk_id_, + current_chunk_pos_); + for (int i = 0; i < real_batch_size; ++i) { + res_vec->Set(i, boost::get(chunk_data())); + } + return res_vec; + } else { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = std::make_shared>( + expr_->GetColumn().data_type_, real_batch_size); + + auto data_barrier = + segment_->num_chunk_data(expr_->GetColumn().field_id_); + + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk_ - 1 + ? active_count_ - chunk_id * size_per_chunk_ + : size_per_chunk_; + auto chunk_data = GetChunkData(expr_->GetColumn().data_type_, + expr_->GetColumn().field_id_, + chunk_id, + data_barrier); + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + res_vec->Set(processed_rows++, boost::get(chunk_data(i))); + if (processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + return res_vec; + } + } + } + return res_vec; + } +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ColumnExpr.h b/internal/core/src/exec/expression/ColumnExpr.h new file mode 100644 index 0000000000000..7a723eb4eff2c --- /dev/null +++ b/internal/core/src/exec/expression/ColumnExpr.h @@ -0,0 +1,200 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +// NOTE: similar to PhyCompareFilterExpr +using number = boost::variant; +using ChunkDataAccessor = std::function; +using MultipleChunkDataAccessor = std::function; + +class PhyColumnExpr : public Expr { + public: + PhyColumnExpr(const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + int64_t active_count, + int64_t batch_size) + : Expr(expr->type(), std::move(input), name), + active_count_(active_count), + segment_(segment), + batch_size_(batch_size), + expr_(expr) { + is_indexed_ = segment_->HasIndex(expr_->GetColumn().field_id_); + size_per_chunk_ = segment_->size_per_chunk(); + // NOTE: similar to PhyCompareFilterExpr + if (segment_->is_chunked()) { + num_chunk_ = + is_indexed_ + ? segment_->num_chunk_index(expr_->GetColumn().field_id_) + : segment_->type() == SegmentType::Growing + ? upper_div(active_count_, size_per_chunk_) + : segment_->num_chunk_data(expr_->GetColumn().field_id_); + } else { + num_chunk_ = + is_indexed_ + ? segment_->num_chunk_index(expr_->GetColumn().field_id_) + : upper_div(active_count_, size_per_chunk_); + } + AssertInfo( + batch_size_ > 0, + fmt::format("expr batch size should greater than zero, but now: {}", + batch_size_)); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + // NOTE: similar to PhyCompareFilterExpr + void + MoveCursor() override { + if (segment_->is_chunked()) { + MoveCursorForMultipleChunk(); + } else { + MoveCursorForSingleChunk(); + } + } + + void + MoveCursorForMultipleChunk() { + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = 0; + if (segment_->type() == SegmentType::Growing) { + chunk_size = chunk_id == num_chunk_ - 1 + ? active_count_ - chunk_id * size_per_chunk_ + : size_per_chunk_; + } else { + chunk_size = segment_->chunk_size(expr_->GetColumn().field_id_, + chunk_id); + } + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + if (++processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + } + } + } + } + + void + MoveCursorForSingleChunk() { + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk_ - 1 + ? active_count_ - chunk_id * size_per_chunk_ + : size_per_chunk_; + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + if (++processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + } + } + } + } + + private: + int64_t + GetCurrentRows() { + if (segment_->is_chunked()) { + auto current_rows = + is_indexed_ && segment_->type() == SegmentType::Sealed + ? current_chunk_pos_ + : segment_->num_rows_until_chunk( + expr_->GetColumn().field_id_, current_chunk_id_) + + current_chunk_pos_; + return current_rows; + } else { + return segment_->type() == SegmentType::Growing + ? current_chunk_id_ * size_per_chunk_ + + current_chunk_pos_ + : current_chunk_pos_; + } + } + + template + MultipleChunkDataAccessor + GetChunkData(FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos); + + template + ChunkDataAccessor + GetChunkData(FieldId field_id, int chunk_id, int data_barrier); + + MultipleChunkDataAccessor + GetChunkData(DataType data_type, + FieldId field_id, + bool index, + int64_t& current_chunk_id, + int64_t& current_chunk_pos); + + ChunkDataAccessor + GetChunkData(DataType data_type, + FieldId field_id, + int chunk_id, + int data_barrier); + int64_t + GetNextBatchSize(); + + template + VectorPtr + DoEval(); + + private: + bool is_indexed_; + + int64_t active_count_; + int64_t num_chunk_{0}; + int64_t current_chunk_id_{0}; + int64_t current_chunk_pos_{0}; + int64_t size_per_chunk_{0}; + + const segcore::SegmentInternalInterface* segment_; + int64_t batch_size_; + std::shared_ptr expr_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index fd9ef751387cb..e7a70103d54d5 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -395,7 +395,7 @@ class PhyCompareFilterExpr : public Expr { const FieldId right_field_; bool is_left_indexed_; bool is_right_indexed_; - int64_t active_count_{0}; + const int64_t active_count_; int64_t num_chunk_{0}; int64_t left_num_chunk_{0}; int64_t right_num_chunk_{0}; diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp index 1332217f477fc..690c0e490dca5 100644 --- a/internal/core/src/exec/expression/Expr.cpp +++ b/internal/core/src/exec/expression/Expr.cpp @@ -16,9 +16,12 @@ #include "Expr.h" +#include "common/EasyAssert.h" #include "exec/expression/AlwaysTrueExpr.h" #include "exec/expression/BinaryArithOpEvalRangeExpr.h" #include "exec/expression/BinaryRangeExpr.h" +#include "exec/expression/CallExpr.h" +#include "exec/expression/ColumnExpr.h" #include "exec/expression/CompareExpr.h" #include "exec/expression/ConjunctExpr.h" #include "exec/expression/ExistsExpr.h" @@ -27,6 +30,10 @@ #include "exec/expression/LogicalUnaryExpr.h" #include "exec/expression/TermExpr.h" #include "exec/expression/UnaryExpr.h" +#include "exec/expression/ValueExpr.h" + +#include + namespace milvus { namespace exec { @@ -156,8 +163,14 @@ CompileExpression(const expr::TypedExprPtr& expr, }; auto input_types = GetTypes(compiled_inputs); - if (auto call = dynamic_cast(expr.get())) { - // TODO: support function register and search mode + if (auto call = std::dynamic_pointer_cast(expr)) { + result = std::make_shared( + compiled_inputs, + call, + "PhyCallExpr", + context->get_segment(), + context->get_active_count(), + context->query_config()->get_expr_batch_size()); } else if (auto casted_expr = std::dynamic_pointer_cast< const milvus::expr::UnaryRangeFilterExpr>(expr)) { result = std::make_shared( @@ -251,6 +264,29 @@ CompileExpression(const expr::TypedExprPtr& expr, context->get_segment(), context->get_active_count(), context->query_config()->get_expr_batch_size()); + } else if (auto value_expr = + std::dynamic_pointer_cast( + expr)) { + // used for function call arguments, may emit any type + result = std::make_shared( + compiled_inputs, + value_expr, + "PhyValueExpr", + context->get_segment(), + context->get_active_count(), + context->query_config()->get_expr_batch_size()); + } else if (auto column_expr = + std::dynamic_pointer_cast( + expr)) { + result = std::make_shared( + compiled_inputs, + column_expr, + "PhyColumnExpr", + context->get_segment(), + context->get_active_count(), + context->query_config()->get_expr_batch_size()); + } else { + PanicInfo(ExprInvalid, "unsupport expr: ", expr->ToString()); } return result; } @@ -261,4 +297,4 @@ OptimizeCompiledExprs(ExecContext* context, const std::vector& exprs) { } } // namespace exec -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 25f90db4a249f..23a89f9358ac8 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -76,6 +76,7 @@ class Expr { DataType type_; const std::vector> inputs_; std::string name_; + // NOTE: unused std::shared_ptr vector_func_; }; @@ -83,6 +84,9 @@ using ExprPtr = std::shared_ptr; using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int); +/* + * The expr has only one column. + */ class SegmentExpr : public Expr { public: SegmentExpr(const std::vector&& input, @@ -610,7 +614,7 @@ CompileExpression(const expr::TypedExprPtr& expr, class ExprSet { public: explicit ExprSet(const std::vector& logical_exprs, - ExecContext* exec_ctx) { + ExecContext* exec_ctx) : exec_ctx_(exec_ctx) { exprs_ = CompileExpressions(logical_exprs, exec_ctx); } diff --git a/internal/core/src/exec/expression/ValueExpr.cpp b/internal/core/src/exec/expression/ValueExpr.cpp new file mode 100644 index 0000000000000..f442999f9ccc3 --- /dev/null +++ b/internal/core/src/exec/expression/ValueExpr.cpp @@ -0,0 +1,96 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ValueExpr.h" +#include "common/Vector.h" + +namespace milvus { +namespace exec { + +void +PhyValueExpr::Eval(EvalCtx& context, VectorPtr& result) { + int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_ + ? active_count_ - current_pos_ + : batch_size_; + + if (real_batch_size == 0) { + result = nullptr; + return; + } + + switch (expr_->type()) { + case DataType::BOOL: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().bool_val()); + break; + case DataType::INT8: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::INT16: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::INT32: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::INT64: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().int64_val()); + break; + case DataType::FLOAT: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().float_val()); + break; + case DataType::DOUBLE: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().float_val()); + break; + case DataType::STRING: + case DataType::VARCHAR: + result = std::make_shared>( + expr_->type(), + real_batch_size, + expr_->GetGenericValue().string_val()); + break; + // TODO: json and array type + case DataType::ARRAY: + case DataType::JSON: + default: + PanicInfo(DataTypeInvalid, + "PhyValueExpr not support data type " + + GetDataTypeName(expr_->type())); + } + current_pos_ += real_batch_size; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ValueExpr.h b/internal/core/src/exec/expression/ValueExpr.h new file mode 100644 index 0000000000000..044f46ac391e3 --- /dev/null +++ b/internal/core/src/exec/expression/ValueExpr.h @@ -0,0 +1,67 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyValueExpr : public Expr { + public: + PhyValueExpr(const std::vector>& input, + const std::shared_ptr expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + int64_t active_count, + int64_t batch_size) + : Expr(expr->type(), std::move(input), name), + expr_(expr), + active_count_(active_count), + batch_size_(batch_size) { + AssertInfo(input.empty(), + "PhyValueExpr should not have input, but got " + + std::to_string(input.size())); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + void + MoveCursor() override { + int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_ + ? active_count_ - current_pos_ + : batch_size_; + + current_pos_ += real_batch_size; + } + + private: + std::shared_ptr expr_; + const int64_t active_count_; + int64_t current_pos_{0}; + const int64_t batch_size_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/FunctionFactory.cpp b/internal/core/src/exec/expression/function/FunctionFactory.cpp new file mode 100644 index 0000000000000..b1fdef4fd3b7b --- /dev/null +++ b/internal/core/src/exec/expression/function/FunctionFactory.cpp @@ -0,0 +1,81 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/FunctionFactory.h" +#include +#include "exec/expression/function/impl/Empty.h" +#include "log/Log.h" + +namespace milvus { +namespace exec { +namespace expression { + +std::string +FilterFunctionRegisterKey::toString() const { + std::ostringstream oss; + oss << func_name << "("; + for (size_t i = 0; i < func_param_type_list.size(); ++i) { + oss << GetDataTypeName(func_param_type_list[i]); + if (i < func_param_type_list.size() - 1) { + oss << ", "; + } + } + + oss << ")"; + return oss.str(); +} + +FunctionFactory& +FunctionFactory::Instance() { + static FunctionFactory factory; + return factory; +} + +void +FunctionFactory::Initialize() { + std::call_once(init_flag_, &FunctionFactory::RegisterAllFunctions, this); +} + +void +FunctionFactory::RegisterAllFunctions() { + RegisterFilterFunction( + "empty", {DataType::VARCHAR}, function::EmptyVarchar); + + LOG_INFO("{} functions registered", GetFilterFunctionNum()); +} + +void +FunctionFactory::RegisterFilterFunction( + std::string func_name, + std::vector func_param_type_list, + FilterFunctionPtr func) { + filter_function_map_[FilterFunctionRegisterKey{ + func_name, func_param_type_list}] = func; +} + +const FilterFunctionPtr +FunctionFactory::GetFilterFunction( + const FilterFunctionRegisterKey& func_sig) const { + auto iter = filter_function_map_.find(func_sig); + if (iter != filter_function_map_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/FunctionFactory.h b/internal/core/src/exec/expression/function/FunctionFactory.h new file mode 100644 index 0000000000000..ed7d91a4e8739 --- /dev/null +++ b/internal/core/src/exec/expression/function/FunctionFactory.h @@ -0,0 +1,113 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "common/Vector.h" + +namespace milvus { +namespace exec { + +class EvalCtx; +class Expr; +class PhyCallExpr; + +namespace expression { + +struct FilterFunctionRegisterKey { + std::string func_name; + std::vector func_param_type_list; + + std::string + toString() const; + + bool + operator==(const FilterFunctionRegisterKey& other) const { + return func_name == other.func_name && + func_param_type_list == other.func_param_type_list; + } + + struct Hash { + size_t + operator()(const FilterFunctionRegisterKey& s) const { + size_t h1 = std::hash{}(s.func_name); + size_t h2 = boost::hash_range(s.func_param_type_list.begin(), + s.func_param_type_list.end()); + return h1 ^ h2; + } + }; +}; + +using FilterFunctionParameter = std::shared_ptr; +using FilterFunctionReturn = VectorPtr; +using FilterFunctionPtr = + void (*)(EvalCtx& context, + const RowVector& args, + FilterFunctionReturn& result); + +class FunctionFactory { + public: + static FunctionFactory& + Instance(); + + void + Initialize(); + + void + RegisterFilterFunction(std::string func_name, + std::vector func_param_type_list, + FilterFunctionPtr func); + + const FilterFunctionPtr + GetFilterFunction(const FilterFunctionRegisterKey& func_sig) const; + + size_t + GetFilterFunctionNum() const { + return filter_function_map_.size(); + } + + std::vector + ListAllFilterFunctions() const { + std::vector result; + for (const auto& [key, value] : filter_function_map_) { + result.push_back(key); + } + return result; + } + + private: + void + RegisterAllFunctions(); + + std::unordered_map + filter_function_map_; + std::once_flag init_flag_; +}; + +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/impl/Empty.cpp b/internal/core/src/exec/expression/function/impl/Empty.cpp new file mode 100644 index 0000000000000..4388aac03fbe0 --- /dev/null +++ b/internal/core/src/exec/expression/function/impl/Empty.cpp @@ -0,0 +1,48 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/impl/Empty.h" + +#include +#include +#include "common/EasyAssert.h" +#include "exec/expression/Expr.h" +#include "exec/expression/function/FunctionFactory.h" + +namespace milvus { +namespace exec { +namespace expression { +namespace function { + +void +EmptyVarchar(EvalCtx& context, + const RowVector& args, + FilterFunctionReturn& result) { + Assert(args.childrens().size() == 1); + auto arg = args.child(0); + if (auto vec = std::dynamic_pointer_cast>(arg)) { + auto bitmap = + vec->Apply([](const std::string& s) { return s.empty(); }); + result = std::make_shared(std::move(bitmap)); + } else { + PanicInfo(ExprInvalid, "invalid vector type"); + } +} + +} // namespace function +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/impl/Empty.h b/internal/core/src/exec/expression/function/impl/Empty.h new file mode 100644 index 0000000000000..5d595f4276072 --- /dev/null +++ b/internal/core/src/exec/expression/function/impl/Empty.h @@ -0,0 +1,33 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/Vector.h" +#include "exec/expression/function/FunctionFactory.h" + +namespace milvus { +namespace exec { +namespace expression { +namespace function { + +void +EmptyVarchar(EvalCtx& context, + const RowVector& args, + FilterFunctionReturn& result); + +} // namespace function +} // namespace expression +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/function/init_c.cpp b/internal/core/src/exec/expression/function/init_c.cpp new file mode 100644 index 0000000000000..072bd866a59db --- /dev/null +++ b/internal/core/src/exec/expression/function/init_c.cpp @@ -0,0 +1,23 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/expression/function/init_c.h" +#include "exec/expression/function/FunctionFactory.h" + +void +InitExecExpressionFunctionFactory() { + milvus::exec::expression::FunctionFactory::Instance().Initialize(); +} diff --git a/internal/core/src/exec/expression/function/init_c.h b/internal/core/src/exec/expression/function/init_c.h new file mode 100644 index 0000000000000..c7dbd3867f80d --- /dev/null +++ b/internal/core/src/exec/expression/function/init_c.h @@ -0,0 +1,28 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +void +InitExecExpressionFunctionFactory(); + +#ifdef __cplusplus +}; +#endif diff --git a/internal/core/src/exec/operator/FilterBitsNode.cpp b/internal/core/src/exec/operator/FilterBitsNode.cpp index 7ad302cbec371..d88adb9e99dcd 100644 --- a/internal/core/src/exec/operator/FilterBitsNode.cpp +++ b/internal/core/src/exec/operator/FilterBitsNode.cpp @@ -75,11 +75,23 @@ PhyFilterBitsNode::GetOutput() { "PhyFilterBitsNode result size should be size one and not " "be nullptr"); - auto col_vec = std::dynamic_pointer_cast(results_[0]); - auto col_vec_size = col_vec->size(); - TargetBitmapView view(col_vec->GetRawData(), col_vec_size); - bitset.append(view); - num_processed_rows_ += col_vec_size; + if (auto col_vec = + std::dynamic_pointer_cast(results_[0])) { + auto col_vec_size = col_vec->size(); + TargetBitmapView view(col_vec->GetRawData(), col_vec_size); + bitset.append(view); + num_processed_rows_ += col_vec_size; + } else if (auto vec = std::dynamic_pointer_cast>( + results_[0])) { + auto col_vec_size = col_vec->size(); + TargetBitmap result = + vec->Apply([](const bool& val) { return val; }); + bitset.append(result); + num_processed_rows_ += col_vec_size; + } else { + PanicInfo(ExprInvalid, + "PhyFilterBitsNode result should be ColumnVector"); + } } bitset.flip(); Assert(bitset.size() == need_process_rows_); @@ -97,4 +109,4 @@ PhyFilterBitsNode::GetOutput() { } } // namespace exec -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index f41b76d1a2001..320e616b4e947 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -21,6 +21,7 @@ #include #include +#include "exec/expression/function/FunctionFactory.h" #include "common/Exception.h" #include "common/Schema.h" #include "common/Types.h" @@ -211,6 +212,7 @@ class ITypeExpr { using TypedExprPtr = std::shared_ptr; +// NOTE: unused class InputTypeExpr : public ITypeExpr { public: InputTypeExpr(DataType type) : ITypeExpr(type) { @@ -224,42 +226,7 @@ class InputTypeExpr : public ITypeExpr { using InputTypeExprPtr = std::shared_ptr; -class CallTypeExpr : public ITypeExpr { - public: - CallTypeExpr(DataType type, - const std::vector& inputs, - std::string fun_name) - : ITypeExpr{type, std::move(inputs)} { - } - - virtual ~CallTypeExpr() = default; - - virtual const std::string& - name() const { - return name_; - } - - std::string - ToString() const override { - std::string str{}; - str += name(); - str += "("; - for (size_t i = 0; i < inputs_.size(); ++i) { - if (i != 0) { - str += ","; - } - str += inputs_[i]->ToString(); - } - str += ")"; - return str; - } - - private: - std::string name_; -}; - -using CallTypeExprPtr = std::shared_ptr; - +// NOTE: unused class FieldAccessTypeExpr : public ITypeExpr { public: FieldAccessTypeExpr(DataType type, const std::string& name) @@ -311,6 +278,71 @@ class ITypeFilterExpr : public ITypeExpr { virtual ~ITypeFilterExpr() = default; }; +class ColumnExpr : public ITypeExpr { + public: + explicit ColumnExpr(const ColumnInfo& column) + : ITypeExpr(column.data_type_), column_(column) { + } + + const ColumnInfo& + GetColumn() const { + return column_; + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "ColumnExpr: {columnInfo:" << column_.ToString() << "}"; + return ss.str(); + } + + private: + const ColumnInfo column_; +}; + +class ValueExpr : public ITypeExpr { + public: + explicit ValueExpr(const proto::plan::GenericValue& val) + : ITypeExpr(DataType::NONE), val_(val) { + switch (val.val_case()) { + case proto::plan::GenericValue::ValCase::kBoolVal: + type_ = DataType::BOOL; + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + type_ = DataType::INT64; + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + type_ = DataType::FLOAT; + break; + case proto::plan::GenericValue::ValCase::kStringVal: + type_ = DataType::VARCHAR; + break; + case proto::plan::GenericValue::ValCase::kArrayVal: + type_ = DataType::ARRAY; + break; + case proto::plan::GenericValue::ValCase::VAL_NOT_SET: + type_ = DataType::NONE; + break; + } + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "ValueExpr: {" + << " val:" << val_.DebugString() << "}"; + return ss.str(); + } + + const proto::plan::GenericValue + GetGenericValue() const { + return val_; + } + + private: + const proto::plan::GenericValue val_; +}; + class UnaryRangeFilterExpr : public ITypeFilterExpr { public: explicit UnaryRangeFilterExpr(const ColumnInfo& column, @@ -595,6 +627,46 @@ class BinaryArithOpEvalRangeExpr : public ITypeFilterExpr { const proto::plan::GenericValue value_; }; +class CallExpr : public ITypeFilterExpr { + public: + CallExpr(const std::string fun_name, + const std::vector& parameters, + const exec::expression::FilterFunctionPtr function_ptr) + : fun_name_(std::move(fun_name)), function_ptr_(function_ptr) { + inputs_.insert(inputs_.end(), parameters.begin(), parameters.end()); + } + + virtual ~CallExpr() = default; + + const std::string& + fun_name() const { + return fun_name_; + } + + const exec::expression::FilterFunctionPtr + function_ptr() const { + return function_ptr_; + } + + std::string + ToString() const override { + std::string parameters; + for (auto& e : inputs_) { + parameters += e->ToString(); + parameters += ", "; + } + return fmt::format("CallExpr:[Function Name: {}, Parameters: {}]", + fun_name_, + parameters); + } + + private: + const std::string fun_name_; + const exec::expression::FilterFunctionPtr function_ptr_; +}; + +using CallExprPtr = std::shared_ptr; + class CompareExpr : public ITypeFilterExpr { public: CompareExpr(const FieldId& left_field, diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index d61ad31ce92d4..21b2e7c841105 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -15,9 +15,11 @@ #include #include +#include #include "common/VectorTrait.h" #include "common/EasyAssert.h" +#include "exec/expression/function/FunctionFactory.h" #include "pb/plan.pb.h" #include "query/Utils.h" #include "knowhere/comp/materialized_view.h" @@ -256,6 +258,30 @@ ProtoParser::ParseBinaryRangeExprs( expr_pb.upper_inclusive()); } +expr::TypedExprPtr +ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) { + std::vector parameters; + std::vector func_param_type_list; + for (auto& param_expr : expr_pb.function_parameters()) { + // function parameter can be any type + auto e = this->ParseExprs(param_expr, TypeIsAny); + parameters.push_back(e); + func_param_type_list.push_back(e->type()); + } + auto& factory = exec::expression::FunctionFactory::Instance(); + exec::expression::FilterFunctionRegisterKey func_sig{ + expr_pb.function_name(), std::move(func_param_type_list)}; + + auto function = factory.GetFilterFunction(func_sig); + if (function == nullptr) { + throw std::runtime_error("FilterScalarFunction " + func_sig.toString() + + " not found. "); + } + return std::make_shared(expr_pb.function_name(), + parameters, + function); +} + expr::TypedExprPtr ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) { auto& left_column_info = expr_pb.left_column_info(); @@ -349,45 +375,80 @@ ProtoParser::ParseJsonContainsExprs( std::move(values)); } +expr::TypedExprPtr +ProtoParser::ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb) { + return std::make_shared(expr_pb.info()); +} + +expr::TypedExprPtr +ProtoParser::ParseValueExprs(const proto::plan::ValueExpr& expr_pb) { + return std::make_shared(expr_pb.value()); +} + expr::TypedExprPtr ProtoParser::CreateAlwaysTrueExprs() { return std::make_shared(); } expr::TypedExprPtr -ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { +ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb, + TypeCheckFunction type_check) { using ppe = proto::plan::Expr; + expr::TypedExprPtr result; switch (expr_pb.expr_case()) { case ppe::kUnaryRangeExpr: { - return ParseUnaryRangeExprs(expr_pb.unary_range_expr()); + result = ParseUnaryRangeExprs(expr_pb.unary_range_expr()); + break; } case ppe::kBinaryExpr: { - return ParseBinaryExprs(expr_pb.binary_expr()); + result = ParseBinaryExprs(expr_pb.binary_expr()); + break; } case ppe::kUnaryExpr: { - return ParseUnaryExprs(expr_pb.unary_expr()); + result = ParseUnaryExprs(expr_pb.unary_expr()); + break; } case ppe::kTermExpr: { - return ParseTermExprs(expr_pb.term_expr()); + result = ParseTermExprs(expr_pb.term_expr()); + break; } case ppe::kBinaryRangeExpr: { - return ParseBinaryRangeExprs(expr_pb.binary_range_expr()); + result = ParseBinaryRangeExprs(expr_pb.binary_range_expr()); + break; } case ppe::kCompareExpr: { - return ParseCompareExprs(expr_pb.compare_expr()); + result = ParseCompareExprs(expr_pb.compare_expr()); + break; } case ppe::kBinaryArithOpEvalRangeExpr: { - return ParseBinaryArithOpEvalRangeExprs( + result = ParseBinaryArithOpEvalRangeExprs( expr_pb.binary_arith_op_eval_range_expr()); + break; } case ppe::kExistsExpr: { - return ParseExistExprs(expr_pb.exists_expr()); + result = ParseExistExprs(expr_pb.exists_expr()); + break; } case ppe::kAlwaysTrueExpr: { - return CreateAlwaysTrueExprs(); + result = CreateAlwaysTrueExprs(); + break; } case ppe::kJsonContainsExpr: { - return ParseJsonContainsExprs(expr_pb.json_contains_expr()); + result = ParseJsonContainsExprs(expr_pb.json_contains_expr()); + break; + } + case ppe::kCallExpr: { + result = ParseCallExprs(expr_pb.call_expr()); + break; + } + // may emit various types + case ppe::kColumnExpr: { + result = ParseColumnExprs(expr_pb.column_expr()); + break; + } + case ppe::kValueExpr: { + result = ParseValueExprs(expr_pb.value_expr()); + break; } default: { std::string s; @@ -396,6 +457,11 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { std::string("unsupported expr proto node: ") + s); } } + if (type_check(result->type())) { + return result; + } + PanicInfo( + ExprInvalid, "expr type check failed, actual type: {}", result->type()); } } // namespace milvus::query diff --git a/internal/core/src/query/PlanProto.h b/internal/core/src/query/PlanProto.h index 63673cefb9270..28aaaaa0cb67f 100644 --- a/internal/core/src/query/PlanProto.h +++ b/internal/core/src/query/PlanProto.h @@ -23,6 +23,17 @@ namespace milvus::query { class ProtoParser { + public: + using TypeCheckFunction = std::function; + static bool + TypeIsBool(const DataType type) { + return type == DataType::BOOL; + } + static bool + TypeIsAny(const DataType) { + return true; + } + public: explicit ProtoParser(const Schema& schema) : schema(schema) { } @@ -40,10 +51,15 @@ class ProtoParser { CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto); expr::TypedExprPtr - ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb); + ParseExprs(const proto::plan::Expr& expr_pb, + TypeCheckFunction type_check = TypeIsBool); + + private: + expr::TypedExprPtr + CreateAlwaysTrueExprs(); expr::TypedExprPtr - ParseExprs(const proto::plan::Expr& expr_pb); + ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb); expr::TypedExprPtr ParseBinaryArithOpEvalRangeExprs( @@ -52,33 +68,39 @@ class ProtoParser { expr::TypedExprPtr ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb); + expr::TypedExprPtr + ParseCallExprs(const proto::plan::CallExpr& expr_pb); + + expr::TypedExprPtr + ParseColumnExprs(const proto::plan::ColumnExpr& expr_pb); + expr::TypedExprPtr ParseCompareExprs(const proto::plan::CompareExpr& expr_pb); expr::TypedExprPtr - ParseTermExprs(const proto::plan::TermExpr& expr_pb); + ParseExistExprs(const proto::plan::ExistsExpr& expr_pb); expr::TypedExprPtr - ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb); + ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb); expr::TypedExprPtr - ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb); + ParseTermExprs(const proto::plan::TermExpr& expr_pb); expr::TypedExprPtr - ParseExistExprs(const proto::plan::ExistsExpr& expr_pb); + ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb); expr::TypedExprPtr - ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb); + ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb); expr::TypedExprPtr - CreateAlwaysTrueExprs(); + ParseValueExprs(const proto::plan::ValueExpr& expr_pb); private: const Schema& schema; }; } // namespace milvus::query -// + template <> struct fmt::formatter : formatter { diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp index e26e911997b62..452f95b09b61a 100644 --- a/internal/core/unittest/test_exec.cpp +++ b/internal/core/unittest/test_exec.cpp @@ -27,6 +27,7 @@ #include "exec/QueryContext.h" #include "expr/ITypeExpr.h" #include "exec/expression/Expr.h" +#include "exec/expression/function/FunctionFactory.h" using namespace milvus; using namespace milvus::exec; @@ -40,6 +41,10 @@ class TaskTest : public testing::TestWithParam { using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + factory.Initialize(); + auto schema = std::make_shared(); auto vec_fid = schema->AddDebugField( "fakevec", GetParam(), 16, knowhere::metric::L2); @@ -113,6 +118,62 @@ INSTANTIATE_TEST_SUITE_P(TaskTestSuite, ::testing::Values(DataType::VECTOR_FLOAT, DataType::VECTOR_SPARSE_FLOAT)); +TEST_P(TaskTest, RegisterFunction) { + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + ASSERT_LE(factory.GetFilterFunctionNum(), 1); + auto all_functions = factory.ListAllFilterFunctions(); + for (auto& f : all_functions) { + std::cout << f.toString() << std::endl; + } + + auto func_ptr = factory.GetFilterFunction( + milvus::exec::expression::FilterFunctionRegisterKey{ + "empty", {DataType::VARCHAR}}); + ASSERT_TRUE(func_ptr != nullptr); +} + +TEST_P(TaskTest, CallExprEmpty) { + expr::ColumnInfo col(field_map_["string1"], DataType::VARCHAR); + std::vector parameters; + parameters.push_back(std::make_shared(col)); + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + auto empty_function_ptr = factory.GetFilterFunction( + milvus::exec::expression::FilterFunctionRegisterKey{ + "empty", {DataType::VARCHAR}}); + auto call_expr = std::make_shared( + "empty", parameters, empty_function_ptr); + ASSERT_EQ(call_expr->inputs().size(), 1); + std::vector sources; + auto filter_node = std::make_shared( + "plannode id 1", call_expr, sources); + auto plan = plan::PlanFragment(filter_node); + auto query_context = std::make_shared( + "test1", + segment_.get(), + 1000000, + MAX_TIMESTAMP, + std::make_shared( + std::unordered_map{})); + + auto start = std::chrono::steady_clock::now(); + auto task = Task::Create("task_call_expr_empty", plan, 0, query_context); + int64_t num_rows = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + num_rows += result->size(); + } + auto cost = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + std::cout << "cost: " << cost << "us" << std::endl; + EXPECT_EQ(num_rows, num_rows_); +} + TEST_P(TaskTest, UnaryExpr) { ::milvus::proto::plan::GenericValue value; value.set_int64_val(-1); @@ -355,4 +416,4 @@ TEST_P(TaskTest, CompileInputs_or_with_and) { "PhyUnaryRangeFilterExpr"); } } -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 2bfc4646d10af..cf62245e3f35b 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include "index/IndexFactory.h" #include "exec/expression/Expr.h" #include "exec/Task.h" +#include "exec/expression/function/FunctionFactory.h" #include "expr/ITypeExpr.h" #include "index/BitmapIndex.h" #include "index/InvertedIndexTantivy.h" @@ -901,6 +903,73 @@ TEST_P(ExprTest, TestTerm) { } } + +TEST_P(ExprTest, TestCall) { + milvus::exec::expression::FunctionFactory& factory = + milvus::exec::expression::FunctionFactory::Instance(); + factory.Initialize(); + std::string raw_plan = R"(vector_anns: < + field_id: 100 + predicates: < + call_expr: < + function_name: "empty" + function_parameters: < + field_id: 101 + data_type: VarChar + > + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField("fakevec", data_type, 16, metric_type); + auto varchar_fid = schema->AddDebugField("address", DataType::VARCHAR); + schema->set_primary_field_id(varchar_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector address_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_address_col = raw_data.get_col(varchar_fid); + address_col.insert( + address_col.end(), new_address_col.begin(), new_address_col.end() + ); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + + auto plan_str = translate_text_plan_with_metric_type(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + BitsetType final; + final = ExecuteQueryExpr( + plan->plan_node_->plannodes_->sources()[0]->sources()[0], + seg_promote, + N * num_iters, + MAX_TIMESTAMP); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + ASSERT_EQ(ans, address_col[i].empty()) << "@" << i << "!!" << address_col[i]; + } +} + TEST_P(ExprTest, TestCompare) { std::vector>> testcases = { diff --git a/internal/parser/planparserv2/Plan.g4 b/internal/parser/planparserv2/Plan.g4 index c0644436a281b..f28f471228772 100644 --- a/internal/parser/planparserv2/Plan.g4 +++ b/internal/parser/planparserv2/Plan.g4 @@ -24,6 +24,7 @@ expr: | (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll | (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny | ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength + | Identifier '(' ( expr (',' expr )* ','? )? ')' # Call | expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range | expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange | expr op = (LT | LE | GT | GE) expr # Relational diff --git a/internal/parser/planparserv2/check_identical_test.go b/internal/parser/planparserv2/check_identical_test.go index 9f48aec504d8e..321920e8c6bbc 100644 --- a/internal/parser/planparserv2/check_identical_test.go +++ b/internal/parser/planparserv2/check_identical_test.go @@ -14,17 +14,27 @@ func TestCheckIdentical(t *testing.T) { helper, err := typeutil.CreateSchemaHelper(schema) assert.NoError(t, err) - exprStr1 := `not (((Int64Field > 0) and (FloatField <= 20.0)) or ((Int32Field in [1, 2, 3]) and (VarCharField < "str")))` - exprStr2 := `Int32Field in [1, 2, 3]` + exprStr1Arr := []string{ + `not (((Int64Field > 0) and (FloatField <= 20.0)) or ((Int32Field in [1, 2, 3]) and (VarCharField < "str")))`, + `f1()`, + } + exprStr2Arr := []string{ + `Int32Field in [1, 2, 3]`, + `f2(Int32Field, Int64Field)`, + } + for i := range exprStr1Arr { + exprStr1 := exprStr1Arr[i] + exprStr2 := exprStr2Arr[i] - expr1, err := ParseExpr(helper, exprStr1) - assert.NoError(t, err) - expr2, err := ParseExpr(helper, exprStr2) - assert.NoError(t, err) + expr1, err := ParseExpr(helper, exprStr1) + assert.NoError(t, err) + expr2, err := ParseExpr(helper, exprStr2) + assert.NoError(t, err) - assert.True(t, CheckPredicatesIdentical(expr1, expr1)) - assert.True(t, CheckPredicatesIdentical(expr2, expr2)) - assert.False(t, CheckPredicatesIdentical(expr1, expr2)) + assert.True(t, CheckPredicatesIdentical(expr1, expr1)) + assert.True(t, CheckPredicatesIdentical(expr2, expr2)) + assert.False(t, CheckPredicatesIdentical(expr1, expr2)) + } } func TestCheckQueryInfoIdentical(t *testing.T) { diff --git a/internal/parser/planparserv2/generated/Plan.interp b/internal/parser/planparserv2/generated/Plan.interp index 41ef66eeefa7d..8cb8890c4f66f 100644 --- a/internal/parser/planparserv2/generated/Plan.interp +++ b/internal/parser/planparserv2/generated/Plan.interp @@ -101,4 +101,4 @@ expr atn: -[4, 1, 46, 123, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 64, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 118, 8, 0, 10, 0, 12, 0, 121, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 154, 0, 63, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 64, 5, 40, 0, 0, 4, 64, 5, 41, 0, 0, 5, 64, 5, 39, 0, 0, 6, 64, 5, 43, 0, 0, 7, 64, 5, 42, 0, 0, 8, 64, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 64, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 64, 1, 0, 0, 0, 27, 64, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 64, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 64, 3, 0, 0, 19, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 64, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 64, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 64, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 64, 5, 2, 0, 0, 61, 62, 5, 13, 0, 0, 62, 64, 3, 0, 0, 1, 63, 2, 1, 0, 0, 0, 63, 4, 1, 0, 0, 0, 63, 5, 1, 0, 0, 0, 63, 6, 1, 0, 0, 0, 63, 7, 1, 0, 0, 0, 63, 8, 1, 0, 0, 0, 63, 9, 1, 0, 0, 0, 63, 13, 1, 0, 0, 0, 63, 27, 1, 0, 0, 0, 63, 28, 1, 0, 0, 0, 63, 34, 1, 0, 0, 0, 63, 36, 1, 0, 0, 0, 63, 43, 1, 0, 0, 0, 63, 50, 1, 0, 0, 0, 63, 57, 1, 0, 0, 0, 63, 61, 1, 0, 0, 0, 64, 119, 1, 0, 0, 0, 65, 66, 10, 20, 0, 0, 66, 67, 5, 20, 0, 0, 67, 118, 3, 0, 0, 21, 68, 69, 10, 18, 0, 0, 69, 70, 7, 5, 0, 0, 70, 118, 3, 0, 0, 19, 71, 72, 10, 17, 0, 0, 72, 73, 7, 6, 0, 0, 73, 118, 3, 0, 0, 18, 74, 75, 10, 16, 0, 0, 75, 76, 7, 7, 0, 0, 76, 118, 3, 0, 0, 17, 77, 79, 10, 15, 0, 0, 78, 80, 5, 29, 0, 0, 79, 78, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 82, 5, 30, 0, 0, 82, 118, 3, 0, 0, 16, 83, 84, 10, 10, 0, 0, 84, 85, 7, 8, 0, 0, 85, 86, 7, 4, 0, 0, 86, 87, 7, 8, 0, 0, 87, 118, 3, 0, 0, 11, 88, 89, 10, 9, 0, 0, 89, 90, 7, 9, 0, 0, 90, 91, 7, 4, 0, 0, 91, 92, 7, 9, 0, 0, 92, 118, 3, 0, 0, 10, 93, 94, 10, 8, 0, 0, 94, 95, 7, 10, 0, 0, 95, 118, 3, 0, 0, 9, 96, 97, 10, 7, 0, 0, 97, 98, 7, 11, 0, 0, 98, 118, 3, 0, 0, 8, 99, 100, 10, 6, 0, 0, 100, 101, 5, 23, 0, 0, 101, 118, 3, 0, 0, 7, 102, 103, 10, 5, 0, 0, 103, 104, 5, 25, 0, 0, 104, 118, 3, 0, 0, 6, 105, 106, 10, 4, 0, 0, 106, 107, 5, 24, 0, 0, 107, 118, 3, 0, 0, 5, 108, 109, 10, 3, 0, 0, 109, 110, 5, 26, 0, 0, 110, 118, 3, 0, 0, 4, 111, 112, 10, 2, 0, 0, 112, 113, 5, 27, 0, 0, 113, 118, 3, 0, 0, 3, 114, 115, 10, 22, 0, 0, 115, 116, 5, 12, 0, 0, 116, 118, 5, 43, 0, 0, 117, 65, 1, 0, 0, 0, 117, 68, 1, 0, 0, 0, 117, 71, 1, 0, 0, 0, 117, 74, 1, 0, 0, 0, 117, 77, 1, 0, 0, 0, 117, 83, 1, 0, 0, 0, 117, 88, 1, 0, 0, 0, 117, 93, 1, 0, 0, 0, 117, 96, 1, 0, 0, 0, 117, 99, 1, 0, 0, 0, 117, 102, 1, 0, 0, 0, 117, 105, 1, 0, 0, 0, 117, 108, 1, 0, 0, 0, 117, 111, 1, 0, 0, 0, 117, 114, 1, 0, 0, 0, 118, 121, 1, 0, 0, 0, 119, 117, 1, 0, 0, 0, 119, 120, 1, 0, 0, 0, 120, 1, 1, 0, 0, 0, 121, 119, 1, 0, 0, 0, 6, 19, 23, 63, 79, 117, 119] \ No newline at end of file +[4, 1, 46, 139, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 67, 8, 0, 10, 0, 12, 0, 70, 9, 0, 1, 0, 3, 0, 73, 8, 0, 3, 0, 75, 8, 0, 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 96, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 134, 8, 0, 10, 0, 12, 0, 137, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 174, 0, 79, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 80, 5, 40, 0, 0, 4, 80, 5, 41, 0, 0, 5, 80, 5, 39, 0, 0, 6, 80, 5, 43, 0, 0, 7, 80, 5, 42, 0, 0, 8, 80, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 80, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 80, 1, 0, 0, 0, 27, 80, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 80, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 80, 3, 0, 0, 20, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 80, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 80, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 80, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 80, 5, 2, 0, 0, 61, 62, 5, 42, 0, 0, 62, 74, 5, 1, 0, 0, 63, 68, 3, 0, 0, 0, 64, 65, 5, 4, 0, 0, 65, 67, 3, 0, 0, 0, 66, 64, 1, 0, 0, 0, 67, 70, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 69, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 71, 73, 5, 4, 0, 0, 72, 71, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 75, 1, 0, 0, 0, 74, 63, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 80, 5, 2, 0, 0, 77, 78, 5, 13, 0, 0, 78, 80, 3, 0, 0, 1, 79, 2, 1, 0, 0, 0, 79, 4, 1, 0, 0, 0, 79, 5, 1, 0, 0, 0, 79, 6, 1, 0, 0, 0, 79, 7, 1, 0, 0, 0, 79, 8, 1, 0, 0, 0, 79, 9, 1, 0, 0, 0, 79, 13, 1, 0, 0, 0, 79, 27, 1, 0, 0, 0, 79, 28, 1, 0, 0, 0, 79, 34, 1, 0, 0, 0, 79, 36, 1, 0, 0, 0, 79, 43, 1, 0, 0, 0, 79, 50, 1, 0, 0, 0, 79, 57, 1, 0, 0, 0, 79, 61, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 80, 135, 1, 0, 0, 0, 81, 82, 10, 21, 0, 0, 82, 83, 5, 20, 0, 0, 83, 134, 3, 0, 0, 22, 84, 85, 10, 19, 0, 0, 85, 86, 7, 5, 0, 0, 86, 134, 3, 0, 0, 20, 87, 88, 10, 18, 0, 0, 88, 89, 7, 6, 0, 0, 89, 134, 3, 0, 0, 19, 90, 91, 10, 17, 0, 0, 91, 92, 7, 7, 0, 0, 92, 134, 3, 0, 0, 18, 93, 95, 10, 16, 0, 0, 94, 96, 5, 29, 0, 0, 95, 94, 1, 0, 0, 0, 95, 96, 1, 0, 0, 0, 96, 97, 1, 0, 0, 0, 97, 98, 5, 30, 0, 0, 98, 134, 3, 0, 0, 17, 99, 100, 10, 10, 0, 0, 100, 101, 7, 8, 0, 0, 101, 102, 7, 4, 0, 0, 102, 103, 7, 8, 0, 0, 103, 134, 3, 0, 0, 11, 104, 105, 10, 9, 0, 0, 105, 106, 7, 9, 0, 0, 106, 107, 7, 4, 0, 0, 107, 108, 7, 9, 0, 0, 108, 134, 3, 0, 0, 10, 109, 110, 10, 8, 0, 0, 110, 111, 7, 10, 0, 0, 111, 134, 3, 0, 0, 9, 112, 113, 10, 7, 0, 0, 113, 114, 7, 11, 0, 0, 114, 134, 3, 0, 0, 8, 115, 116, 10, 6, 0, 0, 116, 117, 5, 23, 0, 0, 117, 134, 3, 0, 0, 7, 118, 119, 10, 5, 0, 0, 119, 120, 5, 25, 0, 0, 120, 134, 3, 0, 0, 6, 121, 122, 10, 4, 0, 0, 122, 123, 5, 24, 0, 0, 123, 134, 3, 0, 0, 5, 124, 125, 10, 3, 0, 0, 125, 126, 5, 26, 0, 0, 126, 134, 3, 0, 0, 4, 127, 128, 10, 2, 0, 0, 128, 129, 5, 27, 0, 0, 129, 134, 3, 0, 0, 3, 130, 131, 10, 23, 0, 0, 131, 132, 5, 12, 0, 0, 132, 134, 5, 43, 0, 0, 133, 81, 1, 0, 0, 0, 133, 84, 1, 0, 0, 0, 133, 87, 1, 0, 0, 0, 133, 90, 1, 0, 0, 0, 133, 93, 1, 0, 0, 0, 133, 99, 1, 0, 0, 0, 133, 104, 1, 0, 0, 0, 133, 109, 1, 0, 0, 0, 133, 112, 1, 0, 0, 0, 133, 115, 1, 0, 0, 0, 133, 118, 1, 0, 0, 0, 133, 121, 1, 0, 0, 0, 133, 124, 1, 0, 0, 0, 133, 127, 1, 0, 0, 0, 133, 130, 1, 0, 0, 0, 134, 137, 1, 0, 0, 0, 135, 133, 1, 0, 0, 0, 135, 136, 1, 0, 0, 0, 136, 1, 1, 0, 0, 0, 137, 135, 1, 0, 0, 0, 9, 19, 23, 68, 72, 74, 79, 95, 133, 135] \ No newline at end of file diff --git a/internal/parser/planparserv2/generated/plan_base_visitor.go b/internal/parser/planparserv2/generated/plan_base_visitor.go index e8ae619676116..2e7a30e771a92 100644 --- a/internal/parser/planparserv2/generated/plan_base_visitor.go +++ b/internal/parser/planparserv2/generated/plan_base_visitor.go @@ -59,6 +59,10 @@ func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} { return v.VisitChildren(ctx) } +func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} { + return v.VisitChildren(ctx) +} + func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} { return v.VisitChildren(ctx) } diff --git a/internal/parser/planparserv2/generated/plan_parser.go b/internal/parser/planparserv2/generated/plan_parser.go index e5dc91fda218a..8869e09e0a0c6 100644 --- a/internal/parser/planparserv2/generated/plan_parser.go +++ b/internal/parser/planparserv2/generated/plan_parser.go @@ -50,65 +50,73 @@ func planParserInit() { } staticData.PredictionContextCache = antlr.NewPredictionContextCache() staticData.serializedATN = []int32{ - 4, 1, 46, 123, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 4, 1, 46, 139, 2, 0, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 18, 8, 0, 10, 0, 12, 0, 21, 9, 0, 1, 0, 3, 0, 24, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, - 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 64, 8, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, + 67, 8, 0, 10, 0, 12, 0, 70, 9, 0, 1, 0, 3, 0, 73, 8, 0, 3, 0, 75, 8, 0, + 1, 0, 1, 0, 1, 0, 3, 0, 80, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3, 0, 96, 8, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, - 1, 0, 1, 0, 1, 0, 5, 0, 118, 8, 0, 10, 0, 12, 0, 121, 9, 0, 1, 0, 0, 1, - 0, 1, 0, 0, 12, 2, 0, 15, 16, 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, - 36, 36, 2, 0, 34, 34, 37, 37, 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, - 15, 16, 1, 0, 21, 22, 1, 0, 6, 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, - 154, 0, 63, 1, 0, 0, 0, 2, 3, 6, 0, -1, 0, 3, 64, 5, 40, 0, 0, 4, 64, 5, - 41, 0, 0, 5, 64, 5, 39, 0, 0, 6, 64, 5, 43, 0, 0, 7, 64, 5, 42, 0, 0, 8, - 64, 5, 44, 0, 0, 9, 10, 5, 1, 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, - 0, 12, 64, 1, 0, 0, 0, 13, 14, 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, - 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, - 19, 17, 1, 0, 0, 0, 19, 20, 1, 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, - 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, - 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, 0, 26, 64, 1, 0, 0, 0, 27, 64, 5, 31, - 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, - 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, 0, 33, 64, 5, 2, 0, 0, 34, 35, 7, 0, - 0, 0, 35, 64, 3, 0, 0, 19, 36, 37, 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, - 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, - 0, 42, 64, 1, 0, 0, 0, 43, 44, 7, 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, - 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, - 49, 64, 1, 0, 0, 0, 50, 51, 7, 3, 0, 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, - 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, - 64, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, - 0, 0, 60, 64, 5, 2, 0, 0, 61, 62, 5, 13, 0, 0, 62, 64, 3, 0, 0, 1, 63, - 2, 1, 0, 0, 0, 63, 4, 1, 0, 0, 0, 63, 5, 1, 0, 0, 0, 63, 6, 1, 0, 0, 0, - 63, 7, 1, 0, 0, 0, 63, 8, 1, 0, 0, 0, 63, 9, 1, 0, 0, 0, 63, 13, 1, 0, - 0, 0, 63, 27, 1, 0, 0, 0, 63, 28, 1, 0, 0, 0, 63, 34, 1, 0, 0, 0, 63, 36, - 1, 0, 0, 0, 63, 43, 1, 0, 0, 0, 63, 50, 1, 0, 0, 0, 63, 57, 1, 0, 0, 0, - 63, 61, 1, 0, 0, 0, 64, 119, 1, 0, 0, 0, 65, 66, 10, 20, 0, 0, 66, 67, - 5, 20, 0, 0, 67, 118, 3, 0, 0, 21, 68, 69, 10, 18, 0, 0, 69, 70, 7, 5, - 0, 0, 70, 118, 3, 0, 0, 19, 71, 72, 10, 17, 0, 0, 72, 73, 7, 6, 0, 0, 73, - 118, 3, 0, 0, 18, 74, 75, 10, 16, 0, 0, 75, 76, 7, 7, 0, 0, 76, 118, 3, - 0, 0, 17, 77, 79, 10, 15, 0, 0, 78, 80, 5, 29, 0, 0, 79, 78, 1, 0, 0, 0, - 79, 80, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 82, 5, 30, 0, 0, 82, 118, 3, - 0, 0, 16, 83, 84, 10, 10, 0, 0, 84, 85, 7, 8, 0, 0, 85, 86, 7, 4, 0, 0, - 86, 87, 7, 8, 0, 0, 87, 118, 3, 0, 0, 11, 88, 89, 10, 9, 0, 0, 89, 90, - 7, 9, 0, 0, 90, 91, 7, 4, 0, 0, 91, 92, 7, 9, 0, 0, 92, 118, 3, 0, 0, 10, - 93, 94, 10, 8, 0, 0, 94, 95, 7, 10, 0, 0, 95, 118, 3, 0, 0, 9, 96, 97, - 10, 7, 0, 0, 97, 98, 7, 11, 0, 0, 98, 118, 3, 0, 0, 8, 99, 100, 10, 6, - 0, 0, 100, 101, 5, 23, 0, 0, 101, 118, 3, 0, 0, 7, 102, 103, 10, 5, 0, - 0, 103, 104, 5, 25, 0, 0, 104, 118, 3, 0, 0, 6, 105, 106, 10, 4, 0, 0, - 106, 107, 5, 24, 0, 0, 107, 118, 3, 0, 0, 5, 108, 109, 10, 3, 0, 0, 109, - 110, 5, 26, 0, 0, 110, 118, 3, 0, 0, 4, 111, 112, 10, 2, 0, 0, 112, 113, - 5, 27, 0, 0, 113, 118, 3, 0, 0, 3, 114, 115, 10, 22, 0, 0, 115, 116, 5, - 12, 0, 0, 116, 118, 5, 43, 0, 0, 117, 65, 1, 0, 0, 0, 117, 68, 1, 0, 0, - 0, 117, 71, 1, 0, 0, 0, 117, 74, 1, 0, 0, 0, 117, 77, 1, 0, 0, 0, 117, - 83, 1, 0, 0, 0, 117, 88, 1, 0, 0, 0, 117, 93, 1, 0, 0, 0, 117, 96, 1, 0, - 0, 0, 117, 99, 1, 0, 0, 0, 117, 102, 1, 0, 0, 0, 117, 105, 1, 0, 0, 0, - 117, 108, 1, 0, 0, 0, 117, 111, 1, 0, 0, 0, 117, 114, 1, 0, 0, 0, 118, - 121, 1, 0, 0, 0, 119, 117, 1, 0, 0, 0, 119, 120, 1, 0, 0, 0, 120, 1, 1, - 0, 0, 0, 121, 119, 1, 0, 0, 0, 6, 19, 23, 63, 79, 117, 119, + 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 5, 0, 134, + 8, 0, 10, 0, 12, 0, 137, 9, 0, 1, 0, 0, 1, 0, 1, 0, 0, 12, 2, 0, 15, 16, + 28, 29, 2, 0, 32, 32, 35, 35, 2, 0, 33, 33, 36, 36, 2, 0, 34, 34, 37, 37, + 2, 0, 42, 42, 44, 44, 1, 0, 17, 19, 1, 0, 15, 16, 1, 0, 21, 22, 1, 0, 6, + 7, 1, 0, 8, 9, 1, 0, 6, 9, 1, 0, 10, 11, 174, 0, 79, 1, 0, 0, 0, 2, 3, + 6, 0, -1, 0, 3, 80, 5, 40, 0, 0, 4, 80, 5, 41, 0, 0, 5, 80, 5, 39, 0, 0, + 6, 80, 5, 43, 0, 0, 7, 80, 5, 42, 0, 0, 8, 80, 5, 44, 0, 0, 9, 10, 5, 1, + 0, 0, 10, 11, 3, 0, 0, 0, 11, 12, 5, 2, 0, 0, 12, 80, 1, 0, 0, 0, 13, 14, + 5, 3, 0, 0, 14, 19, 3, 0, 0, 0, 15, 16, 5, 4, 0, 0, 16, 18, 3, 0, 0, 0, + 17, 15, 1, 0, 0, 0, 18, 21, 1, 0, 0, 0, 19, 17, 1, 0, 0, 0, 19, 20, 1, + 0, 0, 0, 20, 23, 1, 0, 0, 0, 21, 19, 1, 0, 0, 0, 22, 24, 5, 4, 0, 0, 23, + 22, 1, 0, 0, 0, 23, 24, 1, 0, 0, 0, 24, 25, 1, 0, 0, 0, 25, 26, 5, 5, 0, + 0, 26, 80, 1, 0, 0, 0, 27, 80, 5, 31, 0, 0, 28, 29, 5, 14, 0, 0, 29, 30, + 5, 1, 0, 0, 30, 31, 5, 42, 0, 0, 31, 32, 5, 4, 0, 0, 32, 33, 5, 43, 0, + 0, 33, 80, 5, 2, 0, 0, 34, 35, 7, 0, 0, 0, 35, 80, 3, 0, 0, 20, 36, 37, + 7, 1, 0, 0, 37, 38, 5, 1, 0, 0, 38, 39, 3, 0, 0, 0, 39, 40, 5, 4, 0, 0, + 40, 41, 3, 0, 0, 0, 41, 42, 5, 2, 0, 0, 42, 80, 1, 0, 0, 0, 43, 44, 7, + 2, 0, 0, 44, 45, 5, 1, 0, 0, 45, 46, 3, 0, 0, 0, 46, 47, 5, 4, 0, 0, 47, + 48, 3, 0, 0, 0, 48, 49, 5, 2, 0, 0, 49, 80, 1, 0, 0, 0, 50, 51, 7, 3, 0, + 0, 51, 52, 5, 1, 0, 0, 52, 53, 3, 0, 0, 0, 53, 54, 5, 4, 0, 0, 54, 55, + 3, 0, 0, 0, 55, 56, 5, 2, 0, 0, 56, 80, 1, 0, 0, 0, 57, 58, 5, 38, 0, 0, + 58, 59, 5, 1, 0, 0, 59, 60, 7, 4, 0, 0, 60, 80, 5, 2, 0, 0, 61, 62, 5, + 42, 0, 0, 62, 74, 5, 1, 0, 0, 63, 68, 3, 0, 0, 0, 64, 65, 5, 4, 0, 0, 65, + 67, 3, 0, 0, 0, 66, 64, 1, 0, 0, 0, 67, 70, 1, 0, 0, 0, 68, 66, 1, 0, 0, + 0, 68, 69, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 71, 73, + 5, 4, 0, 0, 72, 71, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 75, 1, 0, 0, 0, + 74, 63, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 80, 5, + 2, 0, 0, 77, 78, 5, 13, 0, 0, 78, 80, 3, 0, 0, 1, 79, 2, 1, 0, 0, 0, 79, + 4, 1, 0, 0, 0, 79, 5, 1, 0, 0, 0, 79, 6, 1, 0, 0, 0, 79, 7, 1, 0, 0, 0, + 79, 8, 1, 0, 0, 0, 79, 9, 1, 0, 0, 0, 79, 13, 1, 0, 0, 0, 79, 27, 1, 0, + 0, 0, 79, 28, 1, 0, 0, 0, 79, 34, 1, 0, 0, 0, 79, 36, 1, 0, 0, 0, 79, 43, + 1, 0, 0, 0, 79, 50, 1, 0, 0, 0, 79, 57, 1, 0, 0, 0, 79, 61, 1, 0, 0, 0, + 79, 77, 1, 0, 0, 0, 80, 135, 1, 0, 0, 0, 81, 82, 10, 21, 0, 0, 82, 83, + 5, 20, 0, 0, 83, 134, 3, 0, 0, 22, 84, 85, 10, 19, 0, 0, 85, 86, 7, 5, + 0, 0, 86, 134, 3, 0, 0, 20, 87, 88, 10, 18, 0, 0, 88, 89, 7, 6, 0, 0, 89, + 134, 3, 0, 0, 19, 90, 91, 10, 17, 0, 0, 91, 92, 7, 7, 0, 0, 92, 134, 3, + 0, 0, 18, 93, 95, 10, 16, 0, 0, 94, 96, 5, 29, 0, 0, 95, 94, 1, 0, 0, 0, + 95, 96, 1, 0, 0, 0, 96, 97, 1, 0, 0, 0, 97, 98, 5, 30, 0, 0, 98, 134, 3, + 0, 0, 17, 99, 100, 10, 10, 0, 0, 100, 101, 7, 8, 0, 0, 101, 102, 7, 4, + 0, 0, 102, 103, 7, 8, 0, 0, 103, 134, 3, 0, 0, 11, 104, 105, 10, 9, 0, + 0, 105, 106, 7, 9, 0, 0, 106, 107, 7, 4, 0, 0, 107, 108, 7, 9, 0, 0, 108, + 134, 3, 0, 0, 10, 109, 110, 10, 8, 0, 0, 110, 111, 7, 10, 0, 0, 111, 134, + 3, 0, 0, 9, 112, 113, 10, 7, 0, 0, 113, 114, 7, 11, 0, 0, 114, 134, 3, + 0, 0, 8, 115, 116, 10, 6, 0, 0, 116, 117, 5, 23, 0, 0, 117, 134, 3, 0, + 0, 7, 118, 119, 10, 5, 0, 0, 119, 120, 5, 25, 0, 0, 120, 134, 3, 0, 0, + 6, 121, 122, 10, 4, 0, 0, 122, 123, 5, 24, 0, 0, 123, 134, 3, 0, 0, 5, + 124, 125, 10, 3, 0, 0, 125, 126, 5, 26, 0, 0, 126, 134, 3, 0, 0, 4, 127, + 128, 10, 2, 0, 0, 128, 129, 5, 27, 0, 0, 129, 134, 3, 0, 0, 3, 130, 131, + 10, 23, 0, 0, 131, 132, 5, 12, 0, 0, 132, 134, 5, 43, 0, 0, 133, 81, 1, + 0, 0, 0, 133, 84, 1, 0, 0, 0, 133, 87, 1, 0, 0, 0, 133, 90, 1, 0, 0, 0, + 133, 93, 1, 0, 0, 0, 133, 99, 1, 0, 0, 0, 133, 104, 1, 0, 0, 0, 133, 109, + 1, 0, 0, 0, 133, 112, 1, 0, 0, 0, 133, 115, 1, 0, 0, 0, 133, 118, 1, 0, + 0, 0, 133, 121, 1, 0, 0, 0, 133, 124, 1, 0, 0, 0, 133, 127, 1, 0, 0, 0, + 133, 130, 1, 0, 0, 0, 134, 137, 1, 0, 0, 0, 135, 133, 1, 0, 0, 0, 135, + 136, 1, 0, 0, 0, 136, 1, 1, 0, 0, 0, 137, 135, 1, 0, 0, 0, 9, 19, 23, 68, + 72, 74, 79, 95, 133, 135, } deserializer := antlr.NewATNDeserializer(nil) staticData.atn = deserializer.Deserialize(staticData.serializedATN) @@ -981,6 +989,79 @@ func (s *ShiftContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { } } +type CallContext struct { + ExprContext +} + +func NewCallContext(parser antlr.Parser, ctx antlr.ParserRuleContext) *CallContext { + var p = new(CallContext) + + InitEmptyExprContext(&p.ExprContext) + p.parser = parser + p.CopyAll(ctx.(*ExprContext)) + + return p +} + +func (s *CallContext) GetRuleContext() antlr.RuleContext { + return s +} + +func (s *CallContext) Identifier() antlr.TerminalNode { + return s.GetToken(PlanParserIdentifier, 0) +} + +func (s *CallContext) AllExpr() []IExprContext { + children := s.GetChildren() + len := 0 + for _, ctx := range children { + if _, ok := ctx.(IExprContext); ok { + len++ + } + } + + tst := make([]IExprContext, len) + i := 0 + for _, ctx := range children { + if t, ok := ctx.(IExprContext); ok { + tst[i] = t.(IExprContext) + i++ + } + } + + return tst +} + +func (s *CallContext) Expr(i int) IExprContext { + var t antlr.RuleContext + j := 0 + for _, ctx := range s.GetChildren() { + if _, ok := ctx.(IExprContext); ok { + if j == i { + t = ctx.(antlr.RuleContext) + break + } + j++ + } + } + + if t == nil { + return nil + } + + return t.(IExprContext) +} + +func (s *CallContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { + switch t := visitor.(type) { + case PlanVisitor: + return t.VisitCall(s) + + default: + return t.VisitChildren(s) + } +} + type ReverseRangeContext struct { ExprContext op1 antlr.Token @@ -2231,14 +2312,14 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { var _alt int p.EnterOuterAlt(localctx, 1) - p.SetState(63) + p.SetState(79) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - switch p.GetTokenStream().LA(1) { - case PlanParserIntegerConstant: + switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) { + case 1: localctx = NewIntegerContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2252,7 +2333,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserFloatingConstant: + case 2: localctx = NewFloatingContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2265,7 +2346,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserBooleanConstant: + case 3: localctx = NewBooleanContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2278,7 +2359,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserStringLiteral: + case 4: localctx = NewStringContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2291,7 +2372,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserIdentifier: + case 5: localctx = NewIdentifierContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2304,7 +2385,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserJSONIdentifier: + case 6: localctx = NewJSONIdentifierContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2317,7 +2398,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserT__0: + case 7: localctx = NewParensContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2342,7 +2423,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserT__2: + case 8: localctx = NewArrayContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2420,7 +2501,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserEmptyArray: + case 9: localctx = NewEmptyArrayContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2433,7 +2514,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserTEXTMATCH: + case 10: localctx = NewTextMatchContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2486,7 +2567,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserADD, PlanParserSUB, PlanParserBNOT, PlanParserNOT: + case 11: localctx = NewUnaryContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2510,10 +2591,10 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { p.SetState(35) - p.expr(19) + p.expr(20) } - case PlanParserJSONContains, PlanParserArrayContains: + case 12: localctx = NewJSONContainsContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2561,7 +2642,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserJSONContainsAll, PlanParserArrayContainsAll: + case 13: localctx = NewJSONContainsAllContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2609,7 +2690,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserJSONContainsAny, PlanParserArrayContainsAny: + case 14: localctx = NewJSONContainsAnyContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2657,7 +2738,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserArrayLength: + case 15: localctx = NewArrayLengthContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx @@ -2697,13 +2778,13 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - case PlanParserEXISTS: - localctx = NewExistsContext(p, localctx) + case 16: + localctx = NewCallContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx { p.SetState(61) - p.Match(PlanParserEXISTS) + p.Match(PlanParserIdentifier) if p.HasError() { // Recognition error - abort rule goto errorExit @@ -2711,20 +2792,115 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { p.SetState(62) + p.Match(PlanParserT__0) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + p.SetState(74) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _la = p.GetTokenStream().LA(1) + + if (int64(_la) & ^0x3f) == 0 && ((int64(1)<<_la)&35183030034442) != 0 { + { + p.SetState(63) + p.expr(0) + } + p.SetState(68) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 2, p.GetParserRuleContext()) + if p.HasError() { + goto errorExit + } + for _alt != 2 && _alt != antlr.ATNInvalidAltNumber { + if _alt == 1 { + { + p.SetState(64) + p.Match(PlanParserT__3) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + { + p.SetState(65) + p.expr(0) + } + + } + p.SetState(70) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 2, p.GetParserRuleContext()) + if p.HasError() { + goto errorExit + } + } + p.SetState(72) + p.GetErrorHandler().Sync(p) + if p.HasError() { + goto errorExit + } + _la = p.GetTokenStream().LA(1) + + if _la == PlanParserT__3 { + { + p.SetState(71) + p.Match(PlanParserT__3) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + + } + + } + { + p.SetState(76) + p.Match(PlanParserT__1) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + + case 17: + localctx = NewExistsContext(p, localctx) + p.SetParserRuleContext(localctx) + _prevctx = localctx + { + p.SetState(77) + p.Match(PlanParserEXISTS) + if p.HasError() { + // Recognition error - abort rule + goto errorExit + } + } + { + p.SetState(78) p.expr(1) } - default: - p.SetError(antlr.NewNoViableAltException(p, nil, nil, nil, nil, nil)) + case antlr.ATNInvalidAltNumber: goto errorExit } p.GetParserRuleContext().SetStop(p.GetTokenStream().LT(-1)) - p.SetState(119) + p.SetState(135) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 8, p.GetParserRuleContext()) if p.HasError() { goto errorExit } @@ -2734,24 +2910,24 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { p.TriggerExitRuleEvent() } _prevctx = localctx - p.SetState(117) + p.SetState(133) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 4, p.GetParserRuleContext()) { + switch p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 7, p.GetParserRuleContext()) { case 1: localctx = NewPowerContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(65) + p.SetState(81) - if !(p.Precpred(p.GetParserRuleContext(), 20)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 20)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 21)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 21)", "")) goto errorExit } { - p.SetState(66) + p.SetState(82) p.Match(PlanParserPOW) if p.HasError() { // Recognition error - abort rule @@ -2759,21 +2935,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(67) - p.expr(21) + p.SetState(83) + p.expr(22) } case 2: localctx = NewMulDivModContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(68) + p.SetState(84) - if !(p.Precpred(p.GetParserRuleContext(), 18)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 19)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 19)", "")) goto errorExit } { - p.SetState(69) + p.SetState(85) var _lt = p.GetTokenStream().LT(1) @@ -2791,21 +2967,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(70) - p.expr(19) + p.SetState(86) + p.expr(20) } case 3: localctx = NewAddSubContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(71) + p.SetState(87) - if !(p.Precpred(p.GetParserRuleContext(), 17)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 18)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) goto errorExit } { - p.SetState(72) + p.SetState(88) var _lt = p.GetTokenStream().LT(1) @@ -2823,21 +2999,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(73) - p.expr(18) + p.SetState(89) + p.expr(19) } case 4: localctx = NewShiftContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(74) + p.SetState(90) - if !(p.Precpred(p.GetParserRuleContext(), 16)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 17)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) goto errorExit } { - p.SetState(75) + p.SetState(91) var _lt = p.GetTokenStream().LT(1) @@ -2855,20 +3031,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(76) - p.expr(17) + p.SetState(92) + p.expr(18) } case 5: localctx = NewTermContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(77) + p.SetState(93) - if !(p.Precpred(p.GetParserRuleContext(), 15)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 15)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 16)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) goto errorExit } - p.SetState(79) + p.SetState(95) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit @@ -2877,7 +3053,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { if _la == PlanParserNOT { { - p.SetState(78) + p.SetState(94) var _m = p.Match(PlanParserNOT) @@ -2890,7 +3066,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { - p.SetState(81) + p.SetState(97) p.Match(PlanParserIN) if p.HasError() { // Recognition error - abort rule @@ -2898,21 +3074,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(82) - p.expr(16) + p.SetState(98) + p.expr(17) } case 6: localctx = NewRangeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(83) + p.SetState(99) if !(p.Precpred(p.GetParserRuleContext(), 10)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 10)", "")) goto errorExit } { - p.SetState(84) + p.SetState(100) var _lt = p.GetTokenStream().LT(1) @@ -2930,7 +3106,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(85) + p.SetState(101) _la = p.GetTokenStream().LA(1) if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { @@ -2941,7 +3117,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(86) + p.SetState(102) var _lt = p.GetTokenStream().LT(1) @@ -2959,21 +3135,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(87) + p.SetState(103) p.expr(11) } case 7: localctx = NewReverseRangeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(88) + p.SetState(104) if !(p.Precpred(p.GetParserRuleContext(), 9)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 9)", "")) goto errorExit } { - p.SetState(89) + p.SetState(105) var _lt = p.GetTokenStream().LT(1) @@ -2991,7 +3167,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(90) + p.SetState(106) _la = p.GetTokenStream().LA(1) if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { @@ -3002,7 +3178,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(91) + p.SetState(107) var _lt = p.GetTokenStream().LT(1) @@ -3020,21 +3196,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(92) + p.SetState(108) p.expr(10) } case 8: localctx = NewRelationalContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(93) + p.SetState(109) if !(p.Precpred(p.GetParserRuleContext(), 8)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 8)", "")) goto errorExit } { - p.SetState(94) + p.SetState(110) var _lt = p.GetTokenStream().LT(1) @@ -3052,21 +3228,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(95) + p.SetState(111) p.expr(9) } case 9: localctx = NewEqualityContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(96) + p.SetState(112) if !(p.Precpred(p.GetParserRuleContext(), 7)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 7)", "")) goto errorExit } { - p.SetState(97) + p.SetState(113) var _lt = p.GetTokenStream().LT(1) @@ -3084,21 +3260,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(98) + p.SetState(114) p.expr(8) } case 10: localctx = NewBitAndContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(99) + p.SetState(115) if !(p.Precpred(p.GetParserRuleContext(), 6)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 6)", "")) goto errorExit } { - p.SetState(100) + p.SetState(116) p.Match(PlanParserBAND) if p.HasError() { // Recognition error - abort rule @@ -3106,21 +3282,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(101) + p.SetState(117) p.expr(7) } case 11: localctx = NewBitXorContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(102) + p.SetState(118) if !(p.Precpred(p.GetParserRuleContext(), 5)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 5)", "")) goto errorExit } { - p.SetState(103) + p.SetState(119) p.Match(PlanParserBXOR) if p.HasError() { // Recognition error - abort rule @@ -3128,21 +3304,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(104) + p.SetState(120) p.expr(6) } case 12: localctx = NewBitOrContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(105) + p.SetState(121) if !(p.Precpred(p.GetParserRuleContext(), 4)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 4)", "")) goto errorExit } { - p.SetState(106) + p.SetState(122) p.Match(PlanParserBOR) if p.HasError() { // Recognition error - abort rule @@ -3150,21 +3326,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(107) + p.SetState(123) p.expr(5) } case 13: localctx = NewLogicalAndContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(108) + p.SetState(124) if !(p.Precpred(p.GetParserRuleContext(), 3)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 3)", "")) goto errorExit } { - p.SetState(109) + p.SetState(125) p.Match(PlanParserAND) if p.HasError() { // Recognition error - abort rule @@ -3172,21 +3348,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(110) + p.SetState(126) p.expr(4) } case 14: localctx = NewLogicalOrContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(111) + p.SetState(127) if !(p.Precpred(p.GetParserRuleContext(), 2)) { p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 2)", "")) goto errorExit } { - p.SetState(112) + p.SetState(128) p.Match(PlanParserOR) if p.HasError() { // Recognition error - abort rule @@ -3194,21 +3370,21 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(113) + p.SetState(129) p.expr(3) } case 15: localctx = NewLikeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(114) + p.SetState(130) - if !(p.Precpred(p.GetParserRuleContext(), 22)) { - p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 22)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 23)) { + p.SetError(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 23)", "")) goto errorExit } { - p.SetState(115) + p.SetState(131) p.Match(PlanParserLIKE) if p.HasError() { // Recognition error - abort rule @@ -3216,7 +3392,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(116) + p.SetState(132) p.Match(PlanParserStringLiteral) if p.HasError() { // Recognition error - abort rule @@ -3229,12 +3405,12 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } - p.SetState(121) + p.SetState(137) p.GetErrorHandler().Sync(p) if p.HasError() { goto errorExit } - _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 5, p.GetParserRuleContext()) + _alt = p.GetInterpreter().AdaptivePredict(p.BaseParser, p.GetTokenStream(), 8, p.GetParserRuleContext()) if p.HasError() { goto errorExit } @@ -3270,19 +3446,19 @@ func (p *PlanParser) Sempred(localctx antlr.RuleContext, ruleIndex, predIndex in func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) bool { switch predIndex { case 0: - return p.Precpred(p.GetParserRuleContext(), 20) + return p.Precpred(p.GetParserRuleContext(), 21) case 1: - return p.Precpred(p.GetParserRuleContext(), 18) + return p.Precpred(p.GetParserRuleContext(), 19) case 2: - return p.Precpred(p.GetParserRuleContext(), 17) + return p.Precpred(p.GetParserRuleContext(), 18) case 3: - return p.Precpred(p.GetParserRuleContext(), 16) + return p.Precpred(p.GetParserRuleContext(), 17) case 4: - return p.Precpred(p.GetParserRuleContext(), 15) + return p.Precpred(p.GetParserRuleContext(), 16) case 5: return p.Precpred(p.GetParserRuleContext(), 10) @@ -3312,7 +3488,7 @@ func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) boo return p.Precpred(p.GetParserRuleContext(), 2) case 14: - return p.Precpred(p.GetParserRuleContext(), 22) + return p.Precpred(p.GetParserRuleContext(), 23) default: panic("No predicate with index: " + fmt.Sprint(predIndex)) diff --git a/internal/parser/planparserv2/generated/plan_visitor.go b/internal/parser/planparserv2/generated/plan_visitor.go index acaa0a833b233..a043068901a04 100644 --- a/internal/parser/planparserv2/generated/plan_visitor.go +++ b/internal/parser/planparserv2/generated/plan_visitor.go @@ -46,6 +46,9 @@ type PlanVisitor interface { // Visit a parse tree produced by PlanParser#Shift. VisitShift(ctx *ShiftContext) interface{} + // Visit a parse tree produced by PlanParser#Call. + VisitCall(ctx *CallContext) interface{} + // Visit a parse tree produced by PlanParser#ReverseRange. VisitReverseRange(ctx *ReverseRangeContext) interface{} diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 92af0da0c44a8..6420970df866b 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -594,6 +594,28 @@ func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode) return v.getColumnInfoFromJSONIdentifier(child.GetText()) } +// VisitCall parses the expr to call plan. +func (v *ParserVisitor) VisitCall(ctx *parser.CallContext) interface{} { + functionName := ctx.Identifier().GetText() + numParams := len(ctx.AllExpr()) + funcParameters := make([]*planpb.Expr, 0, numParams) + for _, param := range ctx.AllExpr() { + paramExpr := getExpr(param.Accept(v)) + funcParameters = append(funcParameters, paramExpr.expr) + } + return &ExprWithType{ + expr: &planpb.Expr{ + Expr: &planpb.Expr_CallExpr{ + CallExpr: &planpb.CallExpr{ + FunctionName: functionName, + FunctionParameters: funcParameters, + }, + }, + }, + dataType: schemapb.DataType_Bool, + } +} + // VisitRange translates expr to range plan. func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier()) diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index 5c74535cec8b1..9632f76c19e18 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -61,10 +61,10 @@ func assertValidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string _, err := ParseExpr(helper, exprStr) assert.NoError(t, err, exprStr) - // expr, err := ParseExpr(helper, exprStr) - // assert.NoError(t, err, exprStr) - // fmt.Printf("expr: %s\n", exprStr) - // ShowExpr(expr) + expr, err := ParseExpr(helper, exprStr) + assert.NoError(t, err, exprStr) + fmt.Printf("expr: %s\n", exprStr) + ShowExpr(expr) } func assertInvalidExpr(t *testing.T, helper *typeutil.SchemaHelper, exprStr string) { @@ -106,6 +106,43 @@ func TestExpr_Term(t *testing.T) { } } +func TestExpr_Call(t *testing.T) { + schema := newTestSchema() + helper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + testcases := []struct { + CallExpr string + FunctionName string + ParameterNum int + }{ + {`hello123()`, "hello123", 0}, + {`lt(Int32Field)`, "lt", 1}, + // test parens + {`lt((((Int32Field))))`, "lt", 1}, + {`empty(VarCharField,)`, "empty", 1}, + {`f2(Int64Field)`, "f2", 1}, + {`f2(Int64Field, 4)`, "f2", 2}, + {`f3(JSON_FIELD["A"], Int32Field)`, "f3", 2}, + {`f5(3+3, Int32Field)`, "f5", 2}, + } + for _, testcase := range testcases { + expr, err := ParseExpr(helper, testcase.CallExpr) + assert.NoError(t, err, testcase) + assert.Equal(t, testcase.FunctionName, expr.GetCallExpr().FunctionName, testcase) + assert.Equal(t, testcase.ParameterNum, len(expr.GetCallExpr().FunctionParameters), testcase) + ShowExpr(expr) + } + + expr, err := ParseExpr(helper, "xxx(1+1, !true, f(10+10))") + assert.NoError(t, err) + assert.Equal(t, "xxx", expr.GetCallExpr().FunctionName) + assert.Equal(t, 3, len(expr.GetCallExpr().FunctionParameters)) + assert.Equal(t, int64(2), expr.GetCallExpr().GetFunctionParameters()[0].GetValueExpr().GetValue().GetInt64Val()) + assert.Equal(t, false, expr.GetCallExpr().GetFunctionParameters()[1].GetValueExpr().GetValue().GetBoolVal()) + assert.Equal(t, int64(20), expr.GetCallExpr().GetFunctionParameters()[2].GetCallExpr().GetFunctionParameters()[0].GetValueExpr().GetValue().GetInt64Val()) +} + func TestExpr_Compare(t *testing.T) { schema := newTestSchema() helper, err := typeutil.CreateSchemaHelper(schema) @@ -286,6 +323,7 @@ func TestExpr_Value(t *testing.T) { `true`, `false`, `"str"`, + `3 > 2`, } for _, exprStr := range exprStrs { expr := handleExpr(helper, exprStr) diff --git a/internal/parser/planparserv2/show_visitor.go b/internal/parser/planparserv2/show_visitor.go index b9b263b6e0631..1a06d93d5e62d 100644 --- a/internal/parser/planparserv2/show_visitor.go +++ b/internal/parser/planparserv2/show_visitor.go @@ -46,6 +46,8 @@ func (v *ShowExprVisitor) VisitExpr(expr *planpb.Expr) interface{} { js["expr"] = v.VisitUnaryExpr(realExpr.UnaryExpr) case *planpb.Expr_BinaryExpr: js["expr"] = v.VisitBinaryExpr(realExpr.BinaryExpr) + case *planpb.Expr_CallExpr: + js["expr"] = v.VisitCallExpr(realExpr.CallExpr) case *planpb.Expr_CompareExpr: js["expr"] = v.VisitCompareExpr(realExpr.CompareExpr) case *planpb.Expr_UnaryRangeExpr: @@ -93,6 +95,18 @@ func (v *ShowExprVisitor) VisitBinaryExpr(expr *planpb.BinaryExpr) interface{} { return js } +func (v *ShowExprVisitor) VisitCallExpr(expr *planpb.CallExpr) interface{} { + js := make(map[string]interface{}) + js["expr_type"] = "call" + js["func_name"] = expr.FunctionName + params := make([]interface{}, 0, len(expr.FunctionParameters)) + for _, p := range expr.FunctionParameters { + params = append(params, v.VisitExpr(p)) + } + js["func_parameters"] = params + return js +} + func (v *ShowExprVisitor) VisitCompareExpr(expr *planpb.CompareExpr) interface{} { js := make(map[string]interface{}) js["expr_type"] = "compare" @@ -164,6 +178,6 @@ func NewShowExprVisitor() LogicalExprVisitor { func ShowExpr(expr *planpb.Expr) { v := NewShowExprVisitor() js := v.VisitExpr(expr) - b, _ := json.MarshalIndent(js, "", " ") + b, _ := json.Marshal(js) log.Info("[ShowExpr]", zap.String("expr", string(b))) } diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index 16ed9aee2b184..0ee1d6c03a171 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -105,6 +105,11 @@ message BinaryRangeExpr { GenericValue upper_value = 5; } +message CallExpr { + string function_name = 1; + repeated Expr function_parameters = 2; +} + message CompareExpr { ColumnInfo left_column_info = 1; ColumnInfo right_column_info = 2; @@ -191,6 +196,7 @@ message Expr { ExistsExpr exists_expr = 11; AlwaysTrueExpr always_true_expr = 12; JSONContainsExpr json_contains_expr = 13; + CallExpr call_expr = 14; }; } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 19ac5e7a96a81..7115e14f183cd 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -23,6 +23,7 @@ package querynodev2 #include "segcore/segment_c.h" #include "segcore/segcore_init_c.h" #include "common/init_c.h" +#include "exec/expression/function/init_c.h" */ import "C" @@ -356,6 +357,8 @@ func (node *QueryNode) Init() error { return } + C.InitExecExpressionFunctionFactory() + log.Info("query node init successfully", zap.Int64("queryNodeID", node.GetNodeID()), zap.String("Address", node.address), diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 1a6eb67b8a13d..4873dfb3eba0c 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -46,6 +46,7 @@ h5py==3.8.0 loguru==0.7.0 # util +numpy==1.26.4 psutil==5.9.4 pandas==1.5.3 tenacity==8.1.0 diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 3311fd6666671..a54c91342a8f8 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -5597,3 +5597,20 @@ def test_query_text_match_with_unsupported_tokenizer(self): check_task=CheckTasks.err_res, check_items=error, ) + + +class TestQueryFunction(TestcaseBase): + @pytest.mark.tags(CaseLabel.L1) + def test_query_function_empty(self): + """ + target: test query data + method: create collection and insert data + query with mix expr in string field and int field + expected: query successfully + """ + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, + primary_field=ct.default_string_field_name)[0:2] + res = vectors[0].iloc[:, 1:3].to_dict('records') + output_fields = [default_float_field_name, default_string_field_name] + collection_w.query("not empty(varchar) && int64 >= 0", output_fields=output_fields, + check_task=CheckTasks.check_query_results, check_items={exp_res: res})