diff --git a/client/ann_request.go b/client/ann_request.go new file mode 100644 index 00000000..df88aa47 --- /dev/null +++ b/client/ann_request.go @@ -0,0 +1,73 @@ +package client + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +type ANNSearchRequest struct { + fieldName string + vectors []entity.Vector + metricType entity.MetricType + expr string + searchParam entity.SearchParam + options []SearchQueryOptionFunc + limit int +} + +func NewANNSearchRequest(fieldName string, metricsType entity.MetricType, vectors []entity.Vector, searchParam entity.SearchParam, limit int, options ...SearchQueryOptionFunc) *ANNSearchRequest { + return &ANNSearchRequest{ + fieldName: fieldName, + vectors: vectors, + metricType: metricsType, + searchParam: searchParam, + limit: limit, + options: options, + } +} +func (r *ANNSearchRequest) WithExpr(expr string) *ANNSearchRequest { + r.expr = expr + return r +} + +func (r *ANNSearchRequest) getMilvusSearchRequest(collectionInfo *collInfo) (*milvuspb.SearchRequest, error) { + opt := &SearchQueryOption{ + ConsistencyLevel: collectionInfo.ConsistencyLevel, // default + } + for _, o := range r.options { + o(opt) + } + params := r.searchParam.Params() + params[forTuningKey] = opt.ForTuning + bs, err := json.Marshal(params) + if err != nil { + return nil, err + } + + searchParams := entity.MapKvPairs(map[string]string{ + "anns_field": r.fieldName, + "topk": fmt.Sprintf("%d", r.limit), + "params": string(bs), + "metric_type": string(r.metricType), + "round_decimal": "-1", + ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing), + offsetKey: fmt.Sprintf("%d", opt.Offset), + groupByKey: opt.GroupByField, + }) + + result := &milvuspb.SearchRequest{ + DbName: "", + Dsl: r.expr, + PlaceholderGroup: vector2PlaceholderGroupBytes(r.vectors), + DslType: commonpb.DslType_BoolExprV1, + SearchParams: searchParams, + GuaranteeTimestamp: opt.GuaranteeTimestamp, + Nq: int64(len(r.vectors)), + } + return result, nil +} diff --git a/client/client.go b/client/client.go index ed2bacfe..8787fe87 100644 --- a/client/client.go +++ b/client/client.go @@ -219,6 +219,8 @@ type Client interface { msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition, opts ...ReplicateMessageOption, ) (*entity.MessageInfo, error) + + HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error) } // NewClient create a client connected to remote milvus cluster. diff --git a/client/client_test.go b/client/client_test.go index d369cba3..aa5e99bf 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -150,6 +150,7 @@ func TestGrpcClientNil(t *testing.T) { if m.Name == "Close" || m.Name == "Connect" || // skip connect & close m.Name == "UsingDatabase" || // skip use database m.Name == "Search" || // type alias MetricType treated as string + m.Name == "HybridSearch" || // type alias MetricType treated as string m.Name == "CalcDistance" || m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ... diff --git a/client/data.go b/client/data.go index 9a85c43f..71926815 100644 --- a/client/data.go +++ b/client/data.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-sdk-go/v2/entity" + "github.com/milvus-io/milvus-sdk-go/v2/merr" ) const ( @@ -35,6 +36,59 @@ const ( groupByKey = `group_by_field` ) +func (c *GrpcClient) HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error) { + if c.Service == nil { + return nil, ErrClientNotReady + } + + var schema *entity.Schema + collInfo, ok := MetaCache.getCollectionInfo(collName) + if !ok { + coll, err := c.DescribeCollection(ctx, collName) + if err != nil { + return nil, err + } + schema = coll.Schema + } else { + schema = collInfo.Schema + } + + sReqs := make([]*milvuspb.SearchRequest, 0, len(subRequests)) + nq := 0 + for _, subRequest := range subRequests { + r, err := subRequest.getMilvusSearchRequest(collInfo) + if err != nil { + return nil, err + } + r.CollectionName = collName + r.PartitionNames = partitions + r.OutputFields = outputFields + nq = len(subRequest.vectors) + sReqs = append(sReqs, r) + } + + params := reranker.GetParams() + params = append(params, &commonpb.KeyValuePair{Key: limitKey, Value: strconv.FormatInt(int64(limit), 10)}) + + req := &milvuspb.HybridSearchRequest{ + CollectionName: collName, + PartitionNames: partitions, + Requests: sReqs, + OutputFields: outputFields, + ConsistencyLevel: commonpb.ConsistencyLevel(collInfo.ConsistencyLevel), + RankParams: params, + } + + result, err := c.Service.HybridSearch(ctx, req) + + err = merr.CheckRPCCall(result, err) + if err != nil { + return nil, err + } + + return c.handleSearchResult(schema, outputFields, nq, result) +} + // Search with bool expression func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string, expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) { @@ -63,7 +117,6 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s return nil, err } - sr := make([]SearchResult, 0, len(vectors)) resp, err := c.Service.Search(ctx, req) if err != nil { return nil, err @@ -72,6 +125,13 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s return nil, err } // 3. parse result into result + return c.handleSearchResult(schema, outputFields, len(vectors), resp) +} + +func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]SearchResult, error) { + var err error + sr := make([]SearchResult, 0, nq) + // 3. parse result into result results := resp.GetResults() offset := 0 fieldDataList := results.GetFieldsData() diff --git a/client/reranker.go b/client/reranker.go new file mode 100644 index 00000000..a0cca80f --- /dev/null +++ b/client/reranker.go @@ -0,0 +1,62 @@ +package client + +import ( + "encoding/json" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +const ( + rerankType = "strategy" + rerankParams = "params" + rffParam = "k" + weightedParam = "weights" + + rrfRerankType = `rrf` + weightedRerankType = `weighted` +) + +type Reranker interface { + GetParams() []*commonpb.KeyValuePair +} + +type rrfReranker struct { + K float64 `json:"k,omitempty"` +} + +func (r *rrfReranker) WithK(k float64) *rrfReranker { + r.K = k + return r +} + +func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair { + bs, _ := json.Marshal(r) + + return []*commonpb.KeyValuePair{ + {Key: rerankType, Value: rrfRerankType}, + {Key: rerankParams, Value: string(bs)}, + } +} + +func NewRRFReranker() *rrfReranker { + return &rrfReranker{K: 60} +} + +type weightedReranker struct { + Weights []float64 `json:"weights,omitempty"` +} + +func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair { + bs, _ := json.Marshal(r) + + return []*commonpb.KeyValuePair{ + {Key: rerankType, Value: rrfRerankType}, + {Key: rerankParams, Value: string(bs)}, + } +} + +func NewWeightedReranker(weights []float64) *weightedReranker { + return &weightedReranker{ + Weights: weights, + } +} diff --git a/examples/multivectors/main.go b/examples/multivectors/main.go new file mode 100644 index 00000000..05e7add1 --- /dev/null +++ b/examples/multivectors/main.go @@ -0,0 +1,161 @@ +package main + +import ( + "context" + "log" + "math/rand" + "time" + + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +const ( + milvusAddr = `localhost:19530` + nEntities, dim = 10000, 128 + collectionName = "hello_multi_vectors" + + idCol, keyCol, embeddingCol1, embeddingCol2 = "ID", "key", "vector1", "vector2" + topK = 3 +) + +func main() { + ctx := context.Background() + + log.Println("start connecting to Milvus") + c, err := client.NewClient(ctx, client.Config{ + Address: milvusAddr, + }) + if err != nil { + log.Fatalf("failed to connect to milvus, err: %v", err) + } + defer c.Close() + + // delete collection if exists + has, err := c.HasCollection(ctx, collectionName) + if err != nil { + log.Fatalf("failed to check collection exists, err: %v", err) + } + if has { + c.DropCollection(ctx, collectionName) + } + + // create collection + log.Printf("create collection `%s`\n", collectionName) + schema := entity.NewSchema().WithName(collectionName).WithDescription("hello_partition_key is the a demo to introduce the partition key related APIs"). + WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)). + WithField(entity.NewField().WithName(keyCol).WithDataType(entity.FieldTypeInt64)). + WithField(entity.NewField().WithName(embeddingCol1).WithDataType(entity.FieldTypeFloatVector).WithDim(dim)). + WithField(entity.NewField().WithName(embeddingCol2).WithDataType(entity.FieldTypeFloatVector).WithDim(dim)) + + if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber); err != nil { // use default shard number + log.Fatalf("create collection failed, err: %v", err) + } + + var keyList []int64 + var embeddingList [][]float32 + keyList = make([]int64, 0, nEntities) + embeddingList = make([][]float32, 0, nEntities) + for i := 0; i < nEntities; i++ { + keyList = append(keyList, rand.Int63()%512) + } + for i := 0; i < nEntities; i++ { + vec := make([]float32, 0, dim) + for j := 0; j < dim; j++ { + vec = append(vec, rand.Float32()) + } + embeddingList = append(embeddingList, vec) + } + keyColData := entity.NewColumnInt64(keyCol, keyList) + embeddingColData1 := entity.NewColumnFloatVector(embeddingCol1, dim, embeddingList) + embeddingColData2 := entity.NewColumnFloatVector(embeddingCol2, dim, embeddingList) + + log.Println("start to insert data into collection") + + if _, err := c.Insert(ctx, collectionName, "", keyColData, embeddingColData1, embeddingColData2); err != nil { + log.Fatalf("failed to insert random data into `%s`, err: %v", collectionName, err) + } + + log.Println("insert data done, start to flush") + + if err := c.Flush(ctx, collectionName, false); err != nil { + log.Fatalf("failed to flush data, err: %v", err) + } + log.Println("flush data done") + + // build index + log.Println("start creating index HNSW") + idx, err := entity.NewIndexHNSW(entity.L2, 16, 256) + if err != nil { + log.Fatalf("failed to create ivf flat index, err: %v", err) + } + if err := c.CreateIndex(ctx, collectionName, embeddingCol1, idx, false); err != nil { + log.Fatalf("failed to create index, err: %v", err) + } + if err := c.CreateIndex(ctx, collectionName, embeddingCol2, idx, false); err != nil { + log.Fatalf("failed to create index, err: %v", err) + } + + log.Printf("build HNSW index done for collection `%s`\n", collectionName) + log.Printf("start to load collection `%s`\n", collectionName) + + // load collection + if err := c.LoadCollection(ctx, collectionName, false); err != nil { + log.Fatalf("failed to load collection, err: %v", err) + } + + log.Println("load collection done") + + // currently only nq =1 is supported + vec2search1 := []entity.Vector{ + entity.FloatVector(embeddingList[len(embeddingList)-2]), + } + vec2search2 := []entity.Vector{ + entity.FloatVector(embeddingList[len(embeddingList)-1]), + } + + begin := time.Now() + sp, _ := entity.NewIndexHNSWSearchParam(30) + + log.Println("start to search vector field 1") + result, err := c.Search(ctx, collectionName, nil, "", []string{keyCol, embeddingCol1, embeddingCol2}, vec2search1, + embeddingCol1, entity.L2, topK, sp) + if err != nil { + log.Fatalf("failed to search collection, err: %v", err) + } + + log.Printf("search `%s` done, latency %v\n", collectionName, time.Since(begin)) + for _, rs := range result { + for i := 0; i < rs.ResultCount; i++ { + id, _ := rs.IDs.GetAsInt64(i) + score := rs.Scores[i] + embedding, _ := rs.Fields.GetColumn(embeddingCol1).Get(i) + + log.Printf("ID: %d, score %f, embedding: %v\n", id, score, embedding) + } + } + + log.Println("start to execute hybrid search") + + result, err = c.HybridSearch(ctx, collectionName, nil, topK, []string{keyCol, embeddingCol1, embeddingCol2}, + client.NewRRFReranker(), []*client.ANNSearchRequest{ + client.NewANNSearchRequest(embeddingCol1, entity.L2, vec2search1, sp, topK), + client.NewANNSearchRequest(embeddingCol2, entity.L2, vec2search2, sp, topK), + }) + if err != nil { + log.Fatalf("failed to search collection, err: %v", err) + } + + log.Printf("hybrid search `%s` done, latency %v\n", collectionName, time.Since(begin)) + for _, rs := range result { + for i := 0; i < rs.ResultCount; i++ { + id, _ := rs.IDs.GetAsInt64(i) + score := rs.Scores[i] + embedding1, _ := rs.Fields.GetColumn(embeddingCol1).Get(i) + embedding2, _ := rs.Fields.GetColumn(embeddingCol1).Get(i) + log.Printf("ID: %d, score %f, embedding1: %v, embedding2: %v\n", id, score, embedding1, embedding2) + } + } + + c.DropCollection(ctx, collectionName) +}