Skip to content

Commit

Permalink
enhance: Support hybrid search multiple vector fields (#663)
Browse files Browse the repository at this point in the history
See also: milvus-io/milvus#25639
milvus pr:  milvus-io/milvus#29433

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Feb 2, 2024
1 parent 1e03ea4 commit 43e3bd2
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 1 deletion.
73 changes: 73 additions & 0 deletions client/ann_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package client

import (
"encoding/json"
"fmt"
"strconv"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

type ANNSearchRequest struct {
fieldName string
vectors []entity.Vector
metricType entity.MetricType
expr string
searchParam entity.SearchParam
options []SearchQueryOptionFunc
limit int
}

func NewANNSearchRequest(fieldName string, metricsType entity.MetricType, vectors []entity.Vector, searchParam entity.SearchParam, limit int, options ...SearchQueryOptionFunc) *ANNSearchRequest {
return &ANNSearchRequest{
fieldName: fieldName,
vectors: vectors,
metricType: metricsType,
searchParam: searchParam,
limit: limit,
options: options,
}
}
func (r *ANNSearchRequest) WithExpr(expr string) *ANNSearchRequest {
r.expr = expr
return r
}

func (r *ANNSearchRequest) getMilvusSearchRequest(collectionInfo *collInfo) (*milvuspb.SearchRequest, error) {
opt := &SearchQueryOption{
ConsistencyLevel: collectionInfo.ConsistencyLevel, // default
}
for _, o := range r.options {
o(opt)
}
params := r.searchParam.Params()
params[forTuningKey] = opt.ForTuning
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

searchParams := entity.MapKvPairs(map[string]string{
"anns_field": r.fieldName,
"topk": fmt.Sprintf("%d", r.limit),
"params": string(bs),
"metric_type": string(r.metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
groupByKey: opt.GroupByField,
})

result := &milvuspb.SearchRequest{
DbName: "",
Dsl: r.expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(r.vectors),
DslType: commonpb.DslType_BoolExprV1,
SearchParams: searchParams,
GuaranteeTimestamp: opt.GuaranteeTimestamp,
Nq: int64(len(r.vectors)),
}
return result, nil
}
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ type Client interface {
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption,
) (*entity.MessageInfo, error)

HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error)
}

// NewClient create a client connected to remote milvus cluster.
Expand Down
1 change: 1 addition & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func TestGrpcClientNil(t *testing.T) {
if m.Name == "Close" || m.Name == "Connect" || // skip connect & close
m.Name == "UsingDatabase" || // skip use database
m.Name == "Search" || // type alias MetricType treated as string
m.Name == "HybridSearch" || // type alias MetricType treated as string
m.Name == "CalcDistance" ||
m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ...
Expand Down
62 changes: 61 additions & 1 deletion client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/milvus-io/milvus-sdk-go/v2/merr"
)

const (
Expand All @@ -35,6 +36,59 @@ const (
groupByKey = `group_by_field`
)

func (c *GrpcClient) HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error) {
if c.Service == nil {
return nil, ErrClientNotReady
}

var schema *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collName)
if !ok {
coll, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
schema = coll.Schema
} else {
schema = collInfo.Schema
}

sReqs := make([]*milvuspb.SearchRequest, 0, len(subRequests))
nq := 0
for _, subRequest := range subRequests {
r, err := subRequest.getMilvusSearchRequest(collInfo)
if err != nil {
return nil, err
}
r.CollectionName = collName
r.PartitionNames = partitions
r.OutputFields = outputFields
nq = len(subRequest.vectors)
sReqs = append(sReqs, r)
}

params := reranker.GetParams()
params = append(params, &commonpb.KeyValuePair{Key: limitKey, Value: strconv.FormatInt(int64(limit), 10)})

req := &milvuspb.HybridSearchRequest{
CollectionName: collName,
PartitionNames: partitions,
Requests: sReqs,
OutputFields: outputFields,
ConsistencyLevel: commonpb.ConsistencyLevel(collInfo.ConsistencyLevel),
RankParams: params,
}

result, err := c.Service.HybridSearch(ctx, req)

err = merr.CheckRPCCall(result, err)
if err != nil {
return nil, err
}

return c.handleSearchResult(schema, outputFields, nq, result)
}

// Search with bool expression
func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
Expand Down Expand Up @@ -63,7 +117,6 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
return nil, err
}

sr := make([]SearchResult, 0, len(vectors))
resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
Expand All @@ -72,6 +125,13 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
return nil, err
}
// 3. parse result into result
return c.handleSearchResult(schema, outputFields, len(vectors), resp)
}

func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]SearchResult, error) {
var err error
sr := make([]SearchResult, 0, nq)
// 3. parse result into result
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
Expand Down
62 changes: 62 additions & 0 deletions client/reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package client

import (
"encoding/json"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)

const (
rerankType = "strategy"
rerankParams = "params"
rffParam = "k"
weightedParam = "weights"

rrfRerankType = `rrf`
weightedRerankType = `weighted`
)

type Reranker interface {
GetParams() []*commonpb.KeyValuePair
}

type rrfReranker struct {
K float64 `json:"k,omitempty"`
}

func (r *rrfReranker) WithK(k float64) *rrfReranker {
r.K = k
return r
}

func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)

return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}

func NewRRFReranker() *rrfReranker {
return &rrfReranker{K: 60}
}

type weightedReranker struct {
Weights []float64 `json:"weights,omitempty"`
}

func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)

return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}

func NewWeightedReranker(weights []float64) *weightedReranker {
return &weightedReranker{
Weights: weights,
}
}
Loading

0 comments on commit 43e3bd2

Please sign in to comment.