Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Support hybrid search multiple vector fields #663

Merged
merged 2 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions client/ann_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package client

import (
"encoding/json"
"fmt"
"strconv"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

type ANNSearchRequest struct {
fieldName string
vectors []entity.Vector
metricType entity.MetricType
expr string
searchParam entity.SearchParam
options []SearchQueryOptionFunc
limit int
}

func NewANNSearchRequest(fieldName string, metricsType entity.MetricType, vectors []entity.Vector, searchParam entity.SearchParam, limit int, options ...SearchQueryOptionFunc) *ANNSearchRequest {
return &ANNSearchRequest{
fieldName: fieldName,
vectors: vectors,
metricType: metricsType,
searchParam: searchParam,
limit: limit,
options: options,
}
}
func (r *ANNSearchRequest) WithExpr(expr string) *ANNSearchRequest {
r.expr = expr
return r
}

func (r *ANNSearchRequest) getMilvusSearchRequest(collectionInfo *collInfo) (*milvuspb.SearchRequest, error) {
opt := &SearchQueryOption{
ConsistencyLevel: collectionInfo.ConsistencyLevel, // default
}
for _, o := range r.options {
o(opt)
}
params := r.searchParam.Params()
params[forTuningKey] = opt.ForTuning
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

searchParams := entity.MapKvPairs(map[string]string{
"anns_field": r.fieldName,
"topk": fmt.Sprintf("%d", r.limit),
"params": string(bs),
"metric_type": string(r.metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
groupByKey: opt.GroupByField,
})

result := &milvuspb.SearchRequest{
DbName: "",
Dsl: r.expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(r.vectors),
DslType: commonpb.DslType_BoolExprV1,
SearchParams: searchParams,
GuaranteeTimestamp: opt.GuaranteeTimestamp,
Nq: int64(len(r.vectors)),
}
return result, nil
}
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ type Client interface {
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption,
) (*entity.MessageInfo, error)

HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error)
}

// NewClient create a client connected to remote milvus cluster.
Expand Down
1 change: 1 addition & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func TestGrpcClientNil(t *testing.T) {
if m.Name == "Close" || m.Name == "Connect" || // skip connect & close
m.Name == "UsingDatabase" || // skip use database
m.Name == "Search" || // type alias MetricType treated as string
m.Name == "HybridSearch" || // type alias MetricType treated as string
m.Name == "CalcDistance" ||
m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ...
Expand Down
62 changes: 61 additions & 1 deletion client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/milvus-io/milvus-sdk-go/v2/merr"
)

const (
Expand All @@ -35,6 +36,59 @@ const (
groupByKey = `group_by_field`
)

func (c *GrpcClient) HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error) {
if c.Service == nil {
return nil, ErrClientNotReady
}

var schema *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collName)
if !ok {
coll, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
schema = coll.Schema
} else {
schema = collInfo.Schema
}

sReqs := make([]*milvuspb.SearchRequest, 0, len(subRequests))
nq := 0
for _, subRequest := range subRequests {
r, err := subRequest.getMilvusSearchRequest(collInfo)
if err != nil {
return nil, err
}
r.CollectionName = collName
r.PartitionNames = partitions
r.OutputFields = outputFields
nq = len(subRequest.vectors)
sReqs = append(sReqs, r)
}

params := reranker.GetParams()
params = append(params, &commonpb.KeyValuePair{Key: limitKey, Value: strconv.FormatInt(int64(limit), 10)})

req := &milvuspb.HybridSearchRequest{
CollectionName: collName,
PartitionNames: partitions,
Requests: sReqs,
OutputFields: outputFields,
ConsistencyLevel: commonpb.ConsistencyLevel(collInfo.ConsistencyLevel),
RankParams: params,
}

result, err := c.Service.HybridSearch(ctx, req)

err = merr.CheckRPCCall(result, err)
if err != nil {
return nil, err
}

return c.handleSearchResult(schema, outputFields, nq, result)
}

// Search with bool expression
func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
Expand Down Expand Up @@ -63,7 +117,6 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
return nil, err
}

sr := make([]SearchResult, 0, len(vectors))
resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
Expand All @@ -72,6 +125,13 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
return nil, err
}
// 3. parse result into result
return c.handleSearchResult(schema, outputFields, len(vectors), resp)
}

func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]SearchResult, error) {
var err error
sr := make([]SearchResult, 0, nq)
// 3. parse result into result
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
Expand Down
62 changes: 62 additions & 0 deletions client/reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package client

import (
"encoding/json"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)

const (
rerankType = "strategy"
rerankParams = "params"
rffParam = "k"
weightedParam = "weights"

rrfRerankType = `rrf`
weightedRerankType = `weighted`
)

type Reranker interface {
GetParams() []*commonpb.KeyValuePair
}

type rrfReranker struct {
K float64 `json:"k,omitempty"`
}

func (r *rrfReranker) WithK(k float64) *rrfReranker {
r.K = k
return r
}

func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)

return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}

func NewRRFReranker() *rrfReranker {
return &rrfReranker{K: 60}
}

type weightedReranker struct {
Weights []float64 `json:"weights,omitempty"`
}

func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)

return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}

func NewWeightedReranker(weights []float64) *weightedReranker {
return &weightedReranker{
Weights: weights,
}
}
Loading
Loading