From f888267d37f5eb8f428fcf76b1a5ff9ccbf32231 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 5 Dec 2024 14:08:41 +0800 Subject: [PATCH 1/4] enhance: [2.4] Fill version for load delta request (#38212) (#38228) Cherry-pick from master pr: #38212 Version is needed for load delta request in case of false alarm warning about version go backward Signed-off-by: Congqi Xia --- internal/querynodev2/delegator/delta_forward.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/querynodev2/delegator/delta_forward.go b/internal/querynodev2/delegator/delta_forward.go index 2de278b42ee5f..09905c4243389 100644 --- a/internal/querynodev2/delegator/delta_forward.go +++ b/internal/querynodev2/delegator/delta_forward.go @@ -176,6 +176,7 @@ func (sd *shardDelegator) forwardL0RemoteLoad(ctx context.Context, LoadScope: querypb.LoadScope_Delta, Schema: req.GetSchema(), IndexInfoList: req.GetIndexInfoList(), + Version: req.GetVersion(), }) } From 67a004ca20afca23b88267046dfeb987b411e5a5 Mon Sep 17 00:00:00 2001 From: jaime Date: Thu, 5 Dec 2024 14:20:41 +0800 Subject: [PATCH 2/4] fix: invalid rate limit for time tick delay (#38218) issue: #38217 Signed-off-by: jaime --- internal/datanode/flow_graph_write_node.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/datanode/flow_graph_write_node.go b/internal/datanode/flow_graph_write_node.go index 2425a9591bf50..16e973ef4454b 100644 --- a/internal/datanode/flow_graph_write_node.go +++ b/internal/datanode/flow_graph_write_node.go @@ -98,6 +98,7 @@ func (wNode *writeNode) Operate(in []Msg) []Msg { }) wNode.updater.update(wNode.channelName, end.GetTimestamp(), stats) + rateCol.updateFlowGraphTt(wNode.channelName, end.GetTimestamp()) res := flowGraphMsg{ timeRange: fgMsg.timeRange, From 3d98e8e690b6b3ec9e3130d348cc021c49b8fc1d Mon Sep 17 00:00:00 2001 From: smellthemoon <64083300+smellthemoon@users.noreply.github.com> Date: Thu, 5 Dec 2024 14:26:42 +0800 Subject: [PATCH 3/4] enhance: support templates for expression in Restful api(#38040) (#38161) pr: #38040 issue: #36672 Signed-off-by: lixinguo Co-authored-by: lixinguo --- .../proxy/httpserver/handler_v2.go | 10 +- .../proxy/httpserver/handler_v2_test.go | 4 +- .../proxy/httpserver/request_v2.go | 66 +++---- .../distributed/proxy/httpserver/utils.go | 167 ++++++++++++++++++ .../proxy/httpserver/utils_test.go | 140 +++++++++++++++ 5 files changed, 351 insertions(+), 36 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 5ab7a2be52f1f..6240b0a1705f8 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -622,6 +622,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa QueryParams: []*commonpb.KeyValuePair{}, UseDefaultConsistency: true, } + req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) c.Set(ContextRequest, req) if httpReq.Offset > 0 { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) @@ -713,6 +714,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN PartitionName: httpReq.PartitionName, Expr: httpReq.Filter, } + req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) c.Set(ContextRequest, req) if req.Expr == "" { body, _ := c.Get(gin.BodyBytesKey) @@ -925,7 +927,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair { +func generateSearchParams(reqSearchParams searchParams) []*commonpb.KeyValuePair { var searchParams []*commonpb.KeyValuePair if reqSearchParams.Params == nil { reqSearchParams.Params = make(map[string]any) @@ -965,7 +967,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN return nil, err } - searchParams := generateSearchParams(ctx, c, httpReq.SearchParams) + searchParams := generateSearchParams(httpReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) @@ -982,6 +984,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN } req.SearchParams = searchParams req.PlaceholderGroup = placeholderGroup + req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest)) }) @@ -1034,7 +1037,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams := generateSearchParams(ctx, c, subReq.SearchParams) + searchParams := generateSearchParams(subReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) @@ -1058,6 +1061,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq PartitionNames: httpReq.PartitionNames, SearchParams: searchParams, } + searchReq.ExprTemplateValues = generateExpressionTemplate(subReq.ExprParams) req.Requests = append(req.Requests, searchReq) } bs, _ := json.Marshal(httpReq.Rerank.Params) diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 9cb86a26cbe4f..137360264ac68 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1531,7 +1531,7 @@ func TestSearchV2(t *testing.T) { queryTestCases := []requestBodyTestCase{} queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]}, "limit": 4, "outputFields": ["word_count"]}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, @@ -1612,7 +1612,7 @@ func TestSearchV2(t *testing.T) { queryTestCases = append(queryTestCases, requestBodyTestCase{ path: AdvancedSearchAction, requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + - `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]},"metricType": "L2", "limit": 3},` + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 5ae7babd346a0..ec792894b1333 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -112,13 +112,14 @@ type JobIDReq struct { func (req *JobIDReq) GetJobID() string { return req.JobID } type QueryReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionNames []string `json:"partitionNames"` - OutputFields []string `json:"outputFields"` - Filter string `json:"filter"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + OutputFields []string `json:"outputFields"` + Filter string `json:"filter"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + ExprParams map[string]interface{} `json:"exprParams"` } func (req *QueryReqV2) GetDbName() string { return req.DbName } @@ -135,10 +136,11 @@ type CollectionIDReq struct { func (req *CollectionIDReq) GetDbName() string { return req.DbName } type CollectionFilterReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionName string `json:"partitionName"` - Filter string `json:"filter" binding:"required"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName"` + Filter string `json:"filter" binding:"required"` + ExprParams map[string]interface{} `json:"exprParams"` } func (req *CollectionFilterReq) GetDbName() string { return req.DbName } @@ -160,18 +162,19 @@ type searchParams struct { } type SearchReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - PartitionNames []string `json:"partitionNames"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - SearchParams searchParams `json:"searchParams"` - ConsistencyLevel string `json:"consistencyLevel"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + SearchParams searchParams `json:"searchParams"` + ConsistencyLevel string `json:"consistencyLevel"` + ExprParams map[string]interface{} `json:"exprParams"` // not use Params any more, just for compatibility Params map[string]float64 `json:"params"` } @@ -184,14 +187,15 @@ type Rand struct { } type SubSearchReq struct { - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - MetricType string `json:"metricType"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - SearchParams searchParams `json:"params"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + SearchParams searchParams `json:"params"` + ExprParams map[string]interface{} `json:"exprParams"` } type HybridSearchReq struct { diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 4d2ed437c3b66..e6f6edc546658 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1340,3 +1340,170 @@ func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLe // ConsistencyLevel_Bounded default in PyMilvus return commonpb.ConsistencyLevel_Bounded, true, nil } + +func generateTemplateArrayData(list []interface{}) *schemapb.TemplateArrayValue { + dtype := getTemplateArrayType(list) + var data *schemapb.TemplateArrayValue + switch dtype { + case schemapb.DataType_Bool: + result := make([]bool, len(list)) + for i, item := range list { + result[i] = item.(bool) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: result, + }, + }, + } + case schemapb.DataType_String: + result := make([]string, len(list)) + for i, item := range list { + result[i] = item.(string) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_StringData{ + StringData: &schemapb.StringArray{ + Data: result, + }, + }, + } + case schemapb.DataType_Int64: + result := make([]int64, len(list)) + for i, item := range list { + result[i] = int64(item.(float64)) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_LongData{ + LongData: &schemapb.LongArray{ + Data: result, + }, + }, + } + case schemapb.DataType_Float: + result := make([]float64, len(list)) + for i, item := range list { + result[i] = item.(float64) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: result, + }, + }, + } + case schemapb.DataType_Array: + result := make([]*schemapb.TemplateArrayValue, len(list)) + for i, item := range list { + result[i] = generateTemplateArrayData(item.([]interface{})) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_ArrayData{ + ArrayData: &schemapb.TemplateArrayValueArray{ + Data: result, + }, + }, + } + case schemapb.DataType_JSON: + result := make([][]byte, len(list)) + for i, item := range list { + bytes, err := json.Marshal(item) + // won't happen + if err != nil { + panic(fmt.Sprintf("marshal data(%v) fail, please check it!", item)) + } + result[i] = bytes + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: result, + }, + }, + } + // won't happen + default: + panic(fmt.Sprintf("Unexpected data(%v) type when generateTemplateArrayData, please check it!", list)) + } + return data +} + +func getTemplateArrayType(value []interface{}) schemapb.DataType { + dtype := getTemplateType(value[0]) + + for _, v := range value { + if getTemplateType(v) != dtype { + return schemapb.DataType_JSON + } + } + return dtype +} + +func getTemplateType(value interface{}) schemapb.DataType { + switch v := value.(type) { + case bool: + return schemapb.DataType_Bool + case string: + return schemapb.DataType_String + case float64: + // note: all passed number is float64 type + // if field type is float64, but value in ExpressionTemplate is int64, it's ok to use TemplateValue_Int64Val to store it + // it will convert to float64 in ./internal/parser/planparserv2/utils.go, Line 233 + if v == math.Trunc(v) && v >= math.MinInt64 && v <= math.MaxInt64 { + return schemapb.DataType_Int64 + } + return schemapb.DataType_Float + // it won't happen + // case int64: + case []interface{}: + return schemapb.DataType_Array + default: + panic(fmt.Sprintf("Unexpected data(%v) when getTemplateType, please check it!", value)) + } +} + +func generateExpressionTemplate(params map[string]interface{}) map[string]*schemapb.TemplateValue { + expressionTemplate := make(map[string]*schemapb.TemplateValue, len(params)) + + for name, value := range params { + dtype := getTemplateType(value) + var data *schemapb.TemplateValue + switch dtype { + case schemapb.DataType_Bool: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_BoolVal{ + BoolVal: value.(bool), + }, + } + case schemapb.DataType_String: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_StringVal{ + StringVal: value.(string), + }, + } + case schemapb.DataType_Int64: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: int64(value.(float64)), + }, + } + case schemapb.DataType_Float: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_FloatVal{ + FloatVal: value.(float64), + }, + } + case schemapb.DataType_Array: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: generateTemplateArrayData(value.([]interface{})), + }, + } + default: + panic(fmt.Sprintf("Unexpected data(%v) when generateExpressionTemplate, please check it!", data)) + } + expressionTemplate[name] = data + } + return expressionTemplate +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index d05f0e0661f42..c2dd7f8742f7a 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1592,3 +1592,143 @@ func TestConvertConsistencyLevel(t *testing.T) { _, _, err = convertConsistencyLevel("test") assert.NotNil(t, err) } + +func TestGenerateExpressionTemplate(t *testing.T) { + var mixedList []interface{} + var mixedAns [][]byte + + mixedList = append(mixedList, float64(1)) + mixedList = append(mixedList, "10") + mixedList = append(mixedList, true) + + val, _ := json.Marshal(1) + mixedAns = append(mixedAns, val) + val, _ = json.Marshal("10") + mixedAns = append(mixedAns, val) + val, _ = json.Marshal(true) + mixedAns = append(mixedAns, val) + // all passed number is float64 type, so all the number type has convert to float64 + expressionTemplates := []map[string]interface{}{ + {"str": "10"}, + {"min": float64(1), "max": float64(10)}, + {"bool": true}, + {"float64": 1.1}, + {"int64": float64(1)}, + {"list_of_str": []interface{}{"1", "10", "100"}}, + {"list_of_bool": []interface{}{true, false, true}}, + {"list_of_float": []interface{}{1.1, 10.1, 100.1}}, + {"list_of_int": []interface{}{float64(1), float64(10), float64(100)}}, + {"list_of_json": mixedList}, + } + ans := []map[string]*schemapb.TemplateValue{ + { + "str": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_StringVal{ + StringVal: "10", + }, + }, + }, + { + "min": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: 1, + }, + }, + "max": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: 10, + }, + }, + }, + { + "bool": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_BoolVal{ + BoolVal: true, + }, + }, + }, + { + "float64": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_FloatVal{ + FloatVal: 1.1, + }, + }, + }, + { + "int64": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: 1, + }, + }, + }, + { + "list_of_str": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"1", "10", "100"}, + }, + }, + }, + }, + }, + }, + { + "list_of_bool": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false, true}, + }, + }, + }, + }, + }, + }, + { + "list_of_float": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{1.1, 10.1, 100.1}, + }, + }, + }, + }, + }, + }, + { + "list_of_int": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 10, 100}, + }, + }, + }, + }, + }, + }, + { + "list_of_json": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: mixedAns, + }, + }, + }, + }, + }, + }, + } + for i, template := range expressionTemplates { + actual := generateExpressionTemplate(template) + assert.Equal(t, actual, ans[i]) + } +} From d4ef89f1c63bfdea4c32b0df0506d52f49f3ed08 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Thu, 5 Dec 2024 14:36:41 +0800 Subject: [PATCH 4/4] test: relax the checks on range search (#36542) (#38234) /kind improvement pr: #36542 --- .../testcases/test_vector_operations.py | 12 ++++++++++-- tests/restful_client_v2/utils/utils.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/restful_client_v2/testcases/test_vector_operations.py b/tests/restful_client_v2/testcases/test_vector_operations.py index e21ee6dfa98ea..ce70e71fe406a 100644 --- a/tests/restful_client_v2/testcases/test_vector_operations.py +++ b/tests/restful_client_v2/testcases/test_vector_operations.py @@ -1580,10 +1580,11 @@ def test_search_vector_with_range_search(self, metric_type): vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() training_data = [item[vector_field] for item in data] distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type) - r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct + r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.5*limit))] # recall is not 100% so add 50% to make sure the range is more than limit if metric_type == "L2": r1, r2 = r2, r1 output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + logger.info(f"r1: {r1}, r2: {r2}") payload = { "collectionName": name, "data": [vector_to_search], @@ -1601,7 +1602,14 @@ def test_search_vector_with_range_search(self, metric_type): assert rsp['code'] == 0 res = rsp['data'] logger.info(f"res: {len(res)}") - assert len(res) == limit + assert len(res) >= limit*0.8 + # add buffer to the distance of comparison + if metric_type == "L2": + r1 = r1 + 10**-6 + r2 = r2 - 10**-6 + else: + r1 = r1 - 10**-6 + r2 = r2 + 10**-6 for item in res: distance = item.get("distance") if metric_type == "L2": diff --git a/tests/restful_client_v2/utils/utils.py b/tests/restful_client_v2/utils/utils.py index 0c93e566cd99d..cf4f23eb99b52 100644 --- a/tests/restful_client_v2/utils/utils.py +++ b/tests/restful_client_v2/utils/utils.py @@ -262,6 +262,6 @@ def get_sorted_distance(train_emb, test_emb, metric_type): "IP": ip_distance } distance = pairwise_distances(train_emb, Y=test_emb, metric=milvus_sklearn_metric_map[metric_type], n_jobs=-1) - distance = np.array(distance.T, order='C', dtype=np.float16) + distance = np.array(distance.T, order='C', dtype=np.float32) distance_sorted = np.sort(distance, axis=1).tolist() return distance_sorted