Skip to content

Commit

Permalink
enhance: simplify the structure of search_params
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Dec 6, 2024
1 parent d7a5ad4 commit cd0f327
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 16 deletions.
8 changes: 4 additions & 4 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
searchParams := []*commonpb.KeyValuePair{
{Key: MetricTypeKey, Value: metric.L2},
{Key: SearchParamsKey, Value: string(b)},
{Key: ParamsKey, Value: string(b)},
{Key: AnnsFieldKey, Value: floatVecField},
{Key: TopKKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
Expand Down Expand Up @@ -1617,7 +1617,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
searchParams := []*commonpb.KeyValuePair{
{Key: MetricTypeKey, Value: metric.L2},
{Key: SearchParamsKey, Value: string(b)},
{Key: ParamsKey, Value: string(b)},
{Key: AnnsFieldKey, Value: floatVecField},
{Key: TopKKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
Expand Down Expand Up @@ -1714,7 +1714,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
searchParams := []*commonpb.KeyValuePair{
{Key: MetricTypeKey, Value: metric.L2},
{Key: SearchParamsKey, Value: string(b)},
{Key: ParamsKey, Value: string(b)},
{Key: AnnsFieldKey, Value: floatVecField},
{Key: TopKKey, Value: strconv.Itoa(topk)},
{Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)},
Expand Down Expand Up @@ -1798,7 +1798,7 @@ func TestProxy(t *testing.T) {
// Value: distance.L2,
// },
// {
// Key: SearchParamsKey,
// Key: ParamsKey,
// Value: string(b),
// },
// {
Expand Down
91 changes: 90 additions & 1 deletion internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,39 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

const (
// float64
nlistKey = "nlist"
nprobeKey = "nprobe"
maxEmptyResultBuckets = "max_empty_result_buckets"
reorderKKey = "reorder_k"
searchListKey = "search_list"
itopkSizeKey = "itopk_size"
searchWidthKey = "search_width"
minIterationsKey = "min_iterations"
maxIterationsKey = "max_iterations"
teamSizeKey = "team_size"
radiusKey = "radius"
rangeFilterKey = "range_filter"
levelKey = "level"
// bool
pageRetainOrderKey = "page_retain_order"
)

var ParamsKeyList = []string{
nlistKey, nprobeKey, maxEmptyResultBuckets, reorderKKey, searchListKey,
itopkSizeKey, searchWidthKey, minIterationsKey, maxIterationsKey,
teamSizeKey, radiusKey, rangeFilterKey, levelKey, pageRetainOrderKey,
}

type rankParams struct {
limit int64
offset int64
Expand Down Expand Up @@ -163,10 +189,36 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}

// 4. parse search param str
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ParamsKey, searchParamsPair)
var searchParamMap map[string]any
if err != nil {
searchParamStr = ""
} else {
err = json.Unmarshal([]byte(searchParamStr), &searchParamMap)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: err}
}
}
// related with https://github.com/milvus-io/milvus/issues/37972
// before 2.5.1, all key in ParamsKeyList will be write in search_params.params
// after 2.5.1, allow user to write all this key in search_params
// SearchParams in planpb.QueryInfo is the params set passed to segcore
// so if you want to use the params in segcore/knowhere, remember to add the new params into ParamsKeyList after 2.5.1
for _, key := range ParamsKeyList {
stringValue, err := funcutil.GetAttrByKeyFromRepeatedKV(key, searchParamsPair)
if err == nil {
err := checkParams(searchParamMap, key, stringValue)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: err}
}
}
}

jsonStrBytes, err := json.Marshal(searchParamMap)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: err}
}
searchParamStr = string(jsonStrBytes)

// 5. parse group by field and group by size
var groupByFieldId, groupSize int64
Expand Down Expand Up @@ -494,3 +546,40 @@ func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.Se
}
return ret
}

func checkParams(m map[string]any, key string, stringValue string) error {
switch key {
case pageRetainOrderKey:
value, err := strconv.ParseBool(stringValue)
if err != nil {
return err
}
if v, ok := m[key]; ok {
v, ok := v.(bool)
if !ok {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter(%s) has the wrong type, expect bool type", key))
}
if v != value {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("inconsistent parameter(%s), search_param(%t),search_param.params(%t)", key, v, value))
}
}
m[key] = value
default:
// all number will be convert to float64 in json
value, err := strconv.ParseFloat(stringValue, 64)
if err != nil {
return err
}
if v, ok := m[key]; ok {
v, ok := v.(float64)
if !ok {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter(%s) has the wrong type, expect bool type", key))
}
if v != value {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("inconsistent parameter(%s), search_param(%f),search_param.params(%f)", key, v, value))
}
}
m[key] = value
}
return nil
}
2 changes: 1 addition & 1 deletion internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ const (
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = common.MetricTypeKey
SearchParamsKey = "params"
ParamsKey = "params"
ExprParamsKey = "expr_params"
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
Expand Down
2 changes: 0 additions & 2 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ const (
// a second query request will be initiated to retrieve output fields data.
// In this case, the first search will not return any output field from QueryNodes.
requeryThreshold = 0.5 * 1024 * 1024
radiusKey = "radius"
rangeFilterKey = "range_filter"
)

type searchTask struct {
Expand Down
124 changes: 120 additions & 4 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ import (
"time"

"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
Expand Down Expand Up @@ -156,7 +159,7 @@ func getValidSearchParams() []*commonpb.KeyValuePair {
Value: metric.L2,
},
{
Key: SearchParamsKey,
Key: ParamsKey,
Value: `{"nprobe": 10}`,
},
{
Expand Down Expand Up @@ -2259,7 +2262,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

noMetricTypeParams := getBaseSearchParams()
noMetricTypeParams = append(noMetricTypeParams, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Key: ParamsKey,
Value: `{"nprobe": 10}`,
})

Expand Down Expand Up @@ -2405,7 +2408,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

// no roundDecimal is valid
noRoundDecimal := append(spNoSearchParams, &commonpb.KeyValuePair{
Key: SearchParamsKey,
Key: ParamsKey,
Value: `{"nprobe": 10}`,
})

Expand Down Expand Up @@ -2484,7 +2487,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
})
t.Run("check range-search and groupBy", func(t *testing.T) {
normalParam := getValidSearchParams()
resetSearchParamsValue(normalParam, SearchParamsKey, `{"nprobe": 10, "radius":0.2}`)
resetSearchParamsValue(normalParam, ParamsKey, `{"nprobe": 10, "radius":0.2}`)
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "string_field",
Expand Down Expand Up @@ -2599,6 +2602,119 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size"))
}
})

t.Run("parameters in searchParams and searchParams.Params are inconsistent with same key", func(t *testing.T) {
normalParam := getValidSearchParams()
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: nprobeKey,
Value: "100",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
})

t.Run("type of param is incorrect", func(t *testing.T) {
normalParam := getValidSearchParams()
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: nprobeKey,
Value: "true",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.NotNil(t, searchInfo.parseError)
})

t.Run("type of param is incorrect", func(t *testing.T) {
normalParam := getValidSearchParams()
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: pageRetainOrderKey,
Value: "10",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.NotNil(t, searchInfo.parseError)
})

t.Run("old parameters forms", func(t *testing.T) {
normalParam := getValidSearchParams()
resetSearchParamsValue(normalParam, ParamsKey, `{"nprobe": 10, "radius":0.2, "range_filter": 1, "page_retain_order":true}`)
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(normalParam, schema, nil)
var searchParamMap map[string]any
_ = json.Unmarshal([]byte(searchInfo.planInfo.SearchParams), &searchParamMap)
assert.Nil(t, searchInfo.parseError)
assert.Equal(t, 4, len(searchParamMap))
assert.Equal(t, searchParamMap[nprobeKey], float64(10))
assert.Equal(t, searchParamMap[radiusKey], 0.2)
assert.Equal(t, searchParamMap[rangeFilterKey], float64(1))
assert.Equal(t, searchParamMap[pageRetainOrderKey], true)
})

t.Run("new parameters forms", func(t *testing.T) {
normalParam := getValidSearchParams()
resetSearchParamsValue(normalParam, ParamsKey, `{"nprobe": 10, "radius":0.2}`)
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: nprobeKey,
Value: "10",
})
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: rangeFilterKey,
Value: "1",
})
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: pageRetainOrderKey,
Value: "true",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(normalParam, schema, nil)
var searchParamMap map[string]any
_ = json.Unmarshal([]byte(searchInfo.planInfo.SearchParams), &searchParamMap)
assert.Nil(t, searchInfo.parseError)
assert.Equal(t, 4, len(searchParamMap))
assert.Equal(t, searchParamMap[nprobeKey], float64(10))
assert.Equal(t, searchParamMap[radiusKey], 0.2)
assert.Equal(t, searchParamMap[rangeFilterKey], float64(1))
assert.Equal(t, searchParamMap[pageRetainOrderKey], true)
})
}

func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ func constructSearchRequest(
Value: metric.L2,
},
{
Key: SearchParamsKey,
Key: ParamsKey,
Value: string(b),
},
{
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/util_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ const (
TopKKey = "topk"
NQKey = "nq"
MetricTypeKey = common.MetricTypeKey
SearchParamsKey = common.IndexParamsKey
ParamsKey = common.IndexParamsKey
RoundDecimalKey = "round_decimal"
OffsetKey = "offset"
LimitKey = "limit"
Expand Down Expand Up @@ -196,7 +196,7 @@ func ConstructSearchRequest(
Value: metricType,
},
{
Key: SearchParamsKey,
Key: ParamsKey,
Value: string(b),
},
{
Expand Down Expand Up @@ -255,7 +255,7 @@ func ConstructSearchRequestWithConsistencyLevel(
Value: metricType,
},
{
Key: SearchParamsKey,
Key: ParamsKey,
Value: string(b),
},
{
Expand Down

0 comments on commit cd0f327

Please sign in to comment.