Skip to content

Commit

Permalink
enhance: Support query iterator (#749)
Browse files Browse the repository at this point in the history
Support query iterator interface in GoSDK

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored May 28, 2024
1 parent 06f447f commit 8309bdf
Show file tree
Hide file tree
Showing 18 changed files with 614 additions and 2 deletions.
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ type Client interface {
Query(ctx context.Context, collectionName string, partitionNames []string, expr string, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error)
// Get grabs the inserted entities using the primary key from the Collection.
Get(ctx context.Context, collectionName string, ids entity.Column, opts ...GetOption) (ResultSet, error)
// QueryIterator returns data matches provided criterion in iterator mode.
QueryIterator(ctx context.Context, opt *QueryIteratorOption) (*QueryIterator, error)

// CalcDistance calculate the distance between vectors specified by ids or provided
CalcDistance(ctx context.Context, collName string, partitions []string,
Expand Down
1 change: 1 addition & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func TestGrpcClientNil(t *testing.T) {
if m.Name == "Close" || m.Name == "Connect" || // skip connect & close
m.Name == "UsingDatabase" || // skip use database
m.Name == "Search" || // type alias MetricType treated as string
m.Name == "QueryIterator" ||
m.Name == "HybridSearch" || // type alias MetricType treated as string
m.Name == "CalcDistance" ||
m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
Expand Down
8 changes: 8 additions & 0 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ const (
ignoreGrowingKey = `ignore_growing`
forTuningKey = `for_tuning`
groupByKey = `group_by_field`
iteratorKey = `iterator`
reduceForBestKey = `reduce_stop_for_best`
)

func (c *GrpcClient) HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
Expand Down Expand Up @@ -352,6 +354,12 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition
if option.IgnoreGrowing {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ignoreGrowingKey, Value: strconv.FormatBool(option.IgnoreGrowing)})
}
if option.isIterator {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: iteratorKey, Value: strconv.FormatBool(true)})
}
if option.reduceForBest {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: reduceForBestKey, Value: strconv.FormatBool(true)})
}

resp, err := c.Service.Query(ctx, req)
if err != nil {
Expand Down
179 changes: 179 additions & 0 deletions client/iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package client

import (
"context"
"fmt"
"io"
"strings"

"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

func NewQueryIteratorOption(collectionName string) *QueryIteratorOption {
return &QueryIteratorOption{
collectionName: collectionName,
batchSize: 1000,
}
}

type QueryIteratorOption struct {
collectionName string
partitionNames []string
expr string
outputFields []string
batchSize int
}

func (opt *QueryIteratorOption) WithPartitions(partitionNames ...string) *QueryIteratorOption {
opt.partitionNames = partitionNames
return opt
}

func (opt *QueryIteratorOption) WithExpr(expr string) *QueryIteratorOption {
opt.expr = expr
return opt
}

func (opt *QueryIteratorOption) WithOutputFields(outputFields ...string) *QueryIteratorOption {
opt.outputFields = outputFields
return opt
}

func (opt *QueryIteratorOption) WithBatchSize(batchSize int) *QueryIteratorOption {
opt.batchSize = batchSize
return opt
}

func (c *GrpcClient) QueryIterator(ctx context.Context, opt *QueryIteratorOption) (*QueryIterator, error) {
collectionName := opt.collectionName
var sch *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collectionName)
if !ok {
coll, err := c.DescribeCollection(ctx, collectionName)
if err != nil {
return nil, err
}
sch = coll.Schema
} else {
sch = collInfo.Schema
}

itr := &QueryIterator{
client: c,

collectionName: opt.collectionName,
partitionNames: opt.partitionNames,
outputFields: opt.outputFields,
sch: sch,
pkField: sch.PKField(),

batchSize: opt.batchSize,
expr: opt.expr,
}

err := itr.init(ctx)
if err != nil {
return nil, err
}
return itr, nil
}

type QueryIterator struct {
// user provided expression
expr string

batchSize int

cached ResultSet

collectionName string
partitionNames []string
outputFields []string
sch *entity.Schema
pkField *entity.Field

lastPK interface{}

// internal grpc client
client *GrpcClient
}

// init fetches the first batch of data and put it into cache.
// this operation could be used to check all the parameters before returning the iterator.
func (itr *QueryIterator) init(ctx context.Context) error {
rs, err := itr.fetchNextBatch(ctx)
if err != nil {
return err
}
itr.cached = rs
return nil
}

func (itr *QueryIterator) composeIteratorExpr() string {
if itr.lastPK == nil {
return itr.expr
}

expr := strings.TrimSpace(itr.expr)
if expr != "" {
expr += " AND "
}

switch itr.pkField.DataType {
case entity.FieldTypeInt64:
expr += fmt.Sprintf("%s > %d", itr.pkField.Name, itr.lastPK)
case entity.FieldTypeVarChar:
expr += fmt.Sprintf(`%s > "%s"`, itr.pkField.Name, itr.lastPK)
default:
return itr.expr
}
return expr
}

func (itr *QueryIterator) fetchNextBatch(ctx context.Context) (ResultSet, error) {
return itr.client.Query(ctx, itr.collectionName, itr.partitionNames, itr.composeIteratorExpr(), itr.outputFields,
WithLimit(int64(float64(itr.batchSize))), withIterator(), reduceForBest(true))
}

func (itr *QueryIterator) cachedSufficient() bool {
return itr.cached != nil && itr.cached.Len() >= itr.batchSize
}

func (itr *QueryIterator) cacheNextBatch(rs ResultSet) (ResultSet, error) {
result := rs.Slice(0, itr.batchSize)
itr.cached = rs.Slice(itr.batchSize, -1)

pkColumn := result.GetColumn(itr.pkField.Name)
switch itr.pkField.DataType {
case entity.FieldTypeInt64:
itr.lastPK, _ = pkColumn.GetAsInt64(pkColumn.Len() - 1)
case entity.FieldTypeVarChar:
itr.lastPK, _ = pkColumn.GetAsString(pkColumn.Len() - 1)
default:
return nil, errors.Newf("unsupported pk type: %v", itr.pkField.DataType)
}
return result, nil
}

func (itr *QueryIterator) Next(ctx context.Context) (ResultSet, error) {
var rs ResultSet
var err error

// check cache sufficient for next batch
if !itr.cachedSufficient() {
rs, err = itr.fetchNextBatch(ctx)
if err != nil {
return nil, err
}
} else {
rs = itr.cached
}

// if resultset is empty, return EOF
if rs.Len() == 0 {
return nil, io.EOF
}

return itr.cacheNextBatch(rs)
}
15 changes: 15 additions & 0 deletions client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,26 @@ type SearchQueryOption struct {
ForTuning bool

GroupByField string

isIterator bool
reduceForBest bool
}

// SearchQueryOptionFunc is a function which modifies SearchOption
type SearchQueryOptionFunc func(option *SearchQueryOption)

func withIterator() SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
option.isIterator = true
}
}

func reduceForBest(value bool) SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
option.reduceForBest = value
}
}

func WithForTuning() SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
option.ForTuning = true
Expand Down
15 changes: 15 additions & 0 deletions client/results.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ type SearchResult struct {
// ResultSet is an alias type for column slice.
type ResultSet []entity.Column

func (rs ResultSet) Len() int {
if len(rs) == 0 {
return 0
}
return rs[0].Len()
}

func (rs ResultSet) Slice(start, end int) ResultSet {
result := make([]entity.Column, 0, len(rs))
for _, col := range rs {
result = append(result, col.Slice(start, end))
}
return result
}

// GetColumn returns column with provided field name.
func (rs ResultSet) GetColumn(fieldName string) entity.Column {
for _, column := range rs {
Expand Down
1 change: 1 addition & 0 deletions entity/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Column interface {
Name() string
Type() FieldType
Len() int
Slice(int, int) Column
FieldData() *schemapb.FieldData
AppendValue(interface{}) error
Get(int) (interface{}, error)
Expand Down
11 changes: 11 additions & 0 deletions entity/columns_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ func (c *ColumnVarCharArray) Len() int {
return len(c.values)
}

func (c *ColumnVarCharArray) Slice(start, end int) Column {
if end == -1 || end > c.Len() {
end = c.Len()
}
return &ColumnVarCharArray{
ColumnBase: c.ColumnBase,
name: c.name,
values: c.values[start:end],
}
}

// Get returns value at index as interface{}.
func (c *ColumnVarCharArray) Get(idx int) (interface{}, error) {
var r []string // use default value
Expand Down
Loading

0 comments on commit 8309bdf

Please sign in to comment.