diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index b8eed64bdf20c..32cab80112024 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1584,7 +1584,7 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) searchParams := []*commonpb.KeyValuePair{ {Key: MetricTypeKey, Value: metric.L2}, - {Key: SearchParamsKey, Value: string(b)}, + {Key: ParamsKey, Value: string(b)}, {Key: AnnsFieldKey, Value: floatVecField}, {Key: TopKKey, Value: strconv.Itoa(topk)}, {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, @@ -1617,7 +1617,7 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) searchParams := []*commonpb.KeyValuePair{ {Key: MetricTypeKey, Value: metric.L2}, - {Key: SearchParamsKey, Value: string(b)}, + {Key: ParamsKey, Value: string(b)}, {Key: AnnsFieldKey, Value: floatVecField}, {Key: TopKKey, Value: strconv.Itoa(topk)}, {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, @@ -1714,7 +1714,7 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) searchParams := []*commonpb.KeyValuePair{ {Key: MetricTypeKey, Value: metric.L2}, - {Key: SearchParamsKey, Value: string(b)}, + {Key: ParamsKey, Value: string(b)}, {Key: AnnsFieldKey, Value: floatVecField}, {Key: TopKKey, Value: strconv.Itoa(topk)}, {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, @@ -1798,7 +1798,7 @@ func TestProxy(t *testing.T) { // Value: distance.L2, // }, // { - // Key: SearchParamsKey, + // Key: ParamsKey, // Value: string(b), // }, // { diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index cea3dc49c63a6..8ed45a1dd91e0 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -20,6 +21,31 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +const ( + // float64 + nlistKey = "nlist" + nprobeKey = "nprobe" + maxEmptyResultBuckets = "max_empty_result_buckets" + reorderKKey = "reorder_k" + searchListKey = "search_list" + itopkSizeKey = "itopk_size" + searchWidthKey = "search_width" + minIterationsKey = "min_iterations" + maxIterationsKey = "max_iterations" + teamSizeKey = "team_size" + radiusKey = "radius" + rangeFilterKey = "range_filter" + levelKey = "level" + // bool + pageRetainOrderKey = "page_retain_order" +) + +var ParamsKeyList = []string{ + nlistKey, nprobeKey, maxEmptyResultBuckets, reorderKKey, searchListKey, + itopkSizeKey, searchWidthKey, minIterationsKey, maxIterationsKey, + teamSizeKey, radiusKey, rangeFilterKey, levelKey, pageRetainOrderKey, +} + type rankParams struct { limit int64 offset int64 @@ -163,10 +189,36 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } // 4. parse search param str - searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair) + searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ParamsKey, searchParamsPair) + var searchParamMap map[string]any if err != nil { searchParamStr = "" + } else { + err = json.Unmarshal([]byte(searchParamStr), &searchParamMap) + if err != nil { + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: err} + } + } + // related with https://github.com/milvus-io/milvus/issues/37972 + // before 2.5.1, all key in ParamsKeyList will be write in search_params.params + // after 2.5.1, allow user to write all this key in search_params + // SearchParams in planpb.QueryInfo is the params set passed to segcore + // so if you want to use the params in segcore/knowhere, remember to add the new params into ParamsKeyList after 2.5.1 + for _, key := range ParamsKeyList { + stringValue, err := funcutil.GetAttrByKeyFromRepeatedKV(key, searchParamsPair) + if err == nil { + err := checkParams(searchParamMap, key, stringValue) + if err != nil { + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: err} + } + } + } + + jsonStrBytes, err := json.Marshal(searchParamMap) + if err != nil { + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: err} } + searchParamStr = string(jsonStrBytes) // 5. parse group by field and group by size var groupByFieldId, groupSize int64 @@ -494,3 +546,40 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se } return ret } + +func checkParams(m map[string]any, key string, stringValue string) error { + switch key { + case pageRetainOrderKey: + value, err := strconv.ParseBool(stringValue) + if err != nil { + return err + } + if v, ok := m[key]; ok { + v, ok := v.(bool) + if !ok { + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter(%s) has the wrong type, expect bool type", key)) + } + if v != value { + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("inconsistent parameter(%s), search_param(%t),search_param.params(%t)", key, v, value)) + } + } + m[key] = value + default: + // all number will be convert to float64 in json + value, err := strconv.ParseFloat(stringValue, 64) + if err != nil { + return err + } + if v, ok := m[key]; ok { + v, ok := v.(float64) + if !ok { + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter(%s) has the wrong type, expect bool type", key)) + } + if v != value { + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("inconsistent parameter(%s), search_param(%f),search_param.params(%f)", key, v, value)) + } + } + m[key] = value + } + return nil +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 07b6cf6f864e1..fdb03b60043bb 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -61,7 +61,7 @@ const ( TopKKey = "topk" NQKey = "nq" MetricTypeKey = common.MetricTypeKey - SearchParamsKey = "params" + ParamsKey = "params" ExprParamsKey = "expr_params" RoundDecimalKey = "round_decimal" OffsetKey = "offset" diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 3dc48cfe9503c..19f4600cab1e5 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -43,8 +43,6 @@ const ( // a second query request will be initiated to retrieve output fields data. // In this case, the first search will not return any output field from QueryNodes. requeryThreshold = 0.5 * 1024 * 1024 - radiusKey = "radius" - rangeFilterKey = "range_filter" ) type searchTask struct { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index bab8ee02bd436..f6b24b684be25 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -24,16 +24,19 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" @@ -156,7 +159,7 @@ func getValidSearchParams() []*commonpb.KeyValuePair { Value: metric.L2, }, { - Key: SearchParamsKey, + Key: ParamsKey, Value: `{"nprobe": 10}`, }, { @@ -2259,7 +2262,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { noMetricTypeParams := getBaseSearchParams() noMetricTypeParams = append(noMetricTypeParams, &commonpb.KeyValuePair{ - Key: SearchParamsKey, + Key: ParamsKey, Value: `{"nprobe": 10}`, }) @@ -2405,7 +2408,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { // no roundDecimal is valid noRoundDecimal := append(spNoSearchParams, &commonpb.KeyValuePair{ - Key: SearchParamsKey, + Key: ParamsKey, Value: `{"nprobe": 10}`, }) @@ -2484,7 +2487,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { }) t.Run("check range-search and groupBy", func(t *testing.T) { normalParam := getValidSearchParams() - resetSearchParamsValue(normalParam, SearchParamsKey, `{"nprobe": 10, "radius":0.2}`) + resetSearchParamsValue(normalParam, ParamsKey, `{"nprobe": 10, "radius":0.2}`) normalParam = append(normalParam, &commonpb.KeyValuePair{ Key: GroupByFieldKey, Value: "string_field", @@ -2599,6 +2602,119 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size")) } }) + + t.Run("parameters in searchParams and searchParams.Params are inconsistent with same key", func(t *testing.T) { + normalParam := getValidSearchParams() + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: nprobeKey, + Value: "100", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid) + }) + + t.Run("type of param is incorrect", func(t *testing.T) { + normalParam := getValidSearchParams() + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: nprobeKey, + Value: "true", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.NotNil(t, searchInfo.parseError) + }) + + t.Run("type of param is incorrect", func(t *testing.T) { + normalParam := getValidSearchParams() + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: pageRetainOrderKey, + Value: "10", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.NotNil(t, searchInfo.parseError) + }) + + t.Run("old parameters forms", func(t *testing.T) { + normalParam := getValidSearchParams() + resetSearchParamsValue(normalParam, ParamsKey, `{"nprobe": 10, "radius":0.2, "range_filter": 1, "page_retain_order":true}`) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + searchInfo := parseSearchInfo(normalParam, schema, nil) + var searchParamMap map[string]any + _ = json.Unmarshal([]byte(searchInfo.planInfo.SearchParams), &searchParamMap) + assert.Nil(t, searchInfo.parseError) + assert.Equal(t, 4, len(searchParamMap)) + assert.Equal(t, searchParamMap[nprobeKey], float64(10)) + assert.Equal(t, searchParamMap[radiusKey], 0.2) + assert.Equal(t, searchParamMap[rangeFilterKey], float64(1)) + assert.Equal(t, searchParamMap[pageRetainOrderKey], true) + }) + + t.Run("new parameters forms", func(t *testing.T) { + normalParam := getValidSearchParams() + resetSearchParamsValue(normalParam, ParamsKey, `{"nprobe": 10, "radius":0.2}`) + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: nprobeKey, + Value: "10", + }) + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: rangeFilterKey, + Value: "1", + }) + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: pageRetainOrderKey, + Value: "true", + }) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + searchInfo := parseSearchInfo(normalParam, schema, nil) + var searchParamMap map[string]any + _ = json.Unmarshal([]byte(searchInfo.planInfo.SearchParams), &searchParamMap) + assert.Nil(t, searchInfo.parseError) + assert.Equal(t, 4, len(searchParamMap)) + assert.Equal(t, searchParamMap[nprobeKey], float64(10)) + assert.Equal(t, searchParamMap[radiusKey], 0.2) + assert.Equal(t, searchParamMap[rangeFilterKey], float64(1)) + assert.Equal(t, searchParamMap[pageRetainOrderKey], true) + }) } func getSearchResultData(nq, topk int64) *schemapb.SearchResultData { diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 6847383ac0580..4da777f85e23f 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -437,7 +437,7 @@ func constructSearchRequest( Value: metric.L2, }, { - Key: SearchParamsKey, + Key: ParamsKey, Value: string(b), }, { diff --git a/tests/integration/util_query.go b/tests/integration/util_query.go index 500b4d34ac6e4..d7182107bc463 100644 --- a/tests/integration/util_query.go +++ b/tests/integration/util_query.go @@ -41,7 +41,7 @@ const ( TopKKey = "topk" NQKey = "nq" MetricTypeKey = common.MetricTypeKey - SearchParamsKey = common.IndexParamsKey + ParamsKey = common.IndexParamsKey RoundDecimalKey = "round_decimal" OffsetKey = "offset" LimitKey = "limit" @@ -196,7 +196,7 @@ func ConstructSearchRequest( Value: metricType, }, { - Key: SearchParamsKey, + Key: ParamsKey, Value: string(b), }, { @@ -255,7 +255,7 @@ func ConstructSearchRequestWithConsistencyLevel( Value: metricType, }, { - Key: SearchParamsKey, + Key: ParamsKey, Value: string(b), }, {