From 511edd29fd8dd6c98d0c778d869fd20df1173d69 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Wed, 20 Nov 2024 14:22:30 +0800 Subject: [PATCH] enhance: disallow get raw vector data of a BM25 Function output field (#37800) issue: https://github.com/milvus-io/milvus/issues/35853 Signed-off-by: Buqian Zheng --- internal/proxy/meta_cache.go | 4 +++ internal/proxy/task_test.go | 33 +++++++++++++++---- internal/proxy/util.go | 17 ++++++---- pkg/util/typeutil/schema.go | 14 ++++++++ .../testcases/test_full_text_search.py | 24 +++++++------- 5 files changed, 66 insertions(+), 26 deletions(-) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 5e4fe3ce5ddda..ff247048dfd13 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -252,6 +252,10 @@ func (s *schemaInfo) IsFieldLoaded(fieldID int64) bool { return s.schemaHelper.IsFieldLoaded(fieldID) } +func (s *schemaInfo) CanRetrieveRawFieldData(field *schemapb.FieldSchema) bool { + return s.schemaHelper.CanRetrieveRawFieldData(field) +} + // partitionInfos contains the cached collection partition informations. type partitionInfos struct { partitionInfos []*partitionInfo diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index c153c75e2c97e..66c823a9f2e2c 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -460,12 +460,13 @@ func constructSearchRequest( func TestTranslateOutputFields(t *testing.T) { const ( - idFieldName = "id" - tsFieldName = "timestamp" - floatVectorFieldName = "float_vector" - binaryVectorFieldName = "binary_vector" - float16VectorFieldName = "float16_vector" - bfloat16VectorFieldName = "bfloat16_vector" + idFieldName = "id" + tsFieldName = "timestamp" + floatVectorFieldName = "float_vector" + binaryVectorFieldName = "binary_vector" + float16VectorFieldName = "float16_vector" + bfloat16VectorFieldName = "bfloat16_vector" + sparseFloatVectorFieldName = "sparse_float_vector" ) var outputFields []string var userOutputFields []string @@ -483,6 +484,15 @@ func TestTranslateOutputFields(t *testing.T) { {Name: binaryVectorFieldName, FieldID: 101, DataType: schemapb.DataType_BinaryVector}, {Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector}, {Name: bfloat16VectorFieldName, FieldID: 103, DataType: schemapb.DataType_BFloat16Vector}, + {Name: sparseFloatVectorFieldName, FieldID: 104, DataType: schemapb.DataType_SparseFloatVector, IsFunctionOutput: true}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "bm25", + Type: schemapb.FunctionType_BM25, + OutputFieldNames: []string{sparseFloatVectorFieldName}, + // omit other fields for brevity + }, }, } schema := newSchemaInfo(collSchema) @@ -511,6 +521,7 @@ func TestTranslateOutputFields(t *testing.T) { assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, userOutputFields) assert.ElementsMatch(t, []string{}, userDynamicFields) + // sparse_float_vector is a BM25 function output field, so it should not be included in the output fields outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*"}, schema, false) assert.Equal(t, nil, err) assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields) @@ -535,6 +546,14 @@ func TestTranslateOutputFields(t *testing.T) { assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields) assert.ElementsMatch(t, []string{}, userDynamicFields) + // sparse_float_vector is a BM25 function output field, so it should not be included in the output fields + _, _, _, err = translateOutputFields([]string{"*", sparseFloatVectorFieldName}, schema, false) + assert.Error(t, err) + _, _, _, err = translateOutputFields([]string{sparseFloatVectorFieldName}, schema, false) + assert.Error(t, err) + _, _, _, err = translateOutputFields([]string{sparseFloatVectorFieldName}, schema, true) + assert.Error(t, err) + //========================================================================= outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{}, schema, true) assert.Equal(t, nil, err) @@ -578,7 +597,7 @@ func TestTranslateOutputFields(t *testing.T) { assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields) assert.ElementsMatch(t, []string{}, userDynamicFields) - outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"A"}, schema, true) + _, _, _, err = translateOutputFields([]string{"A"}, schema, true) assert.Error(t, err) t.Run("enable dynamic schema", func(t *testing.T) { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 9541d4089551d..326823fc32ec4 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1204,7 +1204,7 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, []string, error) { var primaryFieldName string var dynamicField *schemapb.FieldSchema - allFieldNameMap := make(map[string]int64) + allFieldNameMap := make(map[string]*schemapb.FieldSchema) resultFieldNameMap := make(map[string]bool) resultFieldNames := make([]string, 0) userOutputFieldsMap := make(map[string]bool) @@ -1219,23 +1219,26 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary if field.IsDynamic { dynamicField = field } - allFieldNameMap[field.Name] = field.GetFieldID() + allFieldNameMap[field.Name] = field } for _, outputFieldName := range outputFields { outputFieldName = strings.TrimSpace(outputFieldName) if outputFieldName == "*" { - for fieldName, fieldID := range allFieldNameMap { - // skip Cold field - if schema.IsFieldLoaded(fieldID) { + for fieldName, field := range allFieldNameMap { + // skip Cold field and fields that can't be output + if schema.IsFieldLoaded(field.GetFieldID()) && schema.CanRetrieveRawFieldData(field) { resultFieldNameMap[fieldName] = true userOutputFieldsMap[fieldName] = true } } useAllDyncamicFields = true } else { - if fieldID, ok := allFieldNameMap[outputFieldName]; ok { - if schema.IsFieldLoaded(fieldID) { + if field, ok := allFieldNameMap[outputFieldName]; ok { + if !schema.CanRetrieveRawFieldData(field) { + return nil, nil, nil, fmt.Errorf("not allowed to retrieve raw data of field %s", outputFieldName) + } + if schema.IsFieldLoaded(field.GetFieldID()) { resultFieldNameMap[outputFieldName] = true userOutputFieldsMap[outputFieldName] = true } else { diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 8739552c4d9f1..d69687603f72d 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -444,6 +444,20 @@ func (helper *SchemaHelper) GetFunctionByOutputField(field *schemapb.FieldSchema return nil, fmt.Errorf("function not exist") } +// As of now, only BM25 function output field is not supported to retrieve raw field data +func (helper *SchemaHelper) CanRetrieveRawFieldData(field *schemapb.FieldSchema) bool { + if !field.IsFunctionOutput { + return true + } + + f, err := helper.GetFunctionByOutputField(field) + if err != nil { + return false + } + + return f.GetType() != schemapb.FunctionType_BM25 +} + func (helper *SchemaHelper) GetCollectionName() string { return helper.schema.Name } diff --git a/tests/python_client/testcases/test_full_text_search.py b/tests/python_client/testcases/test_full_text_search.py index acb66952c9669..f37ae4632e8a1 100644 --- a/tests/python_client/testcases/test_full_text_search.py +++ b/tests/python_client/testcases/test_full_text_search.py @@ -952,7 +952,7 @@ def test_insert_for_full_text_search_with_part_of_empty_string(self, tokenizer): # query with expr res, _ = collection_w.query( expr="id >= 0", - output_fields=["text_sparse_emb", "text"] + output_fields=["text"] ) assert len(res) == len(data) @@ -965,7 +965,7 @@ def test_insert_for_full_text_search_with_part_of_empty_string(self, tokenizer): anns_field="text_sparse_emb", param={}, limit=limit, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) assert len(res_list) == nq for i in range(nq): assert len(res_list[i]) == limit @@ -1536,7 +1536,7 @@ def test_delete_for_full_text_search(self, tokenizer): anns_field="text_sparse_emb", param={}, limit=100, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) for i in range(len(res_list)): query_text = search_data[i] result_texts = [r.text for r in res_list[i]] @@ -2262,7 +2262,7 @@ def test_full_text_search_default( param={}, limit=limit + offset, offset=0, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) full_res_id_list = [] for i in range(nq): res = full_res_list[i] @@ -2278,7 +2278,7 @@ def test_full_text_search_default( param={}, limit=limit, offset=offset, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) # verify correctness for i in range(nq): @@ -2462,7 +2462,7 @@ def test_full_text_search_with_jieba_tokenizer( param={}, limit=limit + offset, offset=0, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) full_res_id_list = [] for i in range(nq): res = full_res_list[i] @@ -2478,7 +2478,7 @@ def test_full_text_search_with_jieba_tokenizer( param={}, limit=limit, offset=offset, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) # verify correctness for i in range(nq): @@ -2637,7 +2637,7 @@ def test_full_text_search_with_range_search( param={ }, limit=limit, # get a wider range of search result - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) distance_list = [] for i in range(nq): @@ -2660,7 +2660,7 @@ def test_full_text_search_with_range_search( } }, limit=limit, - output_fields=["id", "text", "text_sparse_emb"]) + output_fields=["id", "text"]) # verify correctness for i in range(nq): log.info(f"res: {len(res_list[i])}") @@ -2804,7 +2804,7 @@ def test_full_text_search_with_search_iterator( param={ "metric_type": "BM25", }, - output_fields=["id", "text", "text_sparse_emb"], + output_fields=["id", "text"], limit=limit ) iter_result = [] @@ -2948,7 +2948,7 @@ def test_search_for_full_text_search_with_empty_string_search_data( anns_field="text_sparse_emb", param={}, limit=limit, - output_fields=["id", "text", "text_sparse_emb"], + output_fields=["id", "text"], ) assert len(res) == nq for r in res: @@ -3089,7 +3089,7 @@ def test_search_for_full_text_search_with_invalid_search_data( anns_field="text_sparse_emb", param={}, limit=limit, - output_fields=["id", "text", "text_sparse_emb"], + output_fields=["id", "text"], check_task=CheckTasks.err_res, check_items=error )