diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index b1068dd650392..754766f54388b 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -51,6 +51,15 @@ class Schema { return field_id; } + FieldId + AddDebugArrayField(const std::string& name, DataType element_type) { + auto field_id = FieldId(debug_id); + debug_id++; + this->AddField( + FieldName(name), field_id, DataType::ARRAY, element_type); + return field_id; + } + // auto gen field_id for convenience FieldId AddDebugField(const std::string& name, diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index ea9eeac92cef9..a300515560b2d 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -280,6 +280,22 @@ class SegmentExpr : public Expr { return result; } + template + void + ProcessIndexChunksV2(FUNC func, ValTypes... values) { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + + for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { + const Index& index = + segment_->chunk_scalar_index(field_id_, i); + auto* index_ptr = const_cast(&index); + func(index_ptr, values...); + } + } + template bool CanUseIndex(OpType op) const { diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index 72251c301fb14..bbcc852c2a8e2 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -23,7 +23,14 @@ namespace exec { void PhyJsonContainsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { switch (expr_->column_.data_type_) { - case DataType::ARRAY: + case DataType::ARRAY: { + if (is_index_mode_) { + result = EvalArrayContainsForIndexSegment(); + } else { + result = EvalJsonContainsForDataSegment(); + } + break; + } case DataType::JSON: { if (is_index_mode_) { PanicInfo( @@ -94,7 +101,6 @@ PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() { return ExecJsonContainsWithDiffType(); } } - break; } case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { if (IsArrayDataType(data_type)) { @@ -145,7 +151,6 @@ PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() { return ExecJsonContainsAllWithDiffType(); } } - break; } default: PanicInfo(ExprInvalid, @@ -748,5 +753,92 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return res_vec; } +VectorPtr +PhyJsonContainsFilterExpr::EvalArrayContainsForIndexSegment() { + switch (expr_->column_.element_type_) { + case DataType::BOOL: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT8: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT16: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT32: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT64: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::FLOAT: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::DOUBLE: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::VARCHAR: + case DataType::STRING: { + return ExecArrayContainsForIndexSegmentImpl(); + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type for " + "ExecArrayContainsForIndexSegmentImpl: {}", + expr_->column_.element_type_)); + } +} + +template +VectorPtr +PhyJsonContainsFilterExpr::ExecArrayContainsForIndexSegmentImpl() { + typedef std::conditional_t, + std::string, + ExprValueType> + GetType; + using Index = index::ScalarIndex; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + std::unordered_set elements; + for (auto const& element : expr_->vals_) { + elements.insert(GetValueFromProto(element)); + } + boost::container::vector elems(elements.begin(), elements.end()); + auto execute_sub_batch = + [this](Index* index_ptr, + const boost::container::vector& vals) { + switch (expr_->op_) { + case proto::plan::JSONContainsExpr_JSONOp_Contains: + case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: { + return index_ptr->In(vals.size(), vals.data()); + } + case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { + TargetBitmap result(index_ptr->Count()); + result.set(); + for (size_t i = 0; i < vals.size(); i++) { + auto sub = index_ptr->In(1, &vals[i]); + result &= sub; + } + return result; + } + default: + PanicInfo( + ExprInvalid, + "unsupported array contains type {}", + proto::plan::JSONContainsExpr_JSONOp_Name(expr_->op_)); + } + }; + auto res = ProcessIndexChunks(execute_sub_batch, elems); + AssertInfo(res.size() == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + res.size(), + real_batch_size); + return std::make_shared(std::move(res)); +} + } //namespace exec } // namespace milvus diff --git a/internal/core/src/exec/expression/JsonContainsExpr.h b/internal/core/src/exec/expression/JsonContainsExpr.h index c757dc0d3fb92..a0cfdfdea0841 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.h +++ b/internal/core/src/exec/expression/JsonContainsExpr.h @@ -80,6 +80,13 @@ class PhyJsonContainsFilterExpr : public SegmentExpr { VectorPtr ExecJsonContainsWithDiffType(); + VectorPtr + EvalArrayContainsForIndexSegment(); + + template + VectorPtr + ExecArrayContainsForIndexSegmentImpl(); + private: std::shared_ptr expr_; }; diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index f780ec487ba47..b9567133de801 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -20,6 +20,66 @@ namespace milvus { namespace exec { +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArrayForIndex() { + return ExecRangeVisitorImplArray(); +} + +template <> +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArrayForIndex< + proto::plan::Array>() { + switch (expr_->op_type_) { + case proto::plan::Equal: + case proto::plan::NotEqual: { + switch (expr_->column_.element_type_) { + case DataType::BOOL: { + return ExecArrayEqualForIndex(expr_->op_type_ == + proto::plan::NotEqual); + } + case DataType::INT8: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::INT16: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::INT32: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::INT64: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::FLOAT: + case DataType::DOUBLE: { + // not accurate on floating point number, rollback to bruteforce. + return ExecRangeVisitorImplArray(); + } + case DataType::VARCHAR: { + if (segment_->type() == SegmentType::Growing) { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } else { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + } + default: + PanicInfo(DataTypeInvalid, + "unsupported element type when execute array " + "equal for index: {}", + expr_->column_.element_type_); + } + } + default: + return ExecRangeVisitorImplArray(); + } +} + void PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { switch (expr_->column_.data_type_) { @@ -99,7 +159,13 @@ PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { result = ExecRangeVisitorImplArray(); break; case proto::plan::GenericValue::ValCase::kArrayVal: - result = ExecRangeVisitorImplArray(); + if (is_index_mode_) { + result = ExecRangeVisitorImplArrayForIndex< + proto::plan::Array>(); + } else { + result = + ExecRangeVisitorImplArray(); + } break; default: PanicInfo( @@ -196,6 +262,104 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { return res_vec; } +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecArrayEqualForIndex(bool reverse) { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + // get all elements. + auto val = GetValueFromProto(expr_->val_); + if (val.array_size() == 0) { + // rollback to bruteforce. no candidates will be filtered out via index. + return ExecRangeVisitorImplArray(); + } + + // cache the result to suit the framework. + auto batch_res = + ProcessIndexChunks([this, &val, reverse](Index* _) { + boost::container::vector elems; + for (auto const& element : val.array()) { + auto e = GetValueFromProto(element); + if (std::find(elems.begin(), elems.end(), e) == elems.end()) { + elems.push_back(e); + } + } + + // filtering by index, get candidates. + auto size_per_chunk = segment_->size_per_chunk(); + auto retrieve = [ size_per_chunk, this ](int64_t offset) -> auto { + auto chunk_idx = offset / size_per_chunk; + auto chunk_offset = offset % size_per_chunk; + const auto& chunk = + segment_->template chunk_data(field_id_, + chunk_idx); + return chunk.data() + chunk_offset; + }; + + // compare the array via the raw data. + auto filter = [&retrieve, &val, reverse](size_t offset) -> bool { + auto data_ptr = retrieve(offset); + return data_ptr->is_same_array(val) ^ reverse; + }; + + // collect all candidates. + std::unordered_set candidates; + std::unordered_set tmp_candidates; + auto first_callback = [&candidates](size_t offset) -> void { + candidates.insert(offset); + }; + auto callback = [&candidates, + &tmp_candidates](size_t offset) -> void { + if (candidates.find(offset) != candidates.end()) { + tmp_candidates.insert(offset); + } + }; + auto execute_sub_batch = + [](Index* index_ptr, + const IndexInnerType& val, + const std::function& callback) { + index_ptr->InApplyCallback(1, &val, callback); + }; + + // run in-filter. + for (size_t idx = 0; idx < elems.size(); idx++) { + if (idx == 0) { + ProcessIndexChunksV2( + execute_sub_batch, elems[idx], first_callback); + } else { + ProcessIndexChunksV2( + execute_sub_batch, elems[idx], callback); + candidates = std::move(tmp_candidates); + } + // the size of candidates is small enough. + if (candidates.size() * 100 < active_count_) { + break; + } + } + TargetBitmap res(active_count_); + // run post-filter. The filter will only be executed once in the framework. + for (const auto& candidate : candidates) { + res[candidate] = filter(candidate); + } + return res; + }); + AssertInfo(batch_res.size() == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + batch_res.size(), + real_batch_size); + + // return the result. + return std::make_shared(std::move(batch_res)); +} + template VectorPtr PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index e6342eda86434..40371e0e51f38 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -310,6 +310,14 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr { VectorPtr ExecRangeVisitorImplArray(); + template + VectorPtr + ExecRangeVisitorImplArrayForIndex(); + + template + VectorPtr + ExecArrayEqualForIndex(bool reverse); + // Check overflow and cache result for performace template ColumnVectorPtr diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index 102709aa16b83..6716f8af2f66f 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -113,11 +113,13 @@ IsMaterializedViewSupported(const DataType& data_type) { struct ColumnInfo { FieldId field_id_; DataType data_type_; + DataType element_type_; std::vector nested_path_; ColumnInfo(const proto::plan::ColumnInfo& column_info) : field_id_(column_info.field_id()), data_type_(static_cast(column_info.data_type())), + element_type_(static_cast(column_info.element_type())), nested_path_(column_info.nested_path().begin(), column_info.nested_path().end()) { } @@ -127,6 +129,7 @@ struct ColumnInfo { std::vector nested_path = {}) : field_id_(field_id), data_type_(data_type), + element_type_(DataType::NONE), nested_path_(std::move(nested_path)) { } @@ -140,6 +143,10 @@ struct ColumnInfo { return false; } + if (element_type_ != other.element_type_) { + return false; + } + for (int i = 0; i < nested_path_.size(); ++i) { if (nested_path_[i] != other.nested_path_[i]) { return false; @@ -151,10 +158,12 @@ struct ColumnInfo { std::string ToString() const { - return fmt::format("[FieldId:{}, data_type:{}, nested_path:{}]", - std::to_string(field_id_.get()), - data_type_, - milvus::Join(nested_path_, ",")); + return fmt::format( + "[FieldId:{}, data_type:{}, element_type:{}, nested_path:{}]", + std::to_string(field_id_.get()), + data_type_, + element_type_, + milvus::Join(nested_path_, ",")); } }; diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index a593d087eb270..8c0ada968aab8 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -34,13 +34,9 @@ template ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, - const storage::FileManagerContext& file_manager_context, - DataType d_type) { + const storage::FileManagerContext& file_manager_context) { if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; - return std::make_unique>(cfg, - file_manager_context); + return std::make_unique>(file_manager_context); } return CreateScalarIndexSort(file_manager_context); } @@ -56,14 +52,11 @@ template <> ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, - const storage::FileManagerContext& file_manager_context, - DataType d_type) { + const storage::FileManagerContext& file_manager_context) { #if defined(__linux__) || defined(__APPLE__) if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; return std::make_unique>( - cfg, file_manager_context); + file_manager_context); } return CreateStringIndexMarisa(file_manager_context); #else @@ -76,13 +69,10 @@ ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, const storage::FileManagerContext& file_manager_context, - std::shared_ptr space, - DataType d_type) { + std::shared_ptr space) { if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; - return std::make_unique>( - cfg, file_manager_context, space); + return std::make_unique>(file_manager_context, + space); } return CreateScalarIndexSort(file_manager_context, space); } @@ -92,14 +82,11 @@ ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, const storage::FileManagerContext& file_manager_context, - std::shared_ptr space, - DataType d_type) { + std::shared_ptr space) { #if defined(__linux__) || defined(__APPLE__) if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; return std::make_unique>( - cfg, file_manager_context, space); + file_manager_context, space); } return CreateStringIndexMarisa(file_manager_context, space); #else @@ -132,41 +119,32 @@ IndexFactory::CreateIndex( } IndexBasePtr -IndexFactory::CreateScalarIndex( - const CreateIndexInfo& create_index_info, +IndexFactory::CreatePrimitiveScalarIndex( + DataType data_type, + IndexType index_type, const storage::FileManagerContext& file_manager_context) { - auto data_type = create_index_info.field_type; - auto index_type = create_index_info.index_type; - switch (data_type) { // create scalar index case DataType::BOOL: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT8: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT16: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT32: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT64: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::FLOAT: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::DOUBLE: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); // create string index case DataType::STRING: case DataType::VARCHAR: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, + file_manager_context); default: throw SegcoreError( DataTypeInvalid, @@ -174,6 +152,24 @@ IndexFactory::CreateScalarIndex( } } +IndexBasePtr +IndexFactory::CreateScalarIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context) { + switch (create_index_info.field_type) { + case DataType::ARRAY: + return CreatePrimitiveScalarIndex( + static_cast( + file_manager_context.fieldDataMeta.schema.element_type()), + create_index_info.index_type, + file_manager_context); + default: + return CreatePrimitiveScalarIndex(create_index_info.field_type, + create_index_info.index_type, + file_manager_context); + } +} + IndexBasePtr IndexFactory::CreateVectorIndex( const CreateIndexInfo& create_index_info, @@ -249,32 +245,25 @@ IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info, switch (data_type) { // create scalar index case DataType::BOOL: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT8: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT16: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT32: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT64: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::FLOAT: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::DOUBLE: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); // create string index case DataType::STRING: case DataType::VARCHAR: return CreateScalarIndex( - index_type, file_manager, space, data_type); + index_type, file_manager, space); default: throw SegcoreError( DataTypeInvalid, diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 75bd090292907..47b255ab4e912 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -65,6 +65,13 @@ class IndexFactory { CreateVectorIndex(const CreateIndexInfo& create_index_info, const storage::FileManagerContext& file_manager_context); + IndexBasePtr + CreatePrimitiveScalarIndex( + DataType data_type, + IndexType index_type, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, const storage::FileManagerContext& file_manager_context = @@ -89,15 +96,13 @@ class IndexFactory { ScalarIndexPtr CreateScalarIndex(const IndexType& index_type, const storage::FileManagerContext& file_manager = - storage::FileManagerContext(), - DataType d_type = DataType::NONE); + storage::FileManagerContext()); template ScalarIndexPtr CreateScalarIndex(const IndexType& index_type, const storage::FileManagerContext& file_manager, - std::shared_ptr space, - DataType d_type = DataType::NONE); + std::shared_ptr space); }; // template <> @@ -112,6 +117,5 @@ ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, const storage::FileManagerContext& file_manager_context, - std::shared_ptr space, - DataType d_type); + std::shared_ptr space); } // namespace milvus::index diff --git a/internal/core/src/index/InvertedIndexTantivy.cpp b/internal/core/src/index/InvertedIndexTantivy.cpp index 5bb8ba3b16103..3b9a54fae940b 100644 --- a/internal/core/src/index/InvertedIndexTantivy.cpp +++ b/internal/core/src/index/InvertedIndexTantivy.cpp @@ -23,12 +23,50 @@ #include "InvertedIndexTantivy.h" namespace milvus::index { +inline TantivyDataType +get_tantivy_data_type(proto::schema::DataType data_type) { + switch (data_type) { + case proto::schema::DataType::Bool: { + return TantivyDataType::Bool; + } + + case proto::schema::DataType::Int8: + case proto::schema::DataType::Int16: + case proto::schema::DataType::Int32: + case proto::schema::DataType::Int64: { + return TantivyDataType::I64; + } + + case proto::schema::DataType::Float: + case proto::schema::DataType::Double: { + return TantivyDataType::F64; + } + + case proto::schema::DataType::VarChar: { + return TantivyDataType::Keyword; + } + + default: + PanicInfo(ErrorCode::NotImplemented, + fmt::format("not implemented data type: {}", data_type)); + } +} + +inline TantivyDataType +get_tantivy_data_type(const proto::schema::FieldSchema& schema) { + switch (schema.data_type()) { + case proto::schema::Array: + return get_tantivy_data_type(schema.element_type()); + default: + return get_tantivy_data_type(schema.data_type()); + } +} + template InvertedIndexTantivy::InvertedIndexTantivy( - const TantivyConfig& cfg, const storage::FileManagerContext& ctx, std::shared_ptr space) - : cfg_(cfg), space_(space) { + : space_(space), schema_(ctx.fieldDataMeta.schema) { mem_file_manager_ = std::make_shared(ctx, ctx.space_); disk_file_manager_ = std::make_shared(ctx, ctx.space_); auto field = @@ -36,7 +74,7 @@ InvertedIndexTantivy::InvertedIndexTantivy( auto prefix = disk_file_manager_->GetLocalIndexObjectPrefix(); path_ = prefix; boost::filesystem::create_directories(path_); - d_type_ = cfg_.to_tantivy_data_type(); + d_type_ = get_tantivy_data_type(schema_); if (tantivy_index_exist(path_.c_str())) { LOG_INFO( "index {} already exists, which should happen in loading progress", @@ -114,83 +152,7 @@ InvertedIndexTantivy::Build(const Config& config) { AssertInfo(insert_files.has_value(), "insert_files were empty"); auto field_datas = mem_file_manager_->CacheRawDataToMemory(insert_files.value()); - switch (cfg_.data_type_) { - case DataType::BOOL: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data(static_cast(data->Data()), - n); - } - break; - } - - case DataType::INT8: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT16: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT32: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT64: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::FLOAT: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::DOUBLE: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::VARCHAR: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - default: - PanicInfo(ErrorCode::NotImplemented, - fmt::format("todo: not supported, {}", cfg_.data_type_)); - } + build_index(field_datas); } template @@ -211,84 +173,7 @@ InvertedIndexTantivy::BuildV2(const Config& config) { field_data->FillFieldData(col_data); field_datas.push_back(field_data); } - - switch (cfg_.data_type_) { - case DataType::BOOL: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data(static_cast(data->Data()), - n); - } - break; - } - - case DataType::INT8: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT16: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT32: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT64: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::FLOAT: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::DOUBLE: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::VARCHAR: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - default: - PanicInfo(ErrorCode::NotImplemented, - fmt::format("todo: not supported, {}", cfg_.data_type_)); - } + build_index(field_datas); } template @@ -319,6 +204,25 @@ apply_hits(TargetBitmap& bitset, const RustArrayWrapper& w, bool v) { } } +inline void +apply_hits_with_filter(TargetBitmap& bitset, + const RustArrayWrapper& w, + const std::function& filter) { + for (size_t j = 0; j < w.array_.len; j++) { + auto the_offset = w.array_.array[j]; + bitset[the_offset] = filter(the_offset); + } +} + +inline void +apply_hits_with_callback( + const RustArrayWrapper& w, + const std::function& callback) { + for (size_t j = 0; j < w.array_.len; j++) { + callback(w.array_.array[j]); + } +} + template const TargetBitmap InvertedIndexTantivy::In(size_t n, const T* values) { @@ -330,10 +234,33 @@ InvertedIndexTantivy::In(size_t n, const T* values) { return bitset; } +template +const TargetBitmap +InvertedIndexTantivy::InApplyFilter( + size_t n, const T* values, const std::function& filter) { + TargetBitmap bitset(Count()); + for (size_t i = 0; i < n; ++i) { + auto array = wrapper_->term_query(values[i]); + apply_hits_with_filter(bitset, array, filter); + } + return bitset; +} + +template +void +InvertedIndexTantivy::InApplyCallback( + size_t n, const T* values, const std::function& callback) { + for (size_t i = 0; i < n; ++i) { + auto array = wrapper_->term_query(values[i]); + apply_hits_with_callback(array, callback); + } +} + template const TargetBitmap InvertedIndexTantivy::NotIn(size_t n, const T* values) { - TargetBitmap bitset(Count(), true); + TargetBitmap bitset(Count()); + bitset.set(); for (size_t i = 0; i < n; ++i) { auto array = wrapper_->term_query(values[i]); apply_hits(bitset, array, false); @@ -425,25 +352,118 @@ void InvertedIndexTantivy::BuildWithRawData(size_t n, const void* values, const Config& config) { - if constexpr (!std::is_same_v) { - PanicInfo(Unsupported, - "InvertedIndex.BuildWithRawData only support string"); + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Bool); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int8); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int16); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int32); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int64); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Float); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Double); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::VarChar); + } + boost::uuids::random_generator generator; + auto uuid = generator(); + auto prefix = boost::uuids::to_string(uuid); + path_ = fmt::format("/tmp/{}", prefix); + boost::filesystem::create_directories(path_); + d_type_ = get_tantivy_data_type(schema_); + std::string field = "test_inverted_index"; + wrapper_ = std::make_shared( + field.c_str(), d_type_, path_.c_str()); + if (config.find("is_array") != config.end()) { + // only used in ut. + auto arr = static_cast*>(values); + for (size_t i = 0; i < n; i++) { + wrapper_->template add_multi_data(arr[i].data(), arr[i].size()); + } } else { - boost::uuids::random_generator generator; - auto uuid = generator(); - auto prefix = boost::uuids::to_string(uuid); - path_ = fmt::format("/tmp/{}", prefix); - boost::filesystem::create_directories(path_); - cfg_ = TantivyConfig{ - .data_type_ = DataType::VARCHAR, - }; - d_type_ = cfg_.to_tantivy_data_type(); - std::string field = "test_inverted_index"; - wrapper_ = std::make_shared( - field.c_str(), d_type_, path_.c_str()); - wrapper_->add_data(static_cast(values), - n); - finish(); + wrapper_->add_data(static_cast(values), n); + } + finish(); +} + +template +void +InvertedIndexTantivy::build_index( + const std::vector>& field_datas) { + switch (schema_.data_type()) { + case proto::schema::DataType::Bool: + case proto::schema::DataType::Int8: + case proto::schema::DataType::Int16: + case proto::schema::DataType::Int32: + case proto::schema::DataType::Int64: + case proto::schema::DataType::Float: + case proto::schema::DataType::Double: + case proto::schema::DataType::String: + case proto::schema::DataType::VarChar: { + for (const auto& data : field_datas) { + auto n = data->get_num_rows(); + wrapper_->add_data(static_cast(data->Data()), n); + } + break; + } + + case proto::schema::DataType::Array: { + build_index_for_array(field_datas); + break; + } + + default: + PanicInfo(ErrorCode::NotImplemented, + fmt::format("Inverted index not supported on {}", + schema_.data_type())); + } +} + +template +void +InvertedIndexTantivy::build_index_for_array( + const std::vector>& field_datas) { + for (const auto& data : field_datas) { + auto n = data->get_num_rows(); + auto array_column = static_cast(data->Data()); + for (int64_t i = 0; i < n; i++) { + assert(array_column[i].get_element_type() == + static_cast(schema_.element_type())); + wrapper_->template add_multi_data( + reinterpret_cast(array_column[i].data()), + array_column[i].length()); + } + } +} + +template <> +void +InvertedIndexTantivy::build_index_for_array( + const std::vector>& field_datas) { + for (const auto& data : field_datas) { + auto n = data->get_num_rows(); + auto array_column = static_cast(data->Data()); + for (int64_t i = 0; i < n; i++) { + assert(array_column[i].get_element_type() == + static_cast(schema_.element_type())); + std::vector output; + for (int64_t j = 0; j < array_column[i].length(); j++) { + output.push_back( + array_column[i].template get_data(j)); + } + wrapper_->template add_multi_data(output.data(), output.size()); + } } } diff --git a/internal/core/src/index/InvertedIndexTantivy.h b/internal/core/src/index/InvertedIndexTantivy.h index 0ea2f64d869d3..53fb9c2d687ac 100644 --- a/internal/core/src/index/InvertedIndexTantivy.h +++ b/internal/core/src/index/InvertedIndexTantivy.h @@ -18,7 +18,6 @@ #include "tantivy-binding.h" #include "tantivy-wrapper.h" #include "index/StringIndex.h" -#include "index/TantivyConfig.h" #include "storage/space.h" namespace milvus::index { @@ -36,13 +35,11 @@ class InvertedIndexTantivy : public ScalarIndex { InvertedIndexTantivy() = default; - explicit InvertedIndexTantivy(const TantivyConfig& cfg, - const storage::FileManagerContext& ctx) - : InvertedIndexTantivy(cfg, ctx, nullptr) { + explicit InvertedIndexTantivy(const storage::FileManagerContext& ctx) + : InvertedIndexTantivy(ctx, nullptr) { } - explicit InvertedIndexTantivy(const TantivyConfig& cfg, - const storage::FileManagerContext& ctx, + explicit InvertedIndexTantivy(const storage::FileManagerContext& ctx, std::shared_ptr space); ~InvertedIndexTantivy(); @@ -114,6 +111,18 @@ class InvertedIndexTantivy : public ScalarIndex { const TargetBitmap In(size_t n, const T* values) override; + const TargetBitmap + InApplyFilter( + size_t n, + const T* values, + const std::function& filter) override; + + void + InApplyCallback( + size_t n, + const T* values, + const std::function& callback) override; + const TargetBitmap NotIn(size_t n, const T* values) override; @@ -160,11 +169,18 @@ class InvertedIndexTantivy : public ScalarIndex { void finish(); + void + build_index(const std::vector>& field_datas); + + void + build_index_for_array( + const std::vector>& field_datas); + private: std::shared_ptr wrapper_; - TantivyConfig cfg_; TantivyDataType d_type_; std::string path_; + proto::schema::FieldSchema schema_; /* * To avoid IO amplification, we use both mem file manager & disk file manager diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index aacef521f5db3..37d22a288d80b 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -50,6 +50,20 @@ class ScalarIndex : public IndexBase { virtual const TargetBitmap In(size_t n, const T* values) = 0; + virtual const TargetBitmap + InApplyFilter(size_t n, + const T* values, + const std::function& filter) { + PanicInfo(ErrorCode::Unsupported, "InApplyFilter is not implemented"); + } + + virtual void + InApplyCallback(size_t n, + const T* values, + const std::function& callback) { + PanicInfo(ErrorCode::Unsupported, "InApplyCallback is not implemented"); + } + virtual const TargetBitmap NotIn(size_t n, const T* values) = 0; diff --git a/internal/core/src/index/TantivyConfig.h b/internal/core/src/index/TantivyConfig.h deleted file mode 100644 index 355b4c76efc9d..0000000000000 --- a/internal/core/src/index/TantivyConfig.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once - -#include "storage/Types.h" -#include "tantivy-binding.h" - -namespace milvus::index { -struct TantivyConfig { - DataType data_type_; - - TantivyDataType - to_tantivy_data_type() { - switch (data_type_) { - case DataType::BOOL: { - return TantivyDataType::Bool; - } - - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: { - return TantivyDataType::I64; - } - - case DataType::FLOAT: - case DataType::DOUBLE: { - return TantivyDataType::F64; - } - - case DataType::VARCHAR: { - return TantivyDataType::Keyword; - } - - default: - PanicInfo( - ErrorCode::NotImplemented, - fmt::format("not implemented data type: {}", data_type_)); - } - } -}; -} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/indexbuilder/IndexFactory.h b/internal/core/src/indexbuilder/IndexFactory.h index cd361499b4065..1380a6e9817d3 100644 --- a/internal/core/src/indexbuilder/IndexFactory.h +++ b/internal/core/src/indexbuilder/IndexFactory.h @@ -60,6 +60,7 @@ class IndexFactory { case DataType::DOUBLE: case DataType::VARCHAR: case DataType::STRING: + case DataType::ARRAY: return CreateScalarIndex(type, config, context); case DataType::VECTOR_FLOAT: diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 28a629052cad7..7ccaf7c414a24 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -84,29 +84,95 @@ CreateIndexV0(enum CDataType dtype, return status; } +milvus::storage::StorageConfig +get_storage_config(const milvus::proto::indexcgo::StorageConfig& config) { + auto storage_config = milvus::storage::StorageConfig(); + storage_config.address = std::string(config.address()); + storage_config.bucket_name = std::string(config.bucket_name()); + storage_config.access_key_id = std::string(config.access_keyid()); + storage_config.access_key_value = std::string(config.secret_access_key()); + storage_config.root_path = std::string(config.root_path()); + storage_config.storage_type = std::string(config.storage_type()); + storage_config.cloud_provider = std::string(config.cloud_provider()); + storage_config.iam_endpoint = std::string(config.iamendpoint()); + storage_config.cloud_provider = std::string(config.cloud_provider()); + storage_config.useSSL = config.usessl(); + storage_config.sslCACert = config.sslcacert(); + storage_config.useIAM = config.useiam(); + storage_config.region = config.region(); + storage_config.useVirtualHost = config.use_virtual_host(); + storage_config.requestTimeoutMs = config.request_timeout_ms(); + return storage_config; +} + +milvus::OptFieldT +get_opt_field(const ::google::protobuf::RepeatedPtrField< + milvus::proto::indexcgo::OptionalFieldInfo>& field_infos) { + milvus::OptFieldT opt_fields_map; + for (const auto& field_info : field_infos) { + auto field_id = field_info.fieldid(); + if (opt_fields_map.find(field_id) == opt_fields_map.end()) { + opt_fields_map[field_id] = { + field_info.field_name(), + static_cast(field_info.field_type()), + {}}; + } + for (const auto& str : field_info.data_paths()) { + std::get<2>(opt_fields_map[field_id]).emplace_back(str); + } + } + + return opt_fields_map; +} + +milvus::Config +get_config(std::unique_ptr& info) { + milvus::Config config; + for (auto i = 0; i < info->index_params().size(); ++i) { + const auto& param = info->index_params(i); + config[param.key()] = param.value(); + } + + for (auto i = 0; i < info->type_params().size(); ++i) { + const auto& param = info->type_params(i); + config[param.key()] = param.value(); + } + + config["insert_files"] = info->insert_files(); + if (info->opt_fields().size()) { + config["opt_fields"] = get_opt_field(info->opt_fields()); + } + + return config; +} + CStatus -CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info) { +CreateIndex(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len) { try { - auto build_index_info = (BuildIndexInfo*)c_build_index_info; - auto field_type = build_index_info->field_type; + auto build_index_info = + std::make_unique(); + auto res = + build_index_info->ParseFromArray(serialized_build_index_info, len); + AssertInfo(res, "Unmarshall build index info failed"); - milvus::index::CreateIndexInfo index_info; - index_info.field_type = build_index_info->field_type; + auto field_type = + static_cast(build_index_info->field_schema().data_type()); - auto& config = build_index_info->config; - config["insert_files"] = build_index_info->insert_files; - if (build_index_info->opt_fields.size()) { - config["opt_fields"] = build_index_info->opt_fields; - } + milvus::index::CreateIndexInfo index_info; + index_info.field_type = field_type; + auto storage_config = + get_storage_config(build_index_info->storage_config()); + auto config = get_config(build_index_info); // get index type auto index_type = milvus::index::GetValueFromConfig( config, "index_type"); AssertInfo(index_type.has_value(), "index type is empty"); index_info.index_type = index_type.value(); - auto engine_version = build_index_info->index_engine_version; - + auto engine_version = build_index_info->current_index_version(); index_info.index_engine_version = engine_version; config[milvus::index::INDEX_ENGINE_VERSION] = std::to_string(engine_version); @@ -121,24 +187,31 @@ CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info) { // init file manager milvus::storage::FieldDataMeta field_meta{ - build_index_info->collection_id, - build_index_info->partition_id, - build_index_info->segment_id, - build_index_info->field_id}; - - milvus::storage::IndexMeta index_meta{build_index_info->segment_id, - build_index_info->field_id, - build_index_info->index_build_id, - build_index_info->index_version}; - auto chunk_manager = milvus::storage::CreateChunkManager( - build_index_info->storage_config); + build_index_info->collectionid(), + build_index_info->partitionid(), + build_index_info->segmentid(), + build_index_info->field_schema().fieldid(), + build_index_info->field_schema()}; + + milvus::storage::IndexMeta index_meta{ + build_index_info->segmentid(), + build_index_info->field_schema().fieldid(), + build_index_info->buildid(), + build_index_info->index_version(), + "", + build_index_info->field_schema().name(), + field_type, + build_index_info->dim(), + }; + auto chunk_manager = + milvus::storage::CreateChunkManager(storage_config); milvus::storage::FileManagerContext fileManagerContext( field_meta, index_meta, chunk_manager); auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - build_index_info->field_type, config, fileManagerContext); + field_type, config, fileManagerContext); index->Build(); *res_index = index.release(); auto status = CStatus(); @@ -159,22 +232,32 @@ CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info) { } CStatus -CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { +CreateIndexV2(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len) { try { - auto build_index_info = (BuildIndexInfo*)c_build_index_info; - auto field_type = build_index_info->field_type; + auto build_index_info = + std::make_unique(); + auto res = + build_index_info->ParseFromArray(serialized_build_index_info, len); + AssertInfo(res, "Unmarshall build index info failed"); + auto field_type = + static_cast(build_index_info->field_schema().data_type()); + milvus::index::CreateIndexInfo index_info; - index_info.field_type = build_index_info->field_type; - index_info.dim = build_index_info->dim; + index_info.field_type = field_type; + index_info.dim = build_index_info->dim(); - auto& config = build_index_info->config; + auto storage_config = + get_storage_config(build_index_info->storage_config()); + auto config = get_config(build_index_info); // get index type auto index_type = milvus::index::GetValueFromConfig( config, "index_type"); AssertInfo(index_type.has_value(), "index type is empty"); index_info.index_type = index_type.value(); - auto engine_version = build_index_info->index_engine_version; + auto engine_version = build_index_info->current_index_version(); index_info.index_engine_version = engine_version; config[milvus::index::INDEX_ENGINE_VERSION] = std::to_string(engine_version); @@ -188,39 +271,39 @@ CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { } milvus::storage::FieldDataMeta field_meta{ - build_index_info->collection_id, - build_index_info->partition_id, - build_index_info->segment_id, - build_index_info->field_id}; + build_index_info->collectionid(), + build_index_info->partitionid(), + build_index_info->segmentid(), + build_index_info->field_schema().fieldid()}; milvus::storage::IndexMeta index_meta{ - build_index_info->segment_id, - build_index_info->field_id, - build_index_info->index_build_id, - build_index_info->index_version, - build_index_info->field_name, + build_index_info->segmentid(), + build_index_info->field_schema().fieldid(), + build_index_info->buildid(), + build_index_info->index_version(), "", - build_index_info->field_type, - build_index_info->dim, + build_index_info->field_schema().name(), + field_type, + build_index_info->dim(), }; auto store_space = milvus_storage::Space::Open( - build_index_info->data_store_path, + build_index_info->store_path(), milvus_storage::Options{nullptr, - build_index_info->data_store_version}); + build_index_info->store_version()}); AssertInfo(store_space.ok() && store_space.has_value(), "create space failed: {}", store_space.status().ToString()); auto index_space = milvus_storage::Space::Open( - build_index_info->index_store_path, + build_index_info->index_store_path(), milvus_storage::Options{.schema = store_space.value()->schema()}); AssertInfo(index_space.ok() && index_space.has_value(), "create space failed: {}", index_space.status().ToString()); LOG_INFO("init space success"); - auto chunk_manager = milvus::storage::CreateChunkManager( - build_index_info->storage_config); + auto chunk_manager = + milvus::storage::CreateChunkManager(storage_config); milvus::storage::FileManagerContext fileManagerContext( field_meta, index_meta, @@ -229,9 +312,9 @@ CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - build_index_info->field_type, - build_index_info->field_name, - build_index_info->dim, + field_type, + build_index_info->field_schema().name(), + build_index_info->dim(), config, fileManagerContext, std::move(store_space.value())); diff --git a/internal/core/src/indexbuilder/index_c.h b/internal/core/src/indexbuilder/index_c.h index 16cd76e4531ce..53ce5552fef0a 100644 --- a/internal/core/src/indexbuilder/index_c.h +++ b/internal/core/src/indexbuilder/index_c.h @@ -28,7 +28,9 @@ CreateIndexV0(enum CDataType dtype, CIndex* res_index); CStatus -CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info); +CreateIndex(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len); CStatus DeleteIndex(CIndex index); @@ -130,7 +132,9 @@ CStatus SerializeIndexAndUpLoadV2(CIndex index, CBinarySet* c_binary_set); CStatus -CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info); +CreateIndexV2(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len); CStatus AppendIndexStorageInfo(CBuildIndexInfo c_build_index_info, diff --git a/internal/core/src/pb/CMakeLists.txt b/internal/core/src/pb/CMakeLists.txt index 3c00203cf4c25..35726d9c24c65 100644 --- a/internal/core/src/pb/CMakeLists.txt +++ b/internal/core/src/pb/CMakeLists.txt @@ -11,12 +11,10 @@ find_package(Protobuf REQUIRED) +file(GLOB_RECURSE milvus_proto_srcs + "${CMAKE_CURRENT_SOURCE_DIR}/*.cc") add_library(milvus_proto STATIC - common.pb.cc - index_cgo_msg.pb.cc - plan.pb.cc - schema.pb.cc - segcore.pb.cc + ${milvus_proto_srcs} ) message(STATUS "milvus proto sources: " ${milvus_proto_srcs}) diff --git a/internal/core/src/segcore/Types.h b/internal/core/src/segcore/Types.h index 73ba7fcb188b6..106799ce2610f 100644 --- a/internal/core/src/segcore/Types.h +++ b/internal/core/src/segcore/Types.h @@ -46,6 +46,7 @@ struct LoadIndexInfo { std::string uri; int64_t index_store_version; IndexVersion index_engine_version; + proto::schema::FieldSchema schema; }; } // namespace milvus::segcore diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 7f851948545d3..3df3a92879751 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -25,6 +25,7 @@ #include "storage/Util.h" #include "storage/RemoteChunkManagerSingleton.h" #include "storage/LocalChunkManagerSingleton.h" +#include "pb/cgo_msg.pb.h" bool IsLoadWithDisk(const char* index_type, int index_engine_version) { @@ -258,7 +259,8 @@ AppendIndexV2(CTraceContext c_trace, CLoadIndexInfo c_load_index_info) { load_index_info->collection_id, load_index_info->partition_id, load_index_info->segment_id, - load_index_info->field_id}; + load_index_info->field_id, + load_index_info->schema}; milvus::storage::IndexMeta index_meta{load_index_info->segment_id, load_index_info->field_id, load_index_info->index_build_id, @@ -484,3 +486,50 @@ AppendStorageInfo(CLoadIndexInfo c_load_index_info, load_index_info->uri = uri; load_index_info->index_store_version = version; } + +CStatus +FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info, + const uint8_t* serialized_load_index_info, + const uint64_t len) { + try { + auto info_proto = std::make_unique(); + info_proto->ParseFromArray(serialized_load_index_info, len); + auto load_index_info = + static_cast(c_load_index_info); + // TODO: keep this since LoadIndexInfo is used by SegmentSealed. + { + load_index_info->collection_id = info_proto->collectionid(); + load_index_info->partition_id = info_proto->partitionid(); + load_index_info->segment_id = info_proto->segmentid(); + load_index_info->field_id = info_proto->field().fieldid(); + load_index_info->field_type = + static_cast(info_proto->field().data_type()); + load_index_info->enable_mmap = info_proto->enable_mmap(); + load_index_info->mmap_dir_path = info_proto->mmap_dir_path(); + load_index_info->index_id = info_proto->indexid(); + load_index_info->index_build_id = info_proto->index_buildid(); + load_index_info->index_version = info_proto->index_version(); + for (const auto& [k, v] : info_proto->index_params()) { + load_index_info->index_params[k] = v; + } + load_index_info->index_files.assign( + info_proto->index_files().begin(), + info_proto->index_files().end()); + load_index_info->uri = info_proto->uri(); + load_index_info->index_store_version = + info_proto->index_store_version(); + load_index_info->index_engine_version = + info_proto->index_engine_version(); + load_index_info->schema = info_proto->field(); + } + auto status = CStatus(); + status.error_code = milvus::Success; + status.error_msg = ""; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = milvus::UnexpectedError; + status.error_msg = strdup(e.what()); + return status; + } +} diff --git a/internal/core/src/segcore/load_index_c.h b/internal/core/src/segcore/load_index_c.h index 7a3d89b797670..8755aa7396162 100644 --- a/internal/core/src/segcore/load_index_c.h +++ b/internal/core/src/segcore/load_index_c.h @@ -76,6 +76,11 @@ void AppendStorageInfo(CLoadIndexInfo c_load_index_info, const char* uri, int64_t version); + +CStatus +FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info, + const uint8_t* serialized_load_index_info, + const uint64_t len); #ifdef __cplusplus } #endif diff --git a/internal/core/src/storage/Types.h b/internal/core/src/storage/Types.h index 924873dccda64..fbd72d0a59a78 100644 --- a/internal/core/src/storage/Types.h +++ b/internal/core/src/storage/Types.h @@ -64,6 +64,7 @@ struct FieldDataMeta { int64_t partition_id; int64_t segment_id; int64_t field_id; + proto::schema::FieldSchema schema; }; enum CodecType { diff --git a/internal/core/thirdparty/tantivy/CMakeLists.txt b/internal/core/thirdparty/tantivy/CMakeLists.txt index f4d928922874f..c1435a032a85e 100644 --- a/internal/core/thirdparty/tantivy/CMakeLists.txt +++ b/internal/core/thirdparty/tantivy/CMakeLists.txt @@ -71,3 +71,9 @@ target_link_libraries(bench_tantivy boost_filesystem dl ) + +add_executable(ffi_demo ffi_demo.cpp) +target_link_libraries(ffi_demo + tantivy_binding + dl + ) diff --git a/internal/core/thirdparty/tantivy/ffi_demo.cpp b/internal/core/thirdparty/tantivy/ffi_demo.cpp new file mode 100644 index 0000000000000..1626d655f175d --- /dev/null +++ b/internal/core/thirdparty/tantivy/ffi_demo.cpp @@ -0,0 +1,17 @@ +#include +#include + +#include "tantivy-binding.h" + +int +main(int argc, char* argv[]) { + std::vector data{"data1", "data2", "data3"}; + std::vector datas{}; + for (auto& s : data) { + datas.push_back(s.c_str()); + } + + print_vector_of_strings(datas.data(), datas.size()); + + return 0; +} diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h index 3b22018bf047e..045d4a50e6a2c 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h +++ b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h @@ -97,6 +97,24 @@ void tantivy_index_add_bools(void *ptr, const bool *array, uintptr_t len); void tantivy_index_add_keyword(void *ptr, const char *s); +void tantivy_index_add_multi_int8s(void *ptr, const int8_t *array, uintptr_t len); + +void tantivy_index_add_multi_int16s(void *ptr, const int16_t *array, uintptr_t len); + +void tantivy_index_add_multi_int32s(void *ptr, const int32_t *array, uintptr_t len); + +void tantivy_index_add_multi_int64s(void *ptr, const int64_t *array, uintptr_t len); + +void tantivy_index_add_multi_f32s(void *ptr, const float *array, uintptr_t len); + +void tantivy_index_add_multi_f64s(void *ptr, const double *array, uintptr_t len); + +void tantivy_index_add_multi_bools(void *ptr, const bool *array, uintptr_t len); + +void tantivy_index_add_multi_keywords(void *ptr, const char *const *array, uintptr_t len); + bool tantivy_index_exist(const char *path); +void print_vector_of_strings(const char *const *ptr, uintptr_t len); + } // extern "C" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs new file mode 100644 index 0000000000000..257a41f17a891 --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs @@ -0,0 +1,14 @@ +use std::{ffi::{c_char, CStr}, slice}; + +#[no_mangle] +pub extern "C" fn print_vector_of_strings(ptr: *const *const c_char, len: usize) { + let arr : &[*const c_char] = unsafe { + slice::from_raw_parts(ptr, len) + }; + for element in arr { + let c_str = unsafe { + CStr::from_ptr(*element) + }; + println!("{}", c_str.to_str().unwrap()); + } +} \ No newline at end of file diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs index ce96a5b4d5a30..2c8d56bf38694 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs @@ -1,10 +1,11 @@ -use futures::executor::block_on; +use std::ffi::CStr; +use libc::c_char; use tantivy::schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, INDEXED}; -use tantivy::{doc, tokenizer, Index, IndexWriter, SingleSegmentIndexWriter}; +use tantivy::{doc, tokenizer, Index, SingleSegmentIndexWriter, Document}; use crate::data_type::TantivyDataType; -use crate::index_writer; + use crate::log::init_log; pub struct IndexWriterWrapper { @@ -98,7 +99,74 @@ impl IndexWriterWrapper { .unwrap(); } - pub fn finish(mut self) { + pub fn add_multi_i8s(&mut self, datas: &[i8]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as i64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_i16s(&mut self, datas: &[i16]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as i64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_i32s(&mut self, datas: &[i32]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as i64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_i64s(&mut self, datas: &[i64]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_f32s(&mut self, datas: &[f32]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as f64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_f64s(&mut self, datas: &[f64]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_bools(&mut self, datas: &[bool]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_keywords(&mut self, datas: &[*const c_char]) { + let mut document = Document::default(); + for element in datas { + let data = unsafe { + CStr::from_ptr(*element) + }; + document.add_field_value(self.field, data.to_str().unwrap()); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn finish(self) { self.index_writer .finalize() .expect("failed to build inverted index"); diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs index c8822781158e8..b13f550d7cb00 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs @@ -122,3 +122,77 @@ pub extern "C" fn tantivy_index_add_keyword(ptr: *mut c_void, s: *const c_char) let c_str = unsafe { CStr::from_ptr(s) }; unsafe { (*real).add_keyword(c_str.to_str().unwrap()) } } + +// --------------------------------------------- array ------------------------------------------ + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int8s(ptr: *mut c_void, array: *const i8, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len); + (*real).add_multi_i8s(arr) + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int16s(ptr: *mut c_void, array: *const i16, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_i16s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int32s(ptr: *mut c_void, array: *const i32, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_i32s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int64s(ptr: *mut c_void, array: *const i64, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_i64s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_f32s(ptr: *mut c_void, array: *const f32, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_f32s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_f64s(ptr: *mut c_void, array: *const f64, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_f64s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_bools(ptr: *mut c_void, array: *const bool, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_bools(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_keywords(ptr: *mut c_void, array: *const *const c_char, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len); + (*real).add_multi_keywords(arr) + } +} diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs index aa069cb3b32b6..c6193de3f6908 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs @@ -10,6 +10,7 @@ mod log; mod util; mod util_c; mod vec_collector; +mod demo_c; pub fn add(left: usize, right: usize) -> usize { left + right diff --git a/internal/core/thirdparty/tantivy/tantivy-wrapper.h b/internal/core/thirdparty/tantivy/tantivy-wrapper.h index 358f14ea49ed0..7574d3875ca24 100644 --- a/internal/core/thirdparty/tantivy/tantivy-wrapper.h +++ b/internal/core/thirdparty/tantivy/tantivy-wrapper.h @@ -1,5 +1,7 @@ #include #include +#include +#include #include "tantivy-binding.h" namespace milvus::tantivy { @@ -186,6 +188,60 @@ struct TantivyIndexWrapper { typeid(T).name()); } + template + void + add_multi_data(const T* array, uintptr_t len) { + assert(!finished_); + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_bools(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int8s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int16s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int32s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int64s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_f32s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_f64s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + std::vector views; + for (uintptr_t i = 0; i < len; i++) { + views.push_back(array[i].c_str()); + } + tantivy_index_add_multi_keywords(writer_, views.data(), len); + return; + } + + throw fmt::format( + "InvertedIndex.add_multi_data: unsupported data type: {}", + typeid(T).name()); + } + inline void finish() { if (!finished_) { diff --git a/internal/core/thirdparty/tantivy/test.cpp b/internal/core/thirdparty/tantivy/test.cpp index 1c67a69673a5c..a380481042487 100644 --- a/internal/core/thirdparty/tantivy/test.cpp +++ b/internal/core/thirdparty/tantivy/test.cpp @@ -200,6 +200,83 @@ test_32717() { } } +std::set +to_set(const RustArrayWrapper& w) { + std::set s(w.array_.array, w.array_.array + w.array_.len); + return s; +} + +template +std::map> +build_inverted_index(const std::vector>& vec_of_array) { + std::map> inverted_index; + for (uint32_t i = 0; i < vec_of_array.size(); i++) { + for (const auto& term : vec_of_array[i]) { + inverted_index[term].insert(i); + } + } + return inverted_index; +} + +void +test_array_int() { + using T = int64_t; + + auto path = "/tmp/inverted-index/test-binding/"; + boost::filesystem::remove_all(path); + boost::filesystem::create_directories(path); + auto w = TantivyIndexWrapper("test_field_name", guess_data_type(), path); + + std::vector> vec_of_array{ + {10, 40, 50}, + {20, 50}, + {10, 50, 60}, + }; + + for (const auto& arr : vec_of_array) { + w.add_multi_data(arr.data(), arr.size()); + } + w.finish(); + + assert(w.count() == vec_of_array.size()); + + auto inverted_index = build_inverted_index(vec_of_array); + for (const auto& [term, posting_list] : inverted_index) { + auto hits = to_set(w.term_query(term)); + assert(posting_list == hits); + } +} + +void +test_array_string() { + using T = std::string; + + auto path = "/tmp/inverted-index/test-binding/"; + boost::filesystem::remove_all(path); + boost::filesystem::create_directories(path); + auto w = + TantivyIndexWrapper("test_field_name", TantivyDataType::Keyword, path); + + std::vector> vec_of_array{ + {"10", "40", "50"}, + {"20", "50"}, + {"10", "50", "60"}, + }; + + for (const auto& arr : vec_of_array) { + w.add_multi_data(arr.data(), arr.size()); + } + w.finish(); + + assert(w.count() == vec_of_array.size()); + + auto inverted_index = build_inverted_index(vec_of_array); + for (const auto& [term, posting_list] : inverted_index) { + auto hits = to_set(w.term_query(term)); + assert(posting_list == hits); + } +} + int main(int argc, char* argv[]) { test_32717(); @@ -216,5 +293,8 @@ main(int argc, char* argv[]) { run(); + test_array_int(); + test_array_string(); + return 0; } diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 657198c9b88c2..e742e25a5a2bb 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -66,6 +66,7 @@ set(MILVUS_TEST_FILES test_group_by.cpp test_regex_query_util.cpp test_regex_query.cpp + test_array_inverted_index.cpp ) if ( BUILD_DISK_ANN STREQUAL "ON" ) diff --git a/internal/core/unittest/test_array_inverted_index.cpp b/internal/core/unittest/test_array_inverted_index.cpp new file mode 100644 index 0000000000000..cd4833b52bf38 --- /dev/null +++ b/internal/core/unittest/test_array_inverted_index.cpp @@ -0,0 +1,297 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICEN_SE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRAN_TIES OR CON_DITION_S OF AN_Y KIN_D, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "pb/plan.pb.h" +#include "index/InvertedIndexTantivy.h" +#include "common/Schema.h" +#include "segcore/SegmentSealedImpl.h" +#include "test_utils/DataGen.h" +#include "test_utils/GenExprProto.h" +#include "query/PlanProto.h" +#include "query/generated/ExecPlanNodeVisitor.h" + +using namespace milvus; +using namespace milvus::query; +using namespace milvus::segcore; + +template +SchemaPtr +GenTestSchema() { + auto schema_ = std::make_shared(); + schema_->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema_->AddDebugField("pk", DataType::INT64); + schema_->set_primary_field_id(pk); + + if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::BOOL); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT8); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT16); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT32); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT64); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::FLOAT); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::DOUBLE); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::VARCHAR); + } + + return schema_; +} + +template +class ArrayInvertedIndexTest : public ::testing::Test { + public: + void + SetUp() override { + schema_ = GenTestSchema(); + seg_ = CreateSealedSegment(schema_); + N_ = 3000; + uint64_t seed = 19190504; + auto raw_data = DataGen(schema_, N_, seed); + auto array_col = + raw_data.get_col(schema_->get_field_id(FieldName("array"))) + ->scalars() + .array_data() + .data(); + for (size_t i = 0; i < N_; i++) { + boost::container::vector array; + if constexpr (std::is_same_v) { + for (size_t j = 0; j < array_col[i].bool_data().data_size(); + j++) { + array.push_back(array_col[i].bool_data().data(j)); + } + } else if constexpr (std::is_same_v) { + for (size_t j = 0; j < array_col[i].long_data().data_size(); + j++) { + array.push_back(array_col[i].long_data().data(j)); + } + } else if constexpr (std::is_integral_v) { + for (size_t j = 0; j < array_col[i].int_data().data_size(); + j++) { + array.push_back(array_col[i].int_data().data(j)); + } + } else if constexpr (std::is_floating_point_v) { + for (size_t j = 0; j < array_col[i].float_data().data_size(); + j++) { + array.push_back(array_col[i].float_data().data(j)); + } + } else if constexpr (std::is_same_v) { + for (size_t j = 0; j < array_col[i].string_data().data_size(); + j++) { + array.push_back(array_col[i].string_data().data(j)); + } + } + vec_of_array_.push_back(array); + } + SealedLoadFieldData(raw_data, *seg_); + LoadInvertedIndex(); + } + + void + TearDown() override { + } + + void + LoadInvertedIndex() { + auto index = std::make_unique>(); + Config cfg; + cfg["is_array"] = true; + index->BuildWithRawData(N_, vec_of_array_.data(), cfg); + LoadIndexInfo info{ + .field_id = schema_->get_field_id(FieldName("array")).get(), + .index = std::move(index), + }; + seg_->LoadIndex(info); + } + + public: + SchemaPtr schema_; + SegmentSealedUPtr seg_; + int64_t N_; + std::vector> vec_of_array_; +}; + +TYPED_TEST_SUITE_P(ArrayInvertedIndexTest); + +TYPED_TEST_P(ArrayInvertedIndexTest, ArrayContainsAny) { + const auto& meta = this->schema_->operator[](FieldName("array")); + auto column_info = test::GenColumnInfo( + meta.get_id().get(), + static_cast(meta.get_data_type()), + false, + false, + static_cast(meta.get_element_type())); + auto contains_expr = std::make_unique(); + contains_expr->set_allocated_column_info(column_info); + contains_expr->set_op(proto::plan::JSONContainsExpr_JSONOp:: + JSONContainsExpr_JSONOp_ContainsAny); + contains_expr->set_elements_same_type(true); + for (const auto& elem : this->vec_of_array_[0]) { + auto t = test::GenGenericValue(elem); + contains_expr->mutable_elements()->AddAllocated(t); + } + auto expr = test::GenExpr(); + expr->set_allocated_json_contains_expr(contains_expr.release()); + + auto parser = ProtoParser(*this->schema_); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(this->seg_.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + + std::unordered_set elems(this->vec_of_array_[0].begin(), + this->vec_of_array_[0].end()); + auto ref = [this, &elems](size_t offset) -> bool { + std::unordered_set row(this->vec_of_array_[offset].begin(), + this->vec_of_array_[offset].end()); + for (const auto& elem : elems) { + if (row.find(elem) != row.end()) { + return true; + } + } + return false; + }; + ASSERT_EQ(final.size(), this->N_); + for (size_t i = 0; i < this->N_; i++) { + ASSERT_EQ(final[i], ref(i)) << "i: " << i << ", final[i]: " << final[i] + << ", ref(i): " << ref(i); + } +} + +TYPED_TEST_P(ArrayInvertedIndexTest, ArrayContainsAll) { + const auto& meta = this->schema_->operator[](FieldName("array")); + auto column_info = test::GenColumnInfo( + meta.get_id().get(), + static_cast(meta.get_data_type()), + false, + false, + static_cast(meta.get_element_type())); + auto contains_expr = std::make_unique(); + contains_expr->set_allocated_column_info(column_info); + contains_expr->set_op(proto::plan::JSONContainsExpr_JSONOp:: + JSONContainsExpr_JSONOp_ContainsAll); + contains_expr->set_elements_same_type(true); + for (const auto& elem : this->vec_of_array_[0]) { + auto t = test::GenGenericValue(elem); + contains_expr->mutable_elements()->AddAllocated(t); + } + auto expr = test::GenExpr(); + expr->set_allocated_json_contains_expr(contains_expr.release()); + + auto parser = ProtoParser(*this->schema_); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(this->seg_.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + + std::unordered_set elems(this->vec_of_array_[0].begin(), + this->vec_of_array_[0].end()); + auto ref = [this, &elems](size_t offset) -> bool { + std::unordered_set row(this->vec_of_array_[offset].begin(), + this->vec_of_array_[offset].end()); + for (const auto& elem : elems) { + if (row.find(elem) == row.end()) { + return false; + } + } + return true; + }; + ASSERT_EQ(final.size(), this->N_); + for (size_t i = 0; i < this->N_; i++) { + ASSERT_EQ(final[i], ref(i)) << "i: " << i << ", final[i]: " << final[i] + << ", ref(i): " << ref(i); + } +} + +TYPED_TEST_P(ArrayInvertedIndexTest, ArrayEqual) { + if (std::is_floating_point_v) { + GTEST_SKIP() << "not accurate to perform equal comparison on floating " + "point number"; + } + + const auto& meta = this->schema_->operator[](FieldName("array")); + auto column_info = test::GenColumnInfo( + meta.get_id().get(), + static_cast(meta.get_data_type()), + false, + false, + static_cast(meta.get_element_type())); + auto unary_range_expr = std::make_unique(); + unary_range_expr->set_allocated_column_info(column_info); + unary_range_expr->set_op(proto::plan::OpType::Equal); + auto arr = new proto::plan::GenericValue; + arr->mutable_array_val()->set_element_type( + static_cast(meta.get_element_type())); + arr->mutable_array_val()->set_same_type(true); + for (const auto& elem : this->vec_of_array_[0]) { + auto e = test::GenGenericValue(elem); + arr->mutable_array_val()->mutable_array()->AddAllocated(e); + } + unary_range_expr->set_allocated_value(arr); + auto expr = test::GenExpr(); + expr->set_allocated_unary_range_expr(unary_range_expr.release()); + + auto parser = ProtoParser(*this->schema_); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(this->seg_.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + + auto ref = [this](size_t offset) -> bool { + if (this->vec_of_array_[0].size() != + this->vec_of_array_[offset].size()) { + return false; + } + auto size = this->vec_of_array_[0].size(); + for (size_t i = 0; i < size; i++) { + if (this->vec_of_array_[0][i] != this->vec_of_array_[offset][i]) { + return false; + } + } + return true; + }; + ASSERT_EQ(final.size(), this->N_); + for (size_t i = 0; i < this->N_; i++) { + ASSERT_EQ(final[i], ref(i)) << "i: " << i << ", final[i]: " << final[i] + << ", ref(i): " << ref(i); + } +} + +using ElementType = testing:: + Types; + +REGISTER_TYPED_TEST_CASE_P(ArrayInvertedIndexTest, + ArrayContainsAny, + ArrayContainsAll, + ArrayEqual); + +INSTANTIATE_TYPED_TEST_SUITE_P(Naive, ArrayInvertedIndexTest, ElementType); diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp index 39f6841957dc4..79581bc96947b 100644 --- a/internal/core/unittest/test_index_wrapper.cpp +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -23,7 +23,7 @@ using namespace milvus; using namespace milvus::segcore; -using namespace milvus::proto::indexcgo; +using namespace milvus::proto; using Param = std::pair; diff --git a/internal/core/unittest/test_inverted_index.cpp b/internal/core/unittest/test_inverted_index.cpp index eeddfe6e9d81a..c8b9bf3663235 100644 --- a/internal/core/unittest/test_inverted_index.cpp +++ b/internal/core/unittest/test_inverted_index.cpp @@ -25,20 +25,25 @@ using namespace milvus; -// TODO: I would suggest that our all indexes use this test to simulate the real production environment. - namespace milvus::test { auto gen_field_meta(int64_t collection_id = 1, int64_t partition_id = 2, int64_t segment_id = 3, - int64_t field_id = 101) -> storage::FieldDataMeta { - return storage::FieldDataMeta{ + int64_t field_id = 101, + DataType data_type = DataType::NONE, + DataType element_type = DataType::NONE) + -> storage::FieldDataMeta { + auto meta = storage::FieldDataMeta{ .collection_id = collection_id, .partition_id = partition_id, .segment_id = segment_id, .field_id = field_id, }; + meta.schema.set_data_type(static_cast(data_type)); + meta.schema.set_element_type( + static_cast(element_type)); + return meta; } auto @@ -86,7 +91,7 @@ struct ChunkManagerWrapper { }; } // namespace milvus::test -template +template void test_run() { int64_t collection_id = 1; @@ -96,8 +101,8 @@ test_run() { int64_t index_build_id = 1000; int64_t index_version = 10000; - auto field_meta = - test::gen_field_meta(collection_id, partition_id, segment_id, field_id); + auto field_meta = test::gen_field_meta( + collection_id, partition_id, segment_id, field_id, dtype, element_type); auto index_meta = test::gen_index_meta( segment_id, field_id, index_build_id, index_version); @@ -305,8 +310,12 @@ test_string() { int64_t index_build_id = 1000; int64_t index_version = 10000; - auto field_meta = - test::gen_field_meta(collection_id, partition_id, segment_id, field_id); + auto field_meta = test::gen_field_meta(collection_id, + partition_id, + segment_id, + field_id, + dtype, + DataType::NONE); auto index_meta = test::gen_index_meta( segment_id, field_id, index_build_id, index_version); diff --git a/internal/core/unittest/test_scalar_index.cpp b/internal/core/unittest/test_scalar_index.cpp index 8b11c89530e9b..f7becf13b492f 100644 --- a/internal/core/unittest/test_scalar_index.cpp +++ b/internal/core/unittest/test_scalar_index.cpp @@ -49,6 +49,14 @@ TYPED_TEST_P(TypedScalarIndexTest, Dummy) { std::cout << milvus::GetDType() << std::endl; } +auto +GetTempFileManagerCtx(CDataType data_type) { + auto ctx = milvus::storage::FileManagerContext(); + ctx.fieldDataMeta.schema.set_data_type( + static_cast(data_type)); + return ctx; +} + TYPED_TEST_P(TypedScalarIndexTest, Constructor) { using T = TypeParam; auto dtype = milvus::GetDType(); @@ -59,7 +67,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Constructor) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); } } @@ -73,7 +81,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Count) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -92,7 +100,7 @@ TYPED_TEST_P(TypedScalarIndexTest, HasRawData) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -112,7 +120,7 @@ TYPED_TEST_P(TypedScalarIndexTest, In) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -131,7 +139,7 @@ TYPED_TEST_P(TypedScalarIndexTest, NotIn) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -150,7 +158,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Reverse) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -169,7 +177,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Range) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -188,7 +196,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -197,7 +205,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) { auto binary_set = index->Serialize(nullptr); auto copy_index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); copy_index->Load(binary_set); auto copy_scalar_index = @@ -368,6 +376,8 @@ TYPED_TEST_P(TypedScalarIndexTestV2, Base) { auto space = TestSpace(temp_path, vec_size, dataset, scalars); milvus::storage::FileManagerContext file_manager_context( {}, {.field_name = "scalar"}, chunk_manager, space); + file_manager_context.fieldDataMeta.schema.set_data_type( + static_cast(dtype)); auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( create_index_info, file_manager_context, space); diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 37c3d6f27676d..3b69ed98e8ec0 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -480,8 +480,30 @@ inline GeneratedData DataGen(SchemaPtr schema, } break; } - case DataType::INT8: - case DataType::INT16: + case DataType::INT8: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_int_data()->add_data( + static_cast(random())); + } + data[i] = field_data; + } + break; + } + case DataType::INT16: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_int_data()->add_data( + static_cast(random())); + } + data[i] = field_data; + } + break; + } case DataType::INT32: { for (int i = 0; i < N / repeat_count; i++) { milvus::proto::schema::ScalarField field_data; diff --git a/internal/core/unittest/test_utils/GenExprProto.h b/internal/core/unittest/test_utils/GenExprProto.h index 171273b1fc7fd..77f0a4964e4bb 100644 --- a/internal/core/unittest/test_utils/GenExprProto.h +++ b/internal/core/unittest/test_utils/GenExprProto.h @@ -15,15 +15,18 @@ namespace milvus::test { inline auto -GenColumnInfo(int64_t field_id, - proto::schema::DataType field_type, - bool auto_id, - bool is_pk) { +GenColumnInfo( + int64_t field_id, + proto::schema::DataType field_type, + bool auto_id, + bool is_pk, + proto::schema::DataType element_type = proto::schema::DataType::None) { auto column_info = new proto::plan::ColumnInfo(); column_info->set_field_id(field_id); column_info->set_data_type(field_type); column_info->set_is_autoid(auto_id); column_info->set_is_primary_key(is_pk); + column_info->set_element_type(element_type); return column_info; } diff --git a/internal/datacoord/index_builder.go b/internal/datacoord/index_builder.go index be03a613ef634..3c87b94d23f60 100644 --- a/internal/datacoord/index_builder.go +++ b/internal/datacoord/index_builder.go @@ -347,28 +347,29 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { } } var req *indexpb.CreateJobRequest - if Params.CommonCfg.EnableStorageV2.GetAsBool() { - collectionInfo, err := ib.handler.GetCollection(ib.ctx, segment.GetCollectionID()) - if err != nil { - log.Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err)) - return false - } + collectionInfo, err := ib.handler.GetCollection(ib.ctx, segment.GetCollectionID()) + if err != nil { + log.Ctx(ib.ctx).Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err)) + return false + } - schema := collectionInfo.Schema - var field *schemapb.FieldSchema + schema := collectionInfo.Schema + var field *schemapb.FieldSchema - for _, f := range schema.Fields { - if f.FieldID == fieldID { - field = f - break - } - } - - dim, err := storage.GetDimFromParams(field.TypeParams) - if err != nil { - return false + for _, f := range schema.Fields { + if f.FieldID == fieldID { + field = f + break } + } + dim, err := storage.GetDimFromParams(field.TypeParams) + if err != nil { + log.Ctx(ib.ctx).Warn("failed to get dim from field type params", + zap.String("field type", field.GetDataType().String()), zap.Error(err)) + // don't return, maybe field is scalar field or sparseFloatVector + } + if Params.CommonCfg.EnableStorageV2.GetAsBool() { storePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue(), segment.GetID()) if err != nil { log.Ctx(ib.ctx).Warn("failed to get storage uri", zap.Error(err)) @@ -402,6 +403,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(), DataIds: binlogIDs, OptionalScalarFields: optionalFields, + Field: field, } } else { req = &indexpb.CreateJobRequest{ @@ -420,6 +422,8 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { SegmentID: segment.GetID(), FieldID: fieldID, OptionalScalarFields: optionalFields, + Dim: int64(dim), + Field: field, } } diff --git a/internal/datacoord/index_builder_test.go b/internal/datacoord/index_builder_test.go index 46d8c7fe3f43e..9488c70f5e818 100644 --- a/internal/datacoord/index_builder_test.go +++ b/internal/datacoord/index_builder_test.go @@ -675,7 +675,30 @@ func TestIndexBuilder(t *testing.T) { chunkManager := &mocks.ChunkManager{} chunkManager.EXPECT().RootPath().Return("root") - ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager(), nil) + handler := NewNMockHandler(t) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Name: "coll", + Fields: []*schemapb.FieldSchema{ + { + FieldID: fieldID, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + }, + }, + EnableDynamicField: false, + Properties: nil, + }, + }, nil) + + ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager(), handler) assert.Equal(t, 6, len(ib.tasks)) assert.Equal(t, indexTaskInit, ib.tasks[buildID]) @@ -741,6 +764,30 @@ func TestIndexBuilder_Error(t *testing.T) { chunkManager := &mocks.ChunkManager{} chunkManager.EXPECT().RootPath().Return("root") + + handler := NewNMockHandler(t) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Name: "coll", + Fields: []*schemapb.FieldSchema{ + { + FieldID: fieldID, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + }, + }, + EnableDynamicField: false, + Properties: nil, + }, + }, nil) + ib := &indexBuilder{ ctx: context.Background(), tasks: map[int64]indexTaskState{ @@ -749,6 +796,7 @@ func TestIndexBuilder_Error(t *testing.T) { meta: createMetaTable(ec), chunkManager: chunkManager, indexEngineVersionManager: newIndexEngineVersionManager(), + handler: handler, } t.Run("meta not exist", func(t *testing.T) { @@ -1414,9 +1462,32 @@ func TestVecIndexWithOptionalScalarField(t *testing.T) { mt.collections[collID].Schema.Fields[1].DataType = schemapb.DataType_VarChar } + handler := NewNMockHandler(t) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Name: "coll", + Fields: []*schemapb.FieldSchema{ + { + FieldID: fieldID, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + }, + }, + EnableDynamicField: false, + Properties: nil, + }, + }, nil) + paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true") defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false") - ib := newIndexBuilder(ctx, &mt, nodeManager, cm, newIndexEngineVersionManager(), nil) + ib := newIndexBuilder(ctx, &mt, nodeManager, cm, newIndexEngineVersionManager(), handler) t.Run("success to get opt field on startup", func(t *testing.T) { ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( diff --git a/internal/indexnode/indexnode_service.go b/internal/indexnode/indexnode_service.go index a690e35e4a10a..fb9d5a0cc19a1 100644 --- a/internal/indexnode/indexnode_service.go +++ b/internal/indexnode/indexnode_service.go @@ -55,6 +55,8 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest defer i.lifetime.Done() log.Info("IndexNode building index ...", zap.Int64("collectionID", req.GetCollectionID()), + zap.Int64("partitionID", req.GetPartitionID()), + zap.Int64("segmentID", req.GetSegmentID()), zap.Int64("indexID", req.GetIndexID()), zap.String("indexName", req.GetIndexName()), zap.String("indexFilePrefix", req.GetIndexFilePrefix()), diff --git a/internal/indexnode/task.go b/internal/indexnode/task.go index b14343900d99c..54c8b3fe45a66 100644 --- a/internal/indexnode/task.go +++ b/internal/indexnode/task.go @@ -18,7 +18,6 @@ package indexnode import ( "context" - "encoding/json" "fmt" "runtime/debug" "strconv" @@ -30,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexcgopb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" @@ -84,12 +84,21 @@ type indexBuildTaskV2 struct { } func (it *indexBuildTaskV2) parseParams(ctx context.Context) error { - it.collectionID = it.req.CollectionID - it.partitionID = it.req.PartitionID - it.segmentID = it.req.SegmentID - it.fieldType = it.req.FieldType - it.fieldID = it.req.FieldID - it.fieldName = it.req.FieldName + it.collectionID = it.req.GetCollectionID() + it.partitionID = it.req.GetPartitionID() + it.segmentID = it.req.GetSegmentID() + it.fieldType = it.req.GetFieldType() + if it.fieldType == schemapb.DataType_None { + it.fieldType = it.req.GetField().GetDataType() + } + it.fieldID = it.req.GetFieldID() + if it.fieldID == 0 { + it.fieldID = it.req.GetField().GetFieldID() + } + it.fieldName = it.req.GetFieldName() + if it.fieldName == "" { + it.fieldName = it.req.GetField().GetName() + } return nil } @@ -138,61 +147,66 @@ func (it *indexBuildTaskV2) BuildIndex(ctx context.Context) error { } } - var buildIndexInfo *indexcgowrapper.BuildIndexInfo - buildIndexInfo, err = indexcgowrapper.NewBuildIndexInfo(it.req.GetStorageConfig()) - defer indexcgowrapper.DeleteBuildIndexInfo(buildIndexInfo) - if err != nil { - log.Ctx(ctx).Warn("create build index info failed", zap.Error(err)) - return err - } - err = buildIndexInfo.AppendFieldMetaInfoV2(it.collectionID, it.partitionID, it.segmentID, it.fieldID, it.fieldType, it.fieldName, it.req.Dim) - if err != nil { - log.Ctx(ctx).Warn("append field meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendIndexMetaInfo(it.req.IndexID, it.req.BuildID, it.req.IndexVersion) - if err != nil { - log.Ctx(ctx).Warn("append index meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendBuildIndexParam(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Warn("append index params failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendIndexStorageInfo(it.req.StorePath, it.req.IndexStorePath, it.req.StoreVersion) - if err != nil { - log.Ctx(ctx).Warn("append storage info failed", zap.Error(err)) - return err - } - - jsonIndexParams, err := json.Marshal(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Error("failed to json marshal index params", zap.Error(err)) - return err - } - - log.Ctx(ctx).Info("index params are ready", - zap.Int64("buildID", it.BuildID), - zap.String("index params", string(jsonIndexParams))) - - err = buildIndexInfo.AppendBuildTypeParam(it.newTypeParams) - if err != nil { - log.Ctx(ctx).Warn("append type params failed", zap.Error(err)) - return err + storageConfig := &indexcgopb.StorageConfig{ + Address: it.req.GetStorageConfig().GetAddress(), + AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: it.req.GetStorageConfig().GetUseSSL(), + BucketName: it.req.GetStorageConfig().GetBucketName(), + RootPath: it.req.GetStorageConfig().GetRootPath(), + UseIAM: it.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: it.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), + Region: it.req.GetStorageConfig().GetRegion(), + CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: it.req.GetStorageConfig().GetSslCACert(), + } + + optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) + for _, optField := range it.req.GetOptionalScalarFields() { + optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ + FieldID: optField.GetFieldID(), + FieldName: optField.GetFieldName(), + FieldType: optField.GetFieldType(), + DataPaths: optField.GetDataPaths(), + }) } - for _, optField := range it.req.GetOptionalScalarFields() { - if err := buildIndexInfo.AppendOptionalField(optField); err != nil { - log.Ctx(ctx).Warn("append optional field failed", zap.Error(err)) - return err + it.currentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) + field := it.req.GetField() + if field == nil || field.GetDataType() == schemapb.DataType_None { + field = &schemapb.FieldSchema{ + FieldID: it.fieldID, + Name: it.fieldName, + DataType: it.fieldType, } } - it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexInfo) + buildIndexParams := &indexcgopb.BuildIndexInfo{ + ClusterID: it.ClusterID, + BuildID: it.BuildID, + CollectionID: it.collectionID, + PartitionID: it.partitionID, + SegmentID: it.segmentID, + IndexVersion: it.req.GetIndexVersion(), + CurrentIndexVersion: it.currentIndexVersion, + NumRows: it.req.GetNumRows(), + Dim: it.req.GetDim(), + IndexFilePrefix: it.req.GetIndexFilePrefix(), + InsertFiles: it.req.GetDataPaths(), + FieldSchema: field, + StorageConfig: storageConfig, + IndexParams: mapToKVPairs(it.newIndexParams), + TypeParams: mapToKVPairs(it.newTypeParams), + StorePath: it.req.GetStorePath(), + StoreVersion: it.req.GetStoreVersion(), + IndexStorePath: it.req.GetIndexStorePath(), + OptFields: optFields, + } + + it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexParams) if err != nil { if it.index != nil && it.index.CleanLocalData() != nil { log.Ctx(ctx).Error("failed to clean cached data on disk after build index failed", @@ -328,7 +342,7 @@ func (it *indexBuildTask) Prepare(ctx context.Context) error { if len(it.req.DataPaths) == 0 { for _, id := range it.req.GetDataIds() { - path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetFieldID(), id) + path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetField().GetFieldID(), id) it.req.DataPaths = append(it.req.DataPaths, path) } } @@ -362,16 +376,10 @@ func (it *indexBuildTask) Prepare(ctx context.Context) error { } it.newTypeParams = typeParams it.newIndexParams = indexParams + it.statistic.IndexParams = it.req.GetIndexParams() - // ugly codes to get dimension - if dimStr, ok := typeParams[common.DimKey]; ok { - var err error - it.statistic.Dim, err = strconv.ParseInt(dimStr, 10, 64) - if err != nil { - log.Ctx(ctx).Error("parse dimesion failed", zap.Error(err)) - // ignore error - } - } + it.statistic.Dim = it.req.GetDim() + log.Ctx(ctx).Info("Successfully prepare indexBuildTask", zap.Int64("buildID", it.BuildID), zap.Int64("Collection", it.collectionID), zap.Int64("SegmentID", it.segmentID)) return nil @@ -482,69 +490,65 @@ func (it *indexBuildTask) BuildIndex(ctx context.Context) error { } } - var buildIndexInfo *indexcgowrapper.BuildIndexInfo - buildIndexInfo, err = indexcgowrapper.NewBuildIndexInfo(it.req.GetStorageConfig()) - defer indexcgowrapper.DeleteBuildIndexInfo(buildIndexInfo) - if err != nil { - log.Ctx(ctx).Warn("create build index info failed", zap.Error(err)) - return err - } - err = buildIndexInfo.AppendFieldMetaInfo(it.collectionID, it.partitionID, it.segmentID, it.fieldID, it.fieldType) - if err != nil { - log.Ctx(ctx).Warn("append field meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendIndexMetaInfo(it.req.IndexID, it.req.BuildID, it.req.IndexVersion) - if err != nil { - log.Ctx(ctx).Warn("append index meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendBuildIndexParam(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Warn("append index params failed", zap.Error(err)) - return err - } - - jsonIndexParams, err := json.Marshal(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Error("failed to json marshal index params", zap.Error(err)) - return err - } - - log.Ctx(ctx).Info("index params are ready", - zap.Int64("buildID", it.BuildID), - zap.String("index params", string(jsonIndexParams))) - - err = buildIndexInfo.AppendBuildTypeParam(it.newTypeParams) - if err != nil { - log.Ctx(ctx).Warn("append type params failed", zap.Error(err)) - return err - } - - for _, path := range it.req.GetDataPaths() { - err = buildIndexInfo.AppendInsertFile(path) - if err != nil { - log.Ctx(ctx).Warn("append insert binlog path failed", zap.Error(err)) - return err - } + storageConfig := &indexcgopb.StorageConfig{ + Address: it.req.GetStorageConfig().GetAddress(), + AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: it.req.GetStorageConfig().GetUseSSL(), + BucketName: it.req.GetStorageConfig().GetBucketName(), + RootPath: it.req.GetStorageConfig().GetRootPath(), + UseIAM: it.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: it.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), + Region: it.req.GetStorageConfig().GetRegion(), + CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: it.req.GetStorageConfig().GetSslCACert(), + } + + optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) + for _, optField := range it.req.GetOptionalScalarFields() { + optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ + FieldID: optField.GetFieldID(), + FieldName: optField.GetFieldName(), + FieldType: optField.GetFieldType(), + DataPaths: optField.GetDataPaths(), + }) } it.currentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) - if err := buildIndexInfo.AppendIndexEngineVersion(it.currentIndexVersion); err != nil { - log.Ctx(ctx).Warn("append index engine version failed", zap.Error(err)) - return err - } - - for _, optField := range it.req.GetOptionalScalarFields() { - if err := buildIndexInfo.AppendOptionalField(optField); err != nil { - log.Ctx(ctx).Warn("append optional field failed", zap.Error(err)) - return err + field := it.req.GetField() + if field == nil || field.GetDataType() == schemapb.DataType_None { + field = &schemapb.FieldSchema{ + FieldID: it.fieldID, + Name: it.fieldName, + DataType: it.fieldType, } } - - it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexInfo) + buildIndexParams := &indexcgopb.BuildIndexInfo{ + ClusterID: it.ClusterID, + BuildID: it.BuildID, + CollectionID: it.collectionID, + PartitionID: it.partitionID, + SegmentID: it.segmentID, + IndexVersion: it.req.GetIndexVersion(), + CurrentIndexVersion: it.currentIndexVersion, + NumRows: it.req.GetNumRows(), + Dim: it.req.GetDim(), + IndexFilePrefix: it.req.GetIndexFilePrefix(), + InsertFiles: it.req.GetDataPaths(), + FieldSchema: field, + StorageConfig: storageConfig, + IndexParams: mapToKVPairs(it.newIndexParams), + TypeParams: mapToKVPairs(it.newTypeParams), + StorePath: it.req.GetStorePath(), + StoreVersion: it.req.GetStoreVersion(), + IndexStorePath: it.req.GetIndexStorePath(), + OptFields: optFields, + } + + it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexParams) if err != nil { if it.index != nil && it.index.CleanLocalData() != nil { log.Ctx(ctx).Error("failed to clean cached data on disk after build index failed", @@ -653,8 +657,6 @@ func (it *indexBuildTask) decodeBlobs(ctx context.Context, blobs []*storage.Blob deserializeDur := it.tr.RecordSpan() log.Ctx(ctx).Info("IndexNode deserialize data success", - zap.Int64("index id", it.req.IndexID), - zap.String("index name", it.req.IndexName), zap.Int64("collectionID", it.collectionID), zap.Int64("partitionID", it.partitionID), zap.Int64("segmentID", it.segmentID), diff --git a/internal/indexnode/task_test.go b/internal/indexnode/task_test.go index dc30abd800eec..6450c3e504a71 100644 --- a/internal/indexnode/task_test.go +++ b/internal/indexnode/task_test.go @@ -283,12 +283,14 @@ func (suite *IndexBuildTaskV2Suite) TestBuildIndex() { RootPath: "/tmp/milvus/data", StorageType: "local", }, - CollectionID: 1, - PartitionID: 1, - SegmentID: 1, - FieldID: 3, - FieldName: "vec", - FieldType: schemapb.DataType_FloatVector, + CollectionID: 1, + PartitionID: 1, + SegmentID: 1, + Field: &schemapb.FieldSchema{ + FieldID: 3, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + }, StorePath: "file://" + suite.space.Path(), StoreVersion: suite.space.GetCurrentVersion(), IndexStorePath: "file://" + suite.space.Path(), diff --git a/internal/indexnode/util.go b/internal/indexnode/util.go index 9186f9855a81b..8aaa92910503f 100644 --- a/internal/indexnode/util.go +++ b/internal/indexnode/util.go @@ -19,6 +19,7 @@ package indexnode import ( "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -36,3 +37,14 @@ func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType) return 0, nil } } + +func mapToKVPairs(m map[string]string) []*commonpb.KeyValuePair { + kvs := make([]*commonpb.KeyValuePair, 0, len(m)) + for k, v := range m { + kvs = append(kvs, &commonpb.KeyValuePair{ + Key: k, + Value: v, + }) + } + return kvs +} diff --git a/internal/indexnode/util_test.go b/internal/indexnode/util_test.go new file mode 100644 index 0000000000000..6d7d98e823240 --- /dev/null +++ b/internal/indexnode/util_test.go @@ -0,0 +1,41 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexnode + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type utilSuite struct { + suite.Suite +} + +func (s *utilSuite) Test_mapToKVPairs() { + indexParams := map[string]string{ + "index_type": "IVF_FLAT", + "dim": "128", + "nlist": "1024", + } + + s.Equal(3, len(mapToKVPairs(indexParams))) +} + +func Test_utilSuite(t *testing.T) { + suite.Run(t, new(utilSuite)) +} diff --git a/internal/proto/cgo_msg.proto b/internal/proto/cgo_msg.proto new file mode 100644 index 0000000000000..6d851e95e0550 --- /dev/null +++ b/internal/proto/cgo_msg.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package milvus.proto.cgo; +option go_package="github.com/milvus-io/milvus/internal/proto/cgopb"; + +import "schema.proto"; + +message LoadIndexInfo { + int64 collectionID = 1; + int64 partitionID = 2; + int64 segmentID = 3; + schema.FieldSchema field = 5; + bool enable_mmap = 6; + string mmap_dir_path = 7; + int64 indexID = 8; + int64 index_buildID = 9; + int64 index_version = 10; + map index_params = 11; + repeated string index_files = 12; + string uri = 13; + int64 index_store_version = 14; + int32 index_engine_version = 15; +} diff --git a/internal/proto/index_cgo_msg.proto b/internal/proto/index_cgo_msg.proto index 50b1ea5dde5a5..688f871f55aed 100644 --- a/internal/proto/index_cgo_msg.proto +++ b/internal/proto/index_cgo_msg.proto @@ -4,6 +4,7 @@ package milvus.proto.indexcgo; option go_package="github.com/milvus-io/milvus/internal/proto/indexcgopb"; import "common.proto"; +import "schema.proto"; message TypeParams { repeated common.KeyValuePair params = 1; @@ -30,3 +31,52 @@ message Binary { message BinarySet { repeated Binary datas = 1; } + +// Synchronously modify StorageConfig in index_coord.proto file +message StorageConfig { + string address = 1; + string access_keyID = 2; + string secret_access_key = 3; + bool useSSL = 4; + string bucket_name = 5; + string root_path = 6; + bool useIAM = 7; + string IAMEndpoint = 8; + string storage_type = 9; + bool use_virtual_host = 10; + string region = 11; + string cloud_provider = 12; + int64 request_timeout_ms = 13; + string sslCACert = 14; +} + +// Synchronously modify OptionalFieldInfo in index_coord.proto file +message OptionalFieldInfo { + int64 fieldID = 1; + string field_name = 2; + int32 field_type = 3; + repeated string data_paths = 4; +} + +message BuildIndexInfo { + string clusterID = 1; + int64 buildID = 2; + int64 collectionID = 3; + int64 partitionID = 4; + int64 segmentID = 5; + int64 index_version = 6; + int32 current_index_version = 7; + int64 num_rows = 8; + int64 dim = 9; + string index_file_prefix = 10; + repeated string insert_files = 11; +// repeated int64 data_ids = 12; + schema.FieldSchema field_schema = 12; + StorageConfig storage_config = 13; + repeated common.KeyValuePair index_params = 14; + repeated common.KeyValuePair type_params = 15; + string store_path = 16; + int64 store_version = 17; + string index_store_path = 18; + repeated OptionalFieldInfo opt_fields = 19; +} diff --git a/internal/proto/index_coord.proto b/internal/proto/index_coord.proto index d59452b17d2de..9204d7da2a9c7 100644 --- a/internal/proto/index_coord.proto +++ b/internal/proto/index_coord.proto @@ -8,6 +8,7 @@ import "common.proto"; import "internal.proto"; import "milvus.proto"; import "schema.proto"; +import "index_cgo_msg.proto"; service IndexCoord { rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} @@ -226,6 +227,7 @@ message GetIndexBuildProgressResponse { int64 pending_index_rows = 4; } +// Synchronously modify StorageConfig in index_cgo_msg.proto file message StorageConfig { string address = 1; string access_keyID = 2; @@ -243,6 +245,7 @@ message StorageConfig { string sslCACert = 14; } +// Synchronously modify OptionalFieldInfo in index_cgo_msg.proto file message OptionalFieldInfo { int64 fieldID = 1; string field_name = 2; @@ -276,6 +279,7 @@ message CreateJobRequest { int64 dim = 22; repeated int64 data_ids = 23; repeated OptionalFieldInfo optional_scalar_fields = 24; + schema.FieldSchema field = 25; } message QueryJobsRequest { diff --git a/internal/querynodev2/segments/load_index_info.go b/internal/querynodev2/segments/load_index_info.go index c5c1572475c40..04632bed95f2d 100644 --- a/internal/querynodev2/segments/load_index_info.go +++ b/internal/querynodev2/segments/load_index_info.go @@ -29,11 +29,13 @@ import ( "runtime" "unsafe" + "github.com/golang/protobuf/proto" "github.com/pingcap/log" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datacoord" + "github.com/milvus-io/milvus/internal/proto/cgopb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/pkg/common" @@ -245,3 +247,33 @@ func (li *LoadIndexInfo) appendIndexEngineVersion(ctx context.Context, indexEngi return HandleCStatus(ctx, &status, "AppendIndexEngineVersion failed") } + +func (li *LoadIndexInfo) finish(ctx context.Context, info *cgopb.LoadIndexInfo) error { + marshaled, err := proto.Marshal(info) + if err != nil { + return err + } + + var status C.CStatus + _, _ = GetDynamicPool().Submit(func() (any, error) { + status = C.FinishLoadIndexInfo(li.cLoadIndexInfo, (*C.uint8_t)(unsafe.Pointer(&marshaled[0])), (C.uint64_t)(len(marshaled))) + return nil, nil + }).Await() + + if err := HandleCStatus(ctx, &status, "FinishLoadIndexInfo failed"); err != nil { + return err + } + + _, _ = GetLoadPool().Submit(func() (any, error) { + if paramtable.Get().CommonCfg.EnableStorageV2.GetAsBool() { + status = C.AppendIndexV3(li.cLoadIndexInfo) + } else { + traceCtx := ParseCTraceContext(ctx) + status = C.AppendIndexV2(traceCtx.ctx, li.cLoadIndexInfo) + runtime.KeepAlive(traceCtx) + } + return nil, nil + }).Await() + + return HandleCStatus(ctx, &status, "AppendIndex failed") +} diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index ea864607e0090..075111e7b2b04 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -45,6 +45,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" milvus_storage "github.com/milvus-io/milvus-storage/go/storage" "github.com/milvus-io/milvus-storage/go/storage/options" + "github.com/milvus-io/milvus/internal/proto/cgopb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" @@ -56,6 +57,9 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" + "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -1262,18 +1266,58 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn return err } defer deleteLoadIndexInfo(loadIndexInfo) + + schema, err := typeutil.CreateSchemaHelper(s.GetCollection().Schema()) + if err != nil { + return err + } + fieldSchema, err := schema.GetFieldFromID(indexInfo.GetFieldID()) + if err != nil { + return err + } + + indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams) + // as Knowhere reports error if encounter an unknown param, we need to delete it + delete(indexParams, common.MmapEnabledKey) + + // some build params also exist in indexParams, which are useless during loading process + if indexParams["index_type"] == indexparamcheck.IndexDISKANN { + if err := indexparams.SetDiskIndexLoadParams(paramtable.Get(), indexParams, indexInfo.GetNumRows()); err != nil { + return err + } + } + + if err := indexparams.AppendPrepareLoadParams(paramtable.Get(), indexParams); err != nil { + return err + } + + indexInfoProto := &cgopb.LoadIndexInfo{ + CollectionID: s.Collection(), + PartitionID: s.Partition(), + SegmentID: s.ID(), + Field: fieldSchema, + EnableMmap: isIndexMmapEnable(indexInfo), + MmapDirPath: paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue(), + IndexID: indexInfo.GetIndexID(), + IndexBuildID: indexInfo.GetBuildID(), + IndexVersion: indexInfo.GetIndexVersion(), + IndexParams: indexParams, + IndexFiles: indexInfo.GetIndexFilePaths(), + IndexEngineVersion: indexInfo.GetCurrentIndexVersion(), + IndexStoreVersion: indexInfo.GetIndexStoreVersion(), + } + if paramtable.Get().CommonCfg.EnableStorageV2.GetAsBool() { uri, err := typeutil_internal.GetStorageURI(paramtable.Get().CommonCfg.StorageScheme.GetValue(), paramtable.Get().CommonCfg.StoragePathPrefix.GetValue(), s.ID()) if err != nil { return err } - loadIndexInfo.appendStorageInfo(uri, indexInfo.IndexStoreVersion) + indexInfoProto.Uri = uri } newLoadIndexInfoSpan := tr.RecordSpan() // 2. - err = loadIndexInfo.appendLoadIndexInfo(ctx, indexInfo, s.Collection(), s.Partition(), s.ID(), fieldType) - if err != nil { + if err := loadIndexInfo.finish(ctx, indexInfoProto); err != nil { if loadIndexInfo.cleanLocalData(ctx) != nil { log.Warn("failed to clean cached data on disk after append index failed", zap.Int64("buildID", indexInfo.BuildID), diff --git a/internal/util/indexcgowrapper/index.go b/internal/util/indexcgowrapper/index.go index f0850b3b916de..a7cc7d0e9b21c 100644 --- a/internal/util/indexcgowrapper/index.go +++ b/internal/util/indexcgowrapper/index.go @@ -16,6 +16,7 @@ import ( "unsafe" "github.com/golang/protobuf/proto" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -94,9 +95,17 @@ func NewCgoIndex(dtype schemapb.DataType, typeParams, indexParams map[string]str return index, nil } -func CreateIndex(ctx context.Context, buildIndexInfo *BuildIndexInfo) (CodecIndex, error) { +func CreateIndex(ctx context.Context, buildIndexInfo *indexcgopb.BuildIndexInfo) (CodecIndex, error) { + buildIndexInfoBlob, err := proto.Marshal(buildIndexInfo) + if err != nil { + log.Ctx(ctx).Warn("marshal buildIndexInfo failed", + zap.String("clusterID", buildIndexInfo.GetClusterID()), + zap.Int64("buildID", buildIndexInfo.GetBuildID()), + zap.Error(err)) + return nil, err + } var indexPtr C.CIndex - status := C.CreateIndex(&indexPtr, buildIndexInfo.cBuildIndexInfo) + status := C.CreateIndex(&indexPtr, (*C.uint8_t)(unsafe.Pointer(&buildIndexInfoBlob[0])), (C.uint64_t)(len(buildIndexInfoBlob))) if err := HandleCStatus(&status, "failed to create index"); err != nil { return nil, err } @@ -109,9 +118,17 @@ func CreateIndex(ctx context.Context, buildIndexInfo *BuildIndexInfo) (CodecInde return index, nil } -func CreateIndexV2(ctx context.Context, buildIndexInfo *BuildIndexInfo) (CodecIndex, error) { +func CreateIndexV2(ctx context.Context, buildIndexInfo *indexcgopb.BuildIndexInfo) (CodecIndex, error) { + buildIndexInfoBlob, err := proto.Marshal(buildIndexInfo) + if err != nil { + log.Ctx(ctx).Warn("marshal buildIndexInfo failed", + zap.String("clusterID", buildIndexInfo.GetClusterID()), + zap.Int64("buildID", buildIndexInfo.GetBuildID()), + zap.Error(err)) + return nil, err + } var indexPtr C.CIndex - status := C.CreateIndexV2(&indexPtr, buildIndexInfo.cBuildIndexInfo) + status := C.CreateIndexV2(&indexPtr, (*C.uint8_t)(unsafe.Pointer(&buildIndexInfoBlob[0])), (C.uint64_t)(len(buildIndexInfoBlob))) if err := HandleCStatus(&status, "failed to create index"); err != nil { return nil, err } diff --git a/pkg/util/indexparamcheck/inverted_checker.go b/pkg/util/indexparamcheck/inverted_checker.go index b15549cd4b7a6..dfc24127d3569 100644 --- a/pkg/util/indexparamcheck/inverted_checker.go +++ b/pkg/util/indexparamcheck/inverted_checker.go @@ -17,7 +17,8 @@ func (c *INVERTEDChecker) CheckTrain(params map[string]string) error { } func (c *INVERTEDChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) { + if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && + !typeutil.IsArrayType(dType) { return fmt.Errorf("INVERTED are not supported on %s field", dType.String()) } return nil diff --git a/pkg/util/indexparamcheck/inverted_checker_test.go b/pkg/util/indexparamcheck/inverted_checker_test.go index afe41f89f1193..7a31290061490 100644 --- a/pkg/util/indexparamcheck/inverted_checker_test.go +++ b/pkg/util/indexparamcheck/inverted_checker_test.go @@ -18,8 +18,8 @@ func Test_INVERTEDIndexChecker(t *testing.T) { assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Bool)) assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Int64)) assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Float)) + assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Array)) assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_Array)) assert.Error(t, c.CheckValidDataType(schemapb.DataType_FloatVector)) } diff --git a/scripts/generate_proto.sh b/scripts/generate_proto.sh index 2551f586c9f9c..286570b842aa8 100755 --- a/scripts/generate_proto.sh +++ b/scripts/generate_proto.sh @@ -44,6 +44,7 @@ pushd ${PROTO_DIR} mkdir -p etcdpb mkdir -p indexcgopb +mkdir -p cgopb mkdir -p internalpb mkdir -p rootcoordpb @@ -62,6 +63,7 @@ protoc_opt="${PROTOC_BIN} --proto_path=${API_PROTO_DIR} --proto_path=." ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./etcdpb etcd_meta.proto || { echo 'generate etcd_meta.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./indexcgopb index_cgo_msg.proto || { echo 'generate index_cgo_msg failed '; exit 1; } +${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./cgopb cgo_msg.proto || { echo 'generate cgo_msg failed '; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./rootcoordpb root_coord.proto || { echo 'generate root_coord.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./internalpb internal.proto || { echo 'generate internal.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./proxypb proxy.proto|| { echo 'generate proxy.proto failed'; exit 1; } @@ -78,6 +80,7 @@ ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb schema.proto|| { echo 'generate sche ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb common.proto|| { echo 'generate common.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb segcore.proto|| { echo 'generate segcore.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb index_cgo_msg.proto|| { echo 'generate index_cgo_msg.proto failed'; exit 1; } +${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb cgo_msg.proto|| { echo 'generate cgo_msg.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb plan.proto|| { echo 'generate plan.proto failed'; exit 1; } popd diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index 21962385028d1..6e9d914625e67 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -1313,10 +1313,7 @@ def test_create_inverted_index_on_array_field(self): collection_w = self.init_collection_wrap(schema=schema) # 2. create index scalar_index_params = {"index_type": "INVERTED"} - collection_w.create_index(ct.default_int32_array_field_name, index_params=scalar_index_params, - check_task=CheckTasks.err_res, - check_items={ct.err_code: 1100, - ct.err_msg: "create index on Array field is not supported"}) + collection_w.create_index(ct.default_int32_array_field_name, index_params=scalar_index_params) @pytest.mark.tags(CaseLabel.L1) def test_create_inverted_index_no_vector_index(self):