From 01cfb1fd97852b7336dc452e2473c8334011d366 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 19 Dec 2024 11:20:47 +0800 Subject: [PATCH] enhance: [GoSDK] support expression template (#38568) Related to #36672 This PR add - Expression template for search, query & hybrid search - fix hybrid search rerank param - add reranker interface(migrate from go sdk old repo) --------- Signed-off-by: Congqi Xia --- client/milvusclient/read.go | 5 +- client/milvusclient/read_option_test.go | 169 +++++++++++++++++++++ client/milvusclient/read_options.go | 186 +++++++++++++++++++++--- client/milvusclient/read_test.go | 18 ++- client/milvusclient/reranker.go | 62 ++++++++ client/milvusclient/reranker_test.go | 55 +++++++ 6 files changed, 471 insertions(+), 24 deletions(-) create mode 100644 client/milvusclient/reranker.go create mode 100644 client/milvusclient/reranker_test.go diff --git a/client/milvusclient/read.go b/client/milvusclient/read.go index 8839728c22e06..e07185e4846d4 100644 --- a/client/milvusclient/read.go +++ b/client/milvusclient/read.go @@ -158,8 +158,11 @@ func (c *Client) parseSearchResult(sch *entity.Schema, outputFields []string, fi } func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) { - req := option.Request() var resultSet ResultSet + req, err := option.Request() + if err != nil { + return resultSet, err + } collection, err := c.getCollection(ctx, req.GetCollectionName()) if err != nil { diff --git a/client/milvusclient/read_option_test.go b/client/milvusclient/read_option_test.go index 0e50db0580878..c4e1e52c99159 100644 --- a/client/milvusclient/read_option_test.go +++ b/client/milvusclient/read_option_test.go @@ -17,9 +17,11 @@ package milvusclient import ( + "fmt" "math/rand" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -137,3 +139,170 @@ func (s *SearchOptionSuite) TestPlaceHolder() { func TestSearchOption(t *testing.T) { suite.Run(t, new(SearchOptionSuite)) } + +func TestAny2TmplValue(t *testing.T) { + t.Run("primitives", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + v := rand.Int() + val, err := any2TmplValue(v) + assert.NoError(t, err) + assert.EqualValues(t, v, val.GetInt64Val()) + }) + + t.Run("int32", func(t *testing.T) { + v := rand.Int31() + val, err := any2TmplValue(v) + assert.NoError(t, err) + assert.EqualValues(t, v, val.GetInt64Val()) + }) + + t.Run("int64", func(t *testing.T) { + v := rand.Int63() + val, err := any2TmplValue(v) + assert.NoError(t, err) + assert.EqualValues(t, v, val.GetInt64Val()) + }) + + t.Run("float32", func(t *testing.T) { + v := rand.Float32() + val, err := any2TmplValue(v) + assert.NoError(t, err) + assert.EqualValues(t, v, val.GetFloatVal()) + }) + + t.Run("float64", func(t *testing.T) { + v := rand.Float64() + val, err := any2TmplValue(v) + assert.NoError(t, err) + assert.EqualValues(t, v, val.GetFloatVal()) + }) + + t.Run("bool", func(t *testing.T) { + val, err := any2TmplValue(true) + assert.NoError(t, err) + assert.True(t, val.GetBoolVal()) + }) + + t.Run("string", func(t *testing.T) { + v := fmt.Sprintf("%v", rand.Int()) + val, err := any2TmplValue(v) + assert.NoError(t, err) + assert.EqualValues(t, v, val.GetStringVal()) + }) + }) + + t.Run("slice", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]int, 0, l) + for i := 0; i < l; i++ { + v = append(v, rand.Int()) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetLongData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + + t.Run("int32", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]int32, 0, l) + for i := 0; i < l; i++ { + v = append(v, rand.Int31()) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetLongData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + + t.Run("int64", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]int64, 0, l) + for i := 0; i < l; i++ { + v = append(v, rand.Int63()) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetLongData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + + t.Run("float32", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]float32, 0, l) + for i := 0; i < l; i++ { + v = append(v, rand.Float32()) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetDoubleData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + + t.Run("float64", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]float64, 0, l) + for i := 0; i < l; i++ { + v = append(v, rand.Float64()) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetDoubleData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + + t.Run("bool", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]bool, 0, l) + for i := 0; i < l; i++ { + v = append(v, rand.Int()%2 == 0) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetBoolData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + + t.Run("string", func(t *testing.T) { + l := rand.Intn(10) + 1 + v := make([]string, 0, l) + for i := 0; i < l; i++ { + v = append(v, fmt.Sprintf("%v", rand.Int())) + } + val, err := any2TmplValue(v) + assert.NoError(t, err) + data := val.GetArrayVal().GetStringData().GetData() + assert.Equal(t, l, len(data)) + for i, val := range data { + assert.EqualValues(t, v[i], val) + } + }) + }) + + t.Run("unsupported", func(*testing.T) { + _, err := any2TmplValue(struct{}{}) + assert.Error(t, err) + + _, err = any2TmplValue([]struct{}{}) + assert.Error(t, err) + }) +} diff --git a/client/milvusclient/read_options.go b/client/milvusclient/read_options.go index 883deda64b278..8f040f5f531ee 100644 --- a/client/milvusclient/read_options.go +++ b/client/milvusclient/read_options.go @@ -18,6 +18,8 @@ package milvusclient import ( "encoding/json" + "fmt" + "reflect" "strconv" "github.com/cockroachdb/errors" @@ -25,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/index" ) @@ -59,22 +62,24 @@ type searchOption struct { type annRequest struct { vectors []entity.Vector - annField string - metricsType entity.MetricType - searchParam map[string]string - groupByField string - annParam index.AnnParam - ignoreGrowing bool - expr string - topK int - offset int + annField string + metricsType entity.MetricType + searchParam map[string]string + groupByField string + annParam index.AnnParam + ignoreGrowing bool + expr string + topK int + offset int + templateParams map[string]any } func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest { return &annRequest{ - annField: annField, - vectors: vectors, - topK: limit, + annField: annField, + vectors: vectors, + topK: limit, + templateParams: make(map[string]any), } } @@ -116,9 +121,98 @@ func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) { } request.SearchParams = entity.MapKvPairs(params) + request.ExprTemplateValues = make(map[string]*schemapb.TemplateValue) + for key, value := range r.templateParams { + tmplVal, err := any2TmplValue(value) + if err != nil { + return nil, err + } + request.ExprTemplateValues[key] = tmplVal + } + return request, nil } +func any2TmplValue(val any) (*schemapb.TemplateValue, error) { + result := &schemapb.TemplateValue{} + switch v := val.(type) { + case int, int8, int16, int32: + result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: reflect.ValueOf(v).Int()} + case int64: + result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: v} + case float32: + result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: float64(v)} + case float64: + result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: v} + case bool: + result.Val = &schemapb.TemplateValue_BoolVal{BoolVal: v} + case string: + result.Val = &schemapb.TemplateValue_StringVal{StringVal: v} + default: + if reflect.TypeOf(val).Kind() == reflect.Slice { + return slice2TmplValue(val) + } + return nil, fmt.Errorf("unsupported template value type: %T", val) + } + return result, nil +} + +func slice2TmplValue(val any) (*schemapb.TemplateValue, error) { + arrVal := &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{}, + } + + rv := reflect.ValueOf(val) + switch t := reflect.TypeOf(val).Elem().Kind(); t { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + data := make([]int64, 0, rv.Len()) + for i := 0; i < rv.Len(); i++ { + data = append(data, rv.Index(i).Int()) + } + arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_LongData{ + LongData: &schemapb.LongArray{ + Data: data, + }, + } + case reflect.Bool: + data := make([]bool, 0, rv.Len()) + for i := 0; i < rv.Len(); i++ { + data = append(data, rv.Index(i).Bool()) + } + arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: data, + }, + } + case reflect.Float32, reflect.Float64: + data := make([]float64, 0, rv.Len()) + for i := 0; i < rv.Len(); i++ { + data = append(data, rv.Index(i).Float()) + } + arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: data, + }, + } + case reflect.String: + data := make([]string, 0, rv.Len()) + for i := 0; i < rv.Len(); i++ { + data = append(data, rv.Index(i).String()) + } + arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_StringData{ + StringData: &schemapb.StringArray{ + Data: data, + }, + } + default: + return nil, fmt.Errorf("unsupported template type: slice of %v", t) + } + + return &schemapb.TemplateValue{ + Val: arrVal, + }, nil +} + func (r *annRequest) WithANNSField(annsField string) *annRequest { r.annField = annsField return r @@ -144,6 +238,11 @@ func (r *annRequest) WithFilter(expr string) *annRequest { return r } +func (r *annRequest) WithTemplateParam(key string, val any) *annRequest { + r.templateParams[key] = val + return r +} + func (r *annRequest) WithOffset(offset int) *annRequest { r.offset = offset return r @@ -179,6 +278,11 @@ func (opt *searchOption) WithFilter(expr string) *searchOption { return opt } +func (opt *searchOption) WithTemplateParam(key string, val any) *searchOption { + opt.annRequest.WithTemplateParam(key, val) + return opt +} + func (opt *searchOption) WithOffset(offset int) *searchOption { opt.annRequest.WithOffset(offset) return opt @@ -223,9 +327,10 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption { func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption { return &searchOption{ annRequest: &annRequest{ - vectors: vectors, - searchParam: make(map[string]string), - topK: limit, + vectors: vectors, + searchParam: make(map[string]string), + topK: limit, + templateParams: make(map[string]any), }, collectionName: collectionName, useDefaultConsistencyLevel: true, @@ -293,6 +398,10 @@ type hybridSearchOption struct { outputFields []string useDefaultConsistency bool consistencyLevel entity.ConsistencyLevel + + limit int + offset int + reranker Reranker } func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption { @@ -311,6 +420,16 @@ func (opt *hybridSearchOption) WithOutputFields(outputFields ...string) *hybridS return opt } +func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOption { + opt.reranker = reranker + return opt +} + +func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption { + opt.offset = offset + return opt +} + func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, error) { requests := make([]*milvuspb.SearchRequest, 0, len(opt.reqs)) for _, annRequest := range opt.reqs { @@ -321,6 +440,15 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e requests = append(requests, req) } + var params []*commonpb.KeyValuePair + if opt.reranker != nil { + params = opt.reranker.GetParams() + } + params = append(params, &commonpb.KeyValuePair{Key: spLimit, Value: strconv.FormatInt(int64(opt.limit), 10)}) + if opt.offset > 0 { + params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)}) + } + return &milvuspb.HybridSearchRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, @@ -328,20 +456,22 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e UseDefaultConsistency: opt.useDefaultConsistency, ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel), OutputFields: opt.outputFields, + RankParams: params, }, nil } -func NewHybridSearchOption(collectionName string, annRequests ...*annRequest) *hybridSearchOption { +func NewHybridSearchOption(collectionName string, limit int, annRequests ...*annRequest) *hybridSearchOption { return &hybridSearchOption{ collectionName: collectionName, reqs: annRequests, useDefaultConsistency: true, + limit: limit, } } type QueryOption interface { - Request() *milvuspb.QueryRequest + Request() (*milvuspb.QueryRequest, error) } type queryOption struct { @@ -352,10 +482,11 @@ type queryOption struct { consistencyLevel entity.ConsistencyLevel useDefaultConsistencyLevel bool expr string + templateParams map[string]any } -func (opt *queryOption) Request() *milvuspb.QueryRequest { - return &milvuspb.QueryRequest{ +func (opt *queryOption) Request() (*milvuspb.QueryRequest, error) { + req := &milvuspb.QueryRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, OutputFields: opt.outputFields, @@ -364,6 +495,17 @@ func (opt *queryOption) Request() *milvuspb.QueryRequest { QueryParams: entity.MapKvPairs(opt.queryParams), ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(), } + + req.ExprTemplateValues = make(map[string]*schemapb.TemplateValue) + for key, value := range opt.templateParams { + tmplVal, err := any2TmplValue(value) + if err != nil { + return nil, err + } + req.ExprTemplateValues[key] = tmplVal + } + + return req, nil } func (opt *queryOption) WithFilter(expr string) *queryOption { @@ -371,6 +513,11 @@ func (opt *queryOption) WithFilter(expr string) *queryOption { return opt } +func (opt *queryOption) WithTemplateParam(key string, val any) *queryOption { + opt.templateParams[key] = val + return opt +} + func (opt *queryOption) WithOffset(offset int) *queryOption { if opt.queryParams == nil { opt.queryParams = make(map[string]string) @@ -408,5 +555,6 @@ func NewQueryOption(collectionName string) *queryOption { collectionName: collectionName, useDefaultConsistencyLevel: true, consistencyLevel: entity.ClBounded, + templateParams: make(map[string]any), } } diff --git a/client/milvusclient/read_test.go b/client/milvusclient/read_test.go index 9c73b2e68b8db..2e5a1099a17a3 100644 --- a/client/milvusclient/read_test.go +++ b/client/milvusclient/read_test.go @@ -75,6 +75,8 @@ func (s *ReadSuite) TestSearch() { return rand.Float32() })), }).WithPartitions(partitionName). + WithFilter("id > {tmpl_id}"). + WithTemplateParam("tmpl_id", 100). WithGroupByField("group_by"). WithSearchParam("ignore_growing", "true"). WithAnnParam(ap), @@ -178,11 +180,11 @@ func (s *ReadSuite) TestHybridSearch() { }, nil }).Once() - _, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + _, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { return rand.Float32() }))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { return rand.Float32() - })))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithOutputFields("*")) + })))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithReranker(NewRRFReranker()).WithOutputFields("*")) s.NoError(err) }) @@ -190,14 +192,14 @@ func (s *ReadSuite) TestHybridSearch() { collectionName := fmt.Sprintf("coll_%s", s.randString(6)) s.setupCache(collectionName, s.schemaDyn) - _, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, nonSupportData{}))) + _, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, nonSupportData{}))) s.Error(err) s.mock.EXPECT().HybridSearch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hsr *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) { return nil, merr.WrapErrServiceInternal("mocked") }).Once() - _, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + _, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { return rand.Float32() }))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { return rand.Float32() @@ -224,6 +226,14 @@ func (s *ReadSuite) TestQuery() { _, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions(partitionName)) s.NoError(err) }) + + s.Run("bad_request", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + s.setupCache(collectionName, s.schema) + + _, err := s.client.Query(ctx, NewQueryOption(collectionName).WithFilter("id > {tmpl_id}").WithTemplateParam("tmpl_id", struct{}{})) + s.Error(err) + }) } func TestRead(t *testing.T) { diff --git a/client/milvusclient/reranker.go b/client/milvusclient/reranker.go new file mode 100644 index 0000000000000..32e832e616132 --- /dev/null +++ b/client/milvusclient/reranker.go @@ -0,0 +1,62 @@ +package milvusclient + +import ( + "encoding/json" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +const ( + rerankType = "strategy" + rerankParams = "params" + rffParam = "k" + weightedParam = "weights" + + rrfRerankType = `rrf` + weightedRerankType = `weighted` +) + +type Reranker interface { + GetParams() []*commonpb.KeyValuePair +} + +type rrfReranker struct { + K float64 `json:"k,omitempty"` +} + +func (r *rrfReranker) WithK(k float64) *rrfReranker { + r.K = k + return r +} + +func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair { + bs, _ := json.Marshal(r) + + return []*commonpb.KeyValuePair{ + {Key: rerankType, Value: rrfRerankType}, + {Key: rerankParams, Value: string(bs)}, + } +} + +func NewRRFReranker() *rrfReranker { + return &rrfReranker{K: 60} +} + +type weightedReranker struct { + Weights []float64 `json:"weights,omitempty"` +} + +func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair { + bs, _ := json.Marshal(r) + + return []*commonpb.KeyValuePair{ + {Key: rerankType, Value: weightedRerankType}, + {Key: rerankParams, Value: string(bs)}, + } +} + +func NewWeightedReranker(weights []float64) *weightedReranker { + return &weightedReranker{ + Weights: weights, + } +} diff --git a/client/milvusclient/reranker_test.go b/client/milvusclient/reranker_test.go new file mode 100644 index 0000000000000..7ed7d0807c4bc --- /dev/null +++ b/client/milvusclient/reranker_test.go @@ -0,0 +1,55 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +func TestReranker(t *testing.T) { + checkParam := func(params []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range params { + if kv.Key == key && kv.Value == value { + return true + } + } + return false + } + + t.Run("rffReranker", func(t *testing.T) { + rr := NewRRFReranker() + params := rr.GetParams() + assert.True(t, checkParam(params, rerankType, rrfRerankType)) + assert.True(t, checkParam(params, rerankParams, `{"k":60}`), "default k shall be 60") + + rr.WithK(50) + params = rr.GetParams() + assert.True(t, checkParam(params, rerankType, rrfRerankType)) + assert.True(t, checkParam(params, rerankParams, `{"k":50}`)) + }) + + t.Run("weightedReranker", func(t *testing.T) { + rr := NewWeightedReranker([]float64{1, 2, 1}) + params := rr.GetParams() + assert.True(t, checkParam(params, rerankType, weightedRerankType)) + assert.True(t, checkParam(params, rerankParams, `{"weights":[1,2,1]}`)) + }) +}