From 1ef975d3279a1b158fab345c6eb638ca437d63ee Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 17 May 2024 10:39:37 +0800 Subject: [PATCH] enhance: Update latest sdk update to client pkg (#33105) Related to #31293 See also milvus-io/milvus-sdk-go#704 milvus-io/milvus-sdk-go#711 milvus-io/milvus-sdk-go#713 milvus-io/milvus-sdk-go#721 milvus-io/milvus-sdk-go#732 milvus-io/milvus-sdk-go#739 milvus-io/milvus-sdk-go#748 --------- Signed-off-by: Congqi Xia --- client/OWNERS | 1 + client/client.go | 67 +++++++++++++- client/client_config.go | 82 ++++++++---------- client/collection.go | 1 + client/collection_options.go | 8 +- client/column/columns.go | 32 ++++--- client/column/columns_test.go | 32 +++++-- client/database.go | 6 ++ client/database_options.go | 18 ++++ client/entity/collection.go | 1 + client/entity/schema.go | 24 +++-- client/index_options.go | 19 +++- client/interceptors.go | 159 ++++++++++++++++++++++++++++++++++ client/interceptors_test.go | 66 ++++++++++++++ client/read.go | 23 +++-- client/read_options.go | 9 +- client/write.go | 61 ++++++++++--- client/write_test.go | 67 +++++++++++--- 18 files changed, 553 insertions(+), 123 deletions(-) create mode 100644 client/interceptors.go create mode 100644 client/interceptors_test.go diff --git a/client/OWNERS b/client/OWNERS index 1e038a20ebbe7..e8864576b1b77 100644 --- a/client/OWNERS +++ b/client/OWNERS @@ -1,5 +1,6 @@ reviewers: - congqixia + - ThreadDao approvers: - maintainers diff --git a/client/client.go b/client/client.go index b515781106c30..803ef4935ad6e 100644 --- a/client/client.go +++ b/client/client.go @@ -18,14 +18,20 @@ package client import ( "context" + "crypto/tls" "fmt" + "math" "os" "strconv" + "sync" "time" "github.com/gogo/status" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -39,6 +45,11 @@ type Client struct { service milvuspb.MilvusServiceClient config *ClientConfig + // mutable status + stateMut sync.RWMutex + currentDB string + identifier string + collCache *CollectionCache } @@ -54,8 +65,10 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) { // Parse remote address. addr := c.config.getParsedAddress() + // parse authentication parameters + c.config.parseAuthentication() // Parse grpc options - options := c.config.getDialOption() + options := c.dialOptions() // Connect the grpc server. if err := c.connect(ctx, addr, options...); err != nil { @@ -69,6 +82,40 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) { return c, nil } +func (c *Client) dialOptions() []grpc.DialOption { + var options []grpc.DialOption + // Construct dial option. + if c.config.EnableTLSAuth { + options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))) + } else { + options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + if c.config.DialOptions == nil { + // Add default connection options. + options = append(options, DefaultGrpcOpts...) + } else { + options = append(options, c.config.DialOptions...) + } + + options = append(options, + grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor( + grpc_retry.WithMax(6), + grpc_retry.WithBackoff(func(attempt uint) time.Duration { + return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt))) + }), + grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)), + + // c.getRetryOnRateLimitInterceptor(), + )) + + options = append(options, grpc.WithChainUnaryInterceptor( + c.MetadataUnaryInterceptor(), + )) + + return options +} + func (c *Client) Close(ctx context.Context) error { if c.conn == nil { return nil @@ -82,6 +129,18 @@ func (c *Client) Close(ctx context.Context) error { return nil } +func (c *Client) usingDatabase(dbName string) { + c.stateMut.Lock() + defer c.stateMut.Unlock() + c.currentDB = dbName +} + +func (c *Client) setIdentifier(identifier string) { + c.stateMut.Lock() + defer c.stateMut.Unlock() + c.identifier = identifier +} + func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error { if addr == "" { return fmt.Errorf("address is empty") @@ -112,7 +171,7 @@ func (c *Client) connectInternal(ctx context.Context) error { req := &milvuspb.ConnectRequest{ ClientInfo: &commonpb.ClientInfo{ - SdkType: "Golang", + SdkType: "GoMilvusClient", SdkVersion: common.SDKVersion, LocalTime: time.Now().String(), User: c.config.Username, @@ -131,8 +190,8 @@ func (c *Client) connectInternal(ctx context.Context) error { disableJSON | disableParitionKey | disableDynamicSchema) + return nil } - return nil } return err } @@ -142,7 +201,7 @@ func (c *Client) connectInternal(ctx context.Context) error { } c.config.setServerInfo(resp.GetServerInfo().GetBuildTags()) - c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10)) + c.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10)) return nil } diff --git a/client/client_config.go b/client/client_config.go index 998adbf761eba..63a4f6d2b8565 100644 --- a/client/client_config.go +++ b/client/client_config.go @@ -1,7 +1,7 @@ package client import ( - "crypto/tls" + "context" "fmt" "math" "net/url" @@ -10,12 +10,9 @@ import ( "time" "github.com/cockroachdb/errors" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + "github.com/milvus-io/milvus/pkg/util/crypto" "google.golang.org/grpc" "google.golang.org/grpc/backoff" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" ) @@ -59,16 +56,23 @@ type ClientConfig struct { DialOptions []grpc.DialOption // Dial options for GRPC. - // RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor + RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor DisableConn bool + metadataHeaders map[string]string + identifier string // Identifier for this connection ServerVersion string // ServerVersion parsedAddress *url.URL flags uint64 // internal flags } +type RetryRateLimitOption struct { + MaxRetry uint + MaxBackoff time.Duration +} + func (cfg *ClientConfig) parse() error { // Prepend default fake tcp:// scheme for remote address. address := cfg.Address @@ -118,54 +122,36 @@ func (c *ClientConfig) setServerInfo(serverInfo string) { c.ServerVersion = serverInfo } -// Get parsed grpc dial options, should be called after parse was called. -func (c *ClientConfig) getDialOption() []grpc.DialOption { - options := c.DialOptions - if c.DialOptions == nil { - // Add default connection options. - options = make([]grpc.DialOption, len(DefaultGrpcOpts)) - copy(options, DefaultGrpcOpts) +// parseAuthentication prepares authentication headers for grpc inteceptors based on the provided username, password or API key. +func (c *ClientConfig) parseAuthentication() { + c.metadataHeaders = make(map[string]string) + if c.Username != "" || c.Password != "" { + value := crypto.Base64Encode(fmt.Sprintf("%s:%s", c.Username, c.Password)) + c.metadataHeaders[authorizationHeader] = value + } + // API overwrites username & passwd + if c.APIKey != "" { + value := crypto.Base64Encode(c.APIKey) + c.metadataHeaders[authorizationHeader] = value } +} - // Construct dial option. - if c.EnableTLSAuth { - options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))) - } else { - options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials())) +func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor { + if c.RetryRateLimit == nil { + c.RetryRateLimit = c.defaultRetryRateLimitOption() } - options = append(options, - grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(6), - grpc_retry.WithBackoff(func(attempt uint) time.Duration { - return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt))) - }), - grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)), - // c.getRetryOnRateLimitInterceptor(), - )) - - // options = append(options, grpc.WithChainUnaryInterceptor( - // createMetaDataUnaryInterceptor(c), - // )) - return options + return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration { + return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt))) + }) } -// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor { -// if c.RetryRateLimit == nil { -// c.RetryRateLimit = c.defaultRetryRateLimitOption() -// } - -// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration { -// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt))) -// }) -// } - -// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption { -// return &RetryRateLimitOption{ -// MaxRetry: 75, -// MaxBackoff: 3 * time.Second, -// } -// } +func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption { + return &RetryRateLimitOption{ + MaxRetry: 75, + MaxBackoff: 3 * time.Second, + } +} // addFlags set internal flags func (c *ClientConfig) addFlags(flags uint64) { diff --git a/client/collection.go b/client/collection.go index 7d05f5525e759..039ff2460d64c 100644 --- a/client/collection.go +++ b/client/collection.go @@ -98,6 +98,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option *describeCollect VirtualChannels: resp.GetVirtualChannelNames(), ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel), ShardNum: resp.GetShardsNum(), + Properties: entity.KvPairsMap(resp.GetProperties()), } collection.Name = collection.Schema.CollectionName return nil diff --git a/client/collection_options.go b/client/collection_options.go index 2d3457177c7d4..adb59e37b5145 100644 --- a/client/collection_options.go +++ b/client/collection_options.go @@ -140,6 +140,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti autoID: true, dim: dim, enabledDynamicSchema: true, + consistencyLevel: entity.DefaultConsistencyLevel, isFast: true, metricType: entity.COSINE, @@ -149,9 +150,10 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti // NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption { return &createCollectionOption{ - name: name, - shardNum: 1, - schema: collectionSchema, + name: name, + shardNum: 1, + schema: collectionSchema, + consistencyLevel: entity.DefaultConsistencyLevel, metricType: entity.COSINE, } diff --git a/client/column/columns.go b/client/column/columns.go index 79166634bc634..8a2a52d87941f 100644 --- a/client/column/columns.go +++ b/client/column/columns.go @@ -64,26 +64,38 @@ var errFieldDataTypeNotMatch = errors.New("FieldData type not matched") // IDColumns converts schemapb.IDs to corresponding column // currently Int64 / string may be in IDs -func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) { +func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column, error) { var idColumn Column - if idField == nil { + pkField := schema.PKField() + if pkField == nil { + return nil, errors.New("PK Field not found") + } + if ids == nil { return nil, errors.New("nil Ids from response") } - switch field := idField.GetIdField().(type) { - case *schemapb.IDs_IntId: + switch pkField.DataType { + case entity.FieldTypeInt64: + data := ids.GetIntId().GetData() + if data == nil { + return NewColumnInt64(pkField.Name, nil), nil + } if end >= 0 { - idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end]) + idColumn = NewColumnInt64(pkField.Name, data[begin:end]) } else { - idColumn = NewColumnInt64("", field.IntId.GetData()[begin:]) + idColumn = NewColumnInt64(pkField.Name, data[begin:]) + } + case entity.FieldTypeVarChar, entity.FieldTypeString: + data := ids.GetStrId().GetData() + if data == nil { + return NewColumnVarChar(pkField.Name, nil), nil } - case *schemapb.IDs_StrId: if end >= 0 { - idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end]) + idColumn = NewColumnVarChar(pkField.Name, data[begin:end]) } else { - idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:]) + idColumn = NewColumnVarChar(pkField.Name, data[begin:]) } default: - return nil, fmt.Errorf("unsupported id type %v", field) + return nil, fmt.Errorf("unsupported id type %v", pkField.DataType) } return idColumn, nil } diff --git a/client/column/columns_test.go b/client/column/columns_test.go index fad92f47d6b66..38547f384ca45 100644 --- a/client/column/columns_test.go +++ b/client/column/columns_test.go @@ -24,18 +24,34 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/client/v2/entity" ) func TestIDColumns(t *testing.T) { dataLen := rand.Intn(100) + 1 base := rand.Intn(5000) // id start point + intPKCol := entity.NewSchema().WithField( + entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64), + ) + strPKCol := entity.NewSchema().WithField( + entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar), + ) + t.Run("nil id", func(t *testing.T) { - _, err := IDColumns(nil, 0, -1) - assert.NotNil(t, err) + col, err := IDColumns(intPKCol, nil, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) + col, err = IDColumns(strPKCol, nil, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) idField := &schemapb.IDs{} - _, err = IDColumns(idField, 0, -1) - assert.NotNil(t, err) + col, err = IDColumns(intPKCol, idField, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) + col, err = IDColumns(strPKCol, idField, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) }) t.Run("int ids", func(t *testing.T) { @@ -50,12 +66,12 @@ func TestIDColumns(t *testing.T) { }, }, } - column, err := IDColumns(idField, 0, dataLen) + column, err := IDColumns(intPKCol, idField, 0, dataLen) assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) - column, err = IDColumns(idField, 0, -1) // test -1 method + column, err = IDColumns(intPKCol, idField, 0, -1) // test -1 method assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) @@ -72,12 +88,12 @@ func TestIDColumns(t *testing.T) { }, }, } - column, err := IDColumns(idField, 0, dataLen) + column, err := IDColumns(strPKCol, idField, 0, dataLen) assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) - column, err = IDColumns(idField, 0, -1) // test -1 method + column, err = IDColumns(strPKCol, idField, 0, -1) // test -1 method assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) diff --git a/client/database.go b/client/database.go index cd372ab911b54..b4ccaeaa12647 100644 --- a/client/database.go +++ b/client/database.go @@ -25,6 +25,12 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error { + dbName := option.DbName() + c.usingDatabase(dbName) + return c.connectInternal(ctx) +} + func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) { req := option.Request() diff --git a/client/database_options.go b/client/database_options.go index 3fb26d91de9d8..13a58709b6877 100644 --- a/client/database_options.go +++ b/client/database_options.go @@ -18,6 +18,24 @@ package client import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +type UsingDatabaseOption interface { + DbName() string +} + +type usingDatabaseNameOpt struct { + dbName string +} + +func (opt *usingDatabaseNameOpt) DbName() string { + return opt.dbName +} + +func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt { + return &usingDatabaseNameOpt{ + dbName: dbName, + } +} + // ListDatabaseOption is a builder interface for ListDatabase request. type ListDatabaseOption interface { Request() *milvuspb.ListDatabasesRequest diff --git a/client/entity/collection.go b/client/entity/collection.go index 72d86a8ecbc24..f30cc05f59809 100644 --- a/client/entity/collection.go +++ b/client/entity/collection.go @@ -32,6 +32,7 @@ type Collection struct { Loaded bool ConsistencyLevel ConsistencyLevel ShardNum int32 + Properties map[string]string } // Partition represent partition meta in Milvus diff --git a/client/entity/schema.go b/client/entity/schema.go index 4434969578ae4..ce30b53f51483 100644 --- a/client/entity/schema.go +++ b/client/entity/schema.go @@ -60,6 +60,8 @@ type Schema struct { AutoID bool Fields []*Field EnableDynamicField bool + + pkField *Field } // NewSchema creates an empty schema object. @@ -91,6 +93,9 @@ func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema { // WithField adds a field into schema and returns schema itself. func (s *Schema) WithField(f *Field) *Schema { + if f.PrimaryKey { + s.pkField = f + } s.Fields = append(s.Fields, f) return s } @@ -116,10 +121,14 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema { s.CollectionName = p.GetName() s.Fields = make([]*Field, 0, len(p.GetFields())) for _, fp := range p.GetFields() { + field := NewField().ReadProto(fp) if fp.GetAutoID() { s.AutoID = true } - s.Fields = append(s.Fields, NewField().ReadProto(fp)) + if field.PrimaryKey { + s.pkField = field + } + s.Fields = append(s.Fields, field) } s.EnableDynamicField = p.GetEnableDynamicField() return s @@ -127,12 +136,15 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema { // PKFieldName returns pk field name for this schemapb. func (s *Schema) PKFieldName() string { - for _, field := range s.Fields { - if field.PrimaryKey { - return field.Name - } + if s.pkField == nil { + return "" } - return "" + return s.pkField.Name +} + +// PKField returns PK Field schema for this schema. +func (s *Schema) PKField() *Field { + return s.pkField } // Field represent field schema in milvus diff --git a/client/index_options.go b/client/index_options.go index 272d6cdef84c7..b426ee9ade7f2 100644 --- a/client/index_options.go +++ b/client/index_options.go @@ -17,6 +17,8 @@ package client import ( + "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/index" @@ -31,15 +33,27 @@ type createIndexOption struct { fieldName string indexName string indexDef index.Index + + extraParams map[string]any +} + +func (opt *createIndexOption) WithExtraParam(key string, value any) { + opt.extraParams[key] = value } func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest { - return &milvuspb.CreateIndexRequest{ + params := opt.indexDef.Params() + for key, value := range opt.extraParams { + params[key] = fmt.Sprintf("%v", value) + } + req := &milvuspb.CreateIndexRequest{ CollectionName: opt.collectionName, FieldName: opt.fieldName, IndexName: opt.indexName, - ExtraParams: entity.MapKvPairs(opt.indexDef.Params()), + ExtraParams: entity.MapKvPairs(params), } + + return req } func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption { @@ -52,6 +66,7 @@ func NewCreateIndexOption(collectionName string, fieldName string, index index.I collectionName: collectionName, fieldName: fieldName, indexDef: index, + extraParams: make(map[string]any), } } diff --git a/client/interceptors.go b/client/interceptors.go new file mode 100644 index 0000000000000..16396c4aed7f9 --- /dev/null +++ b/client/interceptors.go @@ -0,0 +1,159 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +const ( + authorizationHeader = `authorization` + + identifierHeader = `identifier` + + databaseHeader = `dbname` +) + +func (c *Client) MetadataUnaryInterceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = c.metadata(ctx) + ctx = c.state(ctx) + + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func (c *Client) metadata(ctx context.Context) context.Context { + for k, v := range c.config.metadataHeaders { + ctx = metadata.AppendToOutgoingContext(ctx, k, v) + } + return ctx +} + +func (c *Client) state(ctx context.Context) context.Context { + c.stateMut.RLock() + defer c.stateMut.RUnlock() + + if c.currentDB != "" { + ctx = metadata.AppendToOutgoingContext(ctx, databaseHeader, c.currentDB) + } + if c.identifier != "" { + ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, c.identifier) + } + + return ctx +} + +// ref: https://github.com/grpc-ecosystem/go-grpc-middleware + +type ctxKey int + +const ( + RetryOnRateLimit ctxKey = iota +) + +// RetryOnRateLimitInterceptor returns a new retrying unary client interceptor. +func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor { + return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if maxRetry == 0 { + return invoker(parentCtx, method, req, reply, cc, opts...) + } + var lastErr error + for attempt := uint(0); attempt < maxRetry; attempt++ { + _, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc) + if err != nil { + return err + } + lastErr = invoker(parentCtx, method, req, reply, cc, opts...) + rspStatus := getResultStatus(reply) + if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit { + continue + } + return lastErr + } + return lastErr + } +} + +func retryOnRateLimit(ctx context.Context) bool { + retry, ok := ctx.Value(RetryOnRateLimit).(bool) + if !ok { + return true // default true + } + return retry +} + +// getResultStatus returns status of response. +func getResultStatus(reply interface{}) *commonpb.Status { + switch r := reply.(type) { + case *commonpb.Status: + return r + case *milvuspb.MutationResult: + return r.GetStatus() + case *milvuspb.BoolResponse: + return r.GetStatus() + case *milvuspb.SearchResults: + return r.GetStatus() + case *milvuspb.QueryResults: + return r.GetStatus() + case *milvuspb.FlushResponse: + return r.GetStatus() + default: + return nil + } +} + +func contextErrToGrpcErr(err error) error { + switch err { + case context.DeadlineExceeded: + return status.Error(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return status.Error(codes.Canceled, err.Error()) + default: + return status.Error(codes.Unknown, err.Error()) + } +} + +func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) { + var waitTime time.Duration + if attempt > 0 { + waitTime = backoffFunc(parentCtx, attempt) + } + if waitTime > 0 { + if waitTime > maxBackoff { + waitTime = maxBackoff + } + timer := time.NewTimer(waitTime) + select { + case <-parentCtx.Done(): + timer.Stop() + return waitTime, contextErrToGrpcErr(parentCtx.Err()) + case <-timer.C: + } + } + return waitTime, nil +} diff --git a/client/interceptors_test.go b/client/interceptors_test.go new file mode 100644 index 0000000000000..e3bcb34fcea66 --- /dev/null +++ b/client/interceptors_test.go @@ -0,0 +1,66 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +var mockInvokerError error +var mockInvokerReply interface{} +var mockInvokeTimes = 0 + +var mockInvoker grpc.UnaryInvoker = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + mockInvokeTimes++ + return mockInvokerError +} + +func resetMockInvokeTimes() { + mockInvokeTimes = 0 +} + +func TestRateLimitInterceptor(t *testing.T) { + maxRetry := uint(3) + maxBackoff := 3 * time.Second + inter := RetryOnRateLimitInterceptor(maxRetry, maxBackoff, func(ctx context.Context, attempt uint) time.Duration { + return 60 * time.Millisecond * time.Duration(math.Pow(2, float64(attempt))) + }) + + ctx := context.Background() + + // with retry + mockInvokerReply = &commonpb.Status{ErrorCode: commonpb.ErrorCode_RateLimit} + resetMockInvokeTimes() + err := inter(ctx, "", nil, mockInvokerReply, nil, mockInvoker) + assert.NoError(t, err) + assert.Equal(t, maxRetry, uint(mockInvokeTimes)) + + // without retry + ctx1 := context.WithValue(ctx, RetryOnRateLimit, false) + resetMockInvokeTimes() + err = inter(ctx1, "", nil, mockInvokerReply, nil, mockInvoker) + assert.NoError(t, err) + assert.Equal(t, uint(1), uint(mockInvokeTimes)) +} diff --git a/client/read.go b/client/read.go index 1d4a1e489f52a..3aeaff769d31b 100644 --- a/client/read.go +++ b/client/read.go @@ -33,7 +33,7 @@ type ResultSets struct{} type ResultSet struct { ResultCount int // the returning entry count - GroupByValue any + GroupByValue column.Column IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API Fields DataSet // output field data Scores []float32 // distance to the target vector @@ -67,35 +67,32 @@ func (c *Client) Search(ctx context.Context, option SearchOption, callOptions .. } func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) { - var err error sr := make([]ResultSet, 0, nq) results := resp.GetResults() offset := 0 fieldDataList := results.GetFieldsData() gb := results.GetGroupByFieldValue() - var gbc column.Column - if gb != nil { - gbc, err = column.FieldDataColumn(gb, 0, -1) - if err != nil { - return nil, err - } - } for i := 0; i < int(results.GetNumQueries()); i++ { rc := int(results.GetTopks()[i]) // result entry count for current query entry := ResultSet{ ResultCount: rc, Scores: results.GetScores()[offset : offset+rc], } - if gbc != nil { - entry.GroupByValue, _ = gbc.Get(i) - } // parse result set if current nq is not empty if rc > 0 { - entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc) + entry.IDs, entry.Err = column.IDColumns(schema, results.GetIds(), offset, offset+rc) if entry.Err != nil { offset += rc continue } + // parse group-by values + if gb != nil { + entry.GroupByValue, entry.Err = column.FieldDataColumn(gb, offset, offset+rc) + if entry.Err != nil { + offset += rc + continue + } + } entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc) sr = append(sr, entry) } diff --git a/client/read_options.go b/client/read_options.go index 90da8c21206a5..a1f563bfc0642 100644 --- a/client/read_options.go +++ b/client/read_options.go @@ -87,7 +87,7 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb. // search param bs, _ := json.Marshal(annRequest.searchParam) - request.SearchParams = entity.MapKvPairs(map[string]string{ + params := map[string]string{ spAnnsField: annRequest.annField, spTopK: strconv.Itoa(opt.topK), spOffset: strconv.Itoa(opt.offset), @@ -95,8 +95,11 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb. spMetricsType: string(annRequest.metricsType), spRoundDecimal: "-1", spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing), - spGroupBy: annRequest.groupByField, - }) + } + if annRequest.groupByField != "" { + params[spGroupBy] = annRequest.groupByField + } + request.SearchParams = entity.MapKvPairs(params) // placeholder group request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors) diff --git a/client/write.go b/client/write.go index 0487c110d1b51..d358fc0982264 100644 --- a/client/write.go +++ b/client/write.go @@ -22,53 +22,90 @@ import ( "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/client/v2/column" "github.com/milvus-io/milvus/pkg/util/merr" ) -func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error { +type InsertResult struct { + InsertCount int64 + IDs column.Column +} + +func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) (InsertResult, error) { + result := InsertResult{} collection, err := c.getCollection(ctx, option.CollectionName()) if err != nil { - return err + return result, err } req, err := option.InsertRequest(collection) if err != nil { - return err + return result, err } + err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { resp, err := milvusService.Insert(ctx, req, callOptions...) - return merr.CheckRPCCall(resp, err) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + result.InsertCount = resp.GetInsertCnt() + result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1) + if err != nil { + return err + } + + return nil }) - return err + return result, err +} + +type DeleteResult struct { + DeleteCount int64 } -func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error { +func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) (DeleteResult, error) { req := option.Request() - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + result := DeleteResult{} + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { resp, err := milvusService.Delete(ctx, req, callOptions...) if err = merr.CheckRPCCall(resp, err); err != nil { return err } + result.DeleteCount = resp.GetDeleteCnt() return nil }) + return result, err +} + +type UpsertResult struct { + UpsertCount int64 + IDs column.Column } -func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error { +func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) (UpsertResult, error) { + result := UpsertResult{} collection, err := c.getCollection(ctx, option.CollectionName()) if err != nil { - return err + return result, err } req, err := option.UpsertRequest(collection) if err != nil { - return err + return result, err } - - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { resp, err := milvusService.Upsert(ctx, req, callOptions...) if err = merr.CheckRPCCall(resp, err); err != nil { return err } + result.UpsertCount = resp.GetUpsertCnt() + result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1) + if err != nil { + return err + } return nil }) + return result, err } diff --git a/client/write_test.go b/client/write_test.go index 4fa27ff7c43cc..3fdb9ece0f615 100644 --- a/client/write_test.go +++ b/client/write_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/samber/lo" @@ -63,16 +64,25 @@ func (s *WriteSuite) TestInsert() { s.Require().Len(ir.GetFieldsData(), 2) s.EqualValues(3, ir.GetNumRows()) return &milvuspb.MutationResult{ - Status: merr.Success(), + Status: merr.Success(), + InsertCnt: 3, + IDs: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + }, }, nil }).Once() - err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). + result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) })). WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) s.NoError(err) + s.EqualValues(3, result.InsertCount) }) s.Run("dynamic_schema", func() { @@ -86,17 +96,26 @@ func (s *WriteSuite) TestInsert() { s.Require().Len(ir.GetFieldsData(), 3) s.EqualValues(3, ir.GetNumRows()) return &milvuspb.MutationResult{ - Status: merr.Success(), + Status: merr.Success(), + InsertCnt: 3, + IDs: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + }, }, nil }).Once() - err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). + result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) })). WithVarcharColumn("extra", []string{"a", "b", "c"}). WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) s.NoError(err) + s.EqualValues(3, result.InsertCount) }) s.Run("bad_input", func() { @@ -141,7 +160,7 @@ func (s *WriteSuite) TestInsert() { for _, tc := range cases { s.Run(tc.tag, func() { - err := s.client.Insert(ctx, tc.input) + _, err := s.client.Insert(ctx, tc.input) s.Error(err) }) } @@ -153,7 +172,7 @@ func (s *WriteSuite) TestInsert() { s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() - err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). + _, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) })). @@ -177,16 +196,25 @@ func (s *WriteSuite) TestUpsert() { s.Require().Len(ur.GetFieldsData(), 2) s.EqualValues(3, ur.GetNumRows()) return &milvuspb.MutationResult{ - Status: merr.Success(), + Status: merr.Success(), + UpsertCnt: 3, + IDs: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + }, }, nil }).Once() - err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). + result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) })). WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) s.NoError(err) + s.EqualValues(3, result.UpsertCount) }) s.Run("dynamic_schema", func() { @@ -200,17 +228,26 @@ func (s *WriteSuite) TestUpsert() { s.Require().Len(ur.GetFieldsData(), 3) s.EqualValues(3, ur.GetNumRows()) return &milvuspb.MutationResult{ - Status: merr.Success(), + Status: merr.Success(), + UpsertCnt: 3, + IDs: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + }, }, nil }).Once() - err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). + result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) })). WithVarcharColumn("extra", []string{"a", "b", "c"}). WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) s.NoError(err) + s.EqualValues(3, result.UpsertCount) }) s.Run("bad_input", func() { @@ -255,7 +292,7 @@ func (s *WriteSuite) TestUpsert() { for _, tc := range cases { s.Run(tc.tag, func() { - err := s.client.Upsert(ctx, tc.input) + _, err := s.client.Upsert(ctx, tc.input) s.Error(err) }) } @@ -267,7 +304,7 @@ func (s *WriteSuite) TestUpsert() { s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() - err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). + _, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) })). @@ -315,11 +352,13 @@ func (s *WriteSuite) TestDelete() { s.Equal(partName, dr.GetPartitionName()) s.Equal(tc.expectExpr, dr.GetExpr()) return &milvuspb.MutationResult{ - Status: merr.Success(), + Status: merr.Success(), + DeleteCnt: 100, }, nil }).Once() - err := s.client.Delete(ctx, tc.input) + result, err := s.client.Delete(ctx, tc.input) s.NoError(err) + s.EqualValues(100, result.DeleteCount) }) } })