From a1ed98d5e4c829291fe6b344a2dec801c214c831 Mon Sep 17 00:00:00 2001 From: Congqi Xia Date: Fri, 20 Dec 2024 01:55:08 +0800 Subject: [PATCH] enhance: [GoSDK] Sync API names and add missing APIs Related to #31293 - Rename `UsingDatabase` to `UseDatabase` - Uncomment default value methods - Add missing RBAC APIs - Add some resource group APIs Signed-off-by: Congqi Xia --- client/entity/common.go | 11 + client/entity/field.go | 10 +- client/entity/field_test.go | 14 +- client/entity/load_state.go | 34 ++ client/entity/rbac.go | 35 ++ client/milvusclient/collection.go | 20 + client/milvusclient/collection_options.go | 14 + client/milvusclient/collection_test.go | 33 ++ client/milvusclient/database.go | 2 +- client/milvusclient/database_options.go | 10 +- client/milvusclient/database_test.go | 20 + client/milvusclient/maintenance.go | 83 ++++- client/milvusclient/maintenance_options.go | 135 +++++++ client/milvusclient/maintenance_test.go | 168 +++++++++ client/milvusclient/partition.go | 18 + client/milvusclient/partition_options.go | 23 ++ client/milvusclient/partition_test.go | 33 ++ client/milvusclient/rbac.go | 148 ++++++++ client/milvusclient/rbac_options.go | 308 +++++++++++++++ client/milvusclient/rbac_test.go | 370 +++++++++++++++++++ client/milvusclient/read.go | 4 + client/milvusclient/read_option_test.go | 12 +- client/milvusclient/read_options.go | 91 ++++- client/milvusclient/resource_group.go | 65 ++++ client/milvusclient/resource_group_option.go | 92 +++++ client/milvusclient/resource_group_test.go | 108 ++++++ 26 files changed, 1824 insertions(+), 37 deletions(-) create mode 100644 client/entity/load_state.go create mode 100644 client/entity/rbac.go create mode 100644 client/milvusclient/resource_group.go create mode 100644 client/milvusclient/resource_group_option.go create mode 100644 client/milvusclient/resource_group_test.go diff --git a/client/entity/common.go b/client/entity/common.go index ec794e3db197a..f14c47f627e4d 100644 --- a/client/entity/common.go +++ b/client/entity/common.go @@ -16,6 +16,8 @@ package entity +import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + // MetricType metric type type MetricType string @@ -31,3 +33,12 @@ const ( SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE" BM25 MetricType = "BM25" ) + +// CompactionState enum type for compaction state +type CompactionState commonpb.CompactionState + +// CompactionState Constants +const ( + CompactionStateRunning CompactionState = CompactionState(commonpb.CompactionState_Executing) + CompactionStateCompleted CompactionState = CompactionState(commonpb.CompactionState_Completed) +) diff --git a/client/entity/field.go b/client/entity/field.go index d2765ae74745c..790ad36b56d0a 100644 --- a/client/entity/field.go +++ b/client/entity/field.go @@ -193,6 +193,8 @@ type Field struct { IsPartitionKey bool IsClusteringKey bool ElementType FieldType + DefaultValue *schemapb.ValueField + Nullable bool } // ProtoMessage generates corresponding FieldSchema @@ -261,7 +263,11 @@ func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field { return f } -/* +func (f *Field) WithNullable(nullable bool) *Field { + f.Nullable = nullable + return f +} + func (f *Field) WithDefaultValueBool(defaultValue bool) *Field { f.DefaultValue = &schemapb.ValueField{ Data: &schemapb.ValueField_BoolData{ @@ -314,7 +320,7 @@ func (f *Field) WithDefaultValueString(defaultValue string) *Field { }, } return f -}*/ +} func (f *Field) WithTypeParams(key string, value string) *Field { if f.TypeParams == nil { diff --git a/client/entity/field_test.go b/client/entity/field_test.go index 3528b36a2a6ad..c8a967bfe6735 100644 --- a/client/entity/field_test.go +++ b/client/entity/field_test.go @@ -30,13 +30,13 @@ func TestFieldSchema(t *testing.T) { NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128), NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true), NewField().WithName("varchar_text").WithDataType(FieldTypeVarChar).WithMaxLength(65535).WithEnableAnalyzer(true).WithAnalyzerParams(map[string]any{}), - /* - NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true), - NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1), - NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1), - NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1), - NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1), - NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/ + + NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true), + NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1), + NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1), + NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1), + NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1), + NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"), } for _, field := range fields { diff --git a/client/entity/load_state.go b/client/entity/load_state.go new file mode 100644 index 0000000000000..b4dfe7913ad18 --- /dev/null +++ b/client/entity/load_state.go @@ -0,0 +1,34 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package entity + +import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + +type LoadStateCode commonpb.LoadState + +const ( + // LoadStateNone LoadStateCode = LoadStateCode(commonpb.LoadState) + LoadStateLoading LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateLoading) + LoadStateLoaded LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateLoaded) + LoadStateUnloading LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateNotExist) + LoadStateNotLoad LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateNotLoad) +) + +type LoadState struct { + State LoadStateCode + Progress int64 +} diff --git a/client/entity/rbac.go b/client/entity/rbac.go new file mode 100644 index 0000000000000..b3c3fa11d2e89 --- /dev/null +++ b/client/entity/rbac.go @@ -0,0 +1,35 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package entity + +type User struct { + UserName string + Roles []string +} + +type Role struct { + RoleName string + Privileges []GrantItem +} + +type GrantItem struct { + Object string + ObjectName string + RoleName string + Grantor string + Privilege string +} diff --git a/client/milvusclient/collection.go b/client/milvusclient/collection.go index 253c38ff82dc4..025942bcfa865 100644 --- a/client/milvusclient/collection.go +++ b/client/milvusclient/collection.go @@ -147,3 +147,23 @@ func (c *Client) AlterCollection(ctx context.Context, option AlterCollectionOpti return merr.CheckRPCCall(resp, err) }) } + +type GetCollectionOption interface { + Request() *milvuspb.GetCollectionStatisticsRequest +} + +func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption) (map[string]string, error) { + var stats map[string]string + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.GetCollectionStatistics(ctx, opt.Request()) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + stats = entity.KvPairsMap(resp.GetStats()) + return nil + }) + if err != nil { + return nil, err + } + return stats, nil +} diff --git a/client/milvusclient/collection_options.go b/client/milvusclient/collection_options.go index cb01cb6636921..0907cb6c8f2c0 100644 --- a/client/milvusclient/collection_options.go +++ b/client/milvusclient/collection_options.go @@ -310,3 +310,17 @@ func (opt *alterCollectionOption) Request() *milvuspb.AlterCollectionRequest { func NewAlterCollectionOption(collection string) *alterCollectionOption { return &alterCollectionOption{collectionName: collection, properties: make(map[string]string)} } + +type getCollectionStatsOption struct { + collectionName string +} + +func (opt *getCollectionStatsOption) Request() *milvuspb.GetCollectionStatisticsRequest { + return &milvuspb.GetCollectionStatisticsRequest{ + CollectionName: opt.collectionName, + } +} + +func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOption { + return &getCollectionStatsOption{collectionName: collectionName} +} diff --git a/client/milvusclient/collection_test.go b/client/milvusclient/collection_test.go index 4d4b399302f32..4dc8e62d87629 100644 --- a/client/milvusclient/collection_test.go +++ b/client/milvusclient/collection_test.go @@ -21,6 +21,7 @@ import ( "fmt" "testing" + "github.com/cockroachdb/errors" "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -315,6 +316,38 @@ func (s *CollectionSuite) TestAlterCollection() { }) } +func (s *CollectionSuite) TestGetCollectionStats() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collName := fmt.Sprintf("coll_%s", s.randString(6)) + s.mock.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gcsr *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) { + s.Equal(collName, gcsr.GetCollectionName()) + return &milvuspb.GetCollectionStatisticsResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Stats: []*commonpb.KeyValuePair{ + {Key: "row_count", Value: "1000"}, + }, + }, nil + }).Once() + + stats, err := s.client.GetCollectionStats(ctx, NewGetCollectionStatsOption(collName)) + s.NoError(err) + + s.Len(stats, 1) + s.Equal("1000", stats["row_count"]) + }) + + s.Run("failure", func() { + collName := fmt.Sprintf("coll_%s", s.randString(6)) + s.mock.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once() + + _, err := s.client.GetCollectionStats(ctx, NewGetCollectionStatsOption(collName)) + s.Error(err) + }) +} + func TestCollection(t *testing.T) { suite.Run(t, new(CollectionSuite)) } diff --git a/client/milvusclient/database.go b/client/milvusclient/database.go index b47db2d293c13..eb5b352963b85 100644 --- a/client/milvusclient/database.go +++ b/client/milvusclient/database.go @@ -25,7 +25,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) -func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error { +func (c *Client) UseDatabase(ctx context.Context, option UseDatabaseOption) error { dbName := option.DbName() c.usingDatabase(dbName) return c.connectInternal(ctx) diff --git a/client/milvusclient/database_options.go b/client/milvusclient/database_options.go index 9562b71491870..48542f9e580e0 100644 --- a/client/milvusclient/database_options.go +++ b/client/milvusclient/database_options.go @@ -18,20 +18,20 @@ package milvusclient import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" -type UsingDatabaseOption interface { +type UseDatabaseOption interface { DbName() string } -type usingDatabaseNameOpt struct { +type useDatabaseNameOpt struct { dbName string } -func (opt *usingDatabaseNameOpt) DbName() string { +func (opt *useDatabaseNameOpt) DbName() string { return opt.dbName } -func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt { - return &usingDatabaseNameOpt{ +func NewUseDatabaseOption(dbName string) *useDatabaseNameOpt { + return &useDatabaseNameOpt{ dbName: dbName, } } diff --git a/client/milvusclient/database_test.go b/client/milvusclient/database_test.go index 262970dd8afc3..3cb3b7017f4fe 100644 --- a/client/milvusclient/database_test.go +++ b/client/milvusclient/database_test.go @@ -88,6 +88,26 @@ func (s *DatabaseSuite) TestDropDatabase() { }) } +func (s *DatabaseSuite) TestUseDatabase() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + dbName := fmt.Sprintf("dt_%s", s.randString(6)) + s.mock.EXPECT().Connect(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cr *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) { + return &milvuspb.ConnectResponse{ + Status: merr.Success(), + ServerInfo: &commonpb.ServerInfo{}, + }, nil + }).Once() + + err := s.client.UseDatabase(ctx, NewUseDatabaseOption(dbName)) + s.NoError(err) + + s.Equal(dbName, s.client.currentDB) + }) +} + func TestDatabase(t *testing.T) { suite.Run(t, new(DatabaseSuite)) } diff --git a/client/milvusclient/maintenance.go b/client/milvusclient/maintenance.go index 4a5f75d763bc0..bbf7a636ebe3b 100644 --- a/client/milvusclient/maintenance.go +++ b/client/milvusclient/maintenance.go @@ -23,6 +23,7 @@ import ( "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -31,6 +32,7 @@ type LoadTask struct { collectionName string partitionNames []string interval time.Duration + refresh bool } func (t *LoadTask) Await(ctx context.Context) error { @@ -40,6 +42,7 @@ func (t *LoadTask) Await(ctx context.Context) error { select { case <-timer.C: loaded := false + refreshed := false err := t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { resp, err := milvusService.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ CollectionName: t.collectionName, @@ -49,12 +52,13 @@ func (t *LoadTask) Await(ctx context.Context) error { return err } loaded = resp.GetProgress() == 100 + refreshed = resp.GetRefreshProgress() == 100 return nil }) if err != nil { return err } - if loaded { + if (loaded && !t.refresh) || (refreshed && t.refresh) { return nil } if !timer.Stop() { @@ -85,6 +89,7 @@ func (c *Client) LoadCollection(ctx context.Context, option LoadCollectionOption client: c, collectionName: req.GetCollectionName(), interval: option.CheckInterval(), + refresh: option.IsRefresh(), } return nil @@ -108,6 +113,7 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption collectionName: req.GetCollectionName(), partitionNames: req.GetPartitionNames(), interval: option.CheckInterval(), + refresh: option.IsRefresh(), } return nil @@ -115,6 +121,35 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption return task, err } +func (c *Client) GetLoadState(ctx context.Context, option GetLoadStateOption, callOptions ...grpc.CallOption) (entity.LoadState, error) { + req := option.Request() + + var state entity.LoadState + var err error + + if err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.GetLoadState(ctx, req, callOptions...) + state.State = entity.LoadStateCode(resp.GetState()) + return merr.CheckRPCCall(resp, err) + }); err != nil { + return state, err + } + + // get progress if state is loading + if state.State == entity.LoadStateLoading { + err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.GetLoadingProgress(ctx, option.ProgressRequest(), callOptions...) + if err := merr.CheckRPCCall(resp, err); err != nil { + return err + } + + state.Progress = resp.GetProgress() + return nil + }) + } + return state, err +} + func (c *Client) ReleaseCollection(ctx context.Context, option ReleaseCollectionOption, callOptions ...grpc.CallOption) error { req := option.Request() @@ -134,6 +169,26 @@ func (c *Client) ReleasePartitions(ctx context.Context, option ReleasePartitions }) } +func (c *Client) RefreshLoad(ctx context.Context, option RefreshLoadOption, callOptions ...grpc.CallOption) (LoadTask, error) { + req := option.Request() + var task LoadTask + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.LoadCollection(ctx, req, callOptions...) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + task = LoadTask{ + client: c, + collectionName: req.GetCollectionName(), + interval: option.CheckInterval(), + refresh: true, + } + return nil + }) + return task, err +} + type FlushTask struct { client *Client collectionName string @@ -206,3 +261,29 @@ func (c *Client) Flush(ctx context.Context, option FlushOption, callOptions ...g }) return task, err } + +func (c *Client) Compact(ctx context.Context, option CompactOption, callOptions ...grpc.CallOption) (int64, error) { + req := option.Request() + + var jobID int64 + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.ManualCompaction(ctx, req, callOptions...) + jobID = resp.GetCompactionID() + return merr.CheckRPCCall(resp, err) + }) + return jobID, err +} + +func (c *Client) GetCompactionState(ctx context.Context, option GetCompactionStateOption, callOptions ...grpc.CallOption) (entity.CompactionState, error) { + req := option.Request() + + var status entity.CompactionState + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.GetCompactionState(ctx, req, callOptions...) + status = entity.CompactionState(resp.GetState()) + return merr.CheckRPCCall(resp, err) + }) + return status, err +} diff --git a/client/milvusclient/maintenance_options.go b/client/milvusclient/maintenance_options.go index 566168ad874d7..b0476ac34294b 100644 --- a/client/milvusclient/maintenance_options.go +++ b/client/milvusclient/maintenance_options.go @@ -25,6 +25,7 @@ import ( type LoadCollectionOption interface { Request() *milvuspb.LoadCollectionRequest CheckInterval() time.Duration + IsRefresh() bool } type loadCollectionOption struct { @@ -33,6 +34,8 @@ type loadCollectionOption struct { replicaNum int loadFields []string skipLoadDynamicField bool + isRefresh bool + resourceGroups []string } func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest { @@ -41,6 +44,7 @@ func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest { ReplicaNumber: int32(opt.replicaNum), LoadFields: opt.loadFields, SkipLoadDynamicField: opt.skipLoadDynamicField, + ResourceGroups: opt.resourceGroups, } } @@ -48,11 +52,20 @@ func (opt *loadCollectionOption) CheckInterval() time.Duration { return opt.interval } +func (opt *loadCollectionOption) IsRefresh() bool { + return opt.isRefresh +} + func (opt *loadCollectionOption) WithReplica(num int) *loadCollectionOption { opt.replicaNum = num return opt } +func (opt *loadCollectionOption) WithResourceGroup(resourceGroups ...string) *loadCollectionOption { + opt.resourceGroups = resourceGroups + return opt +} + func (opt *loadCollectionOption) WithLoadFields(loadFields ...string) *loadCollectionOption { opt.loadFields = loadFields return opt @@ -63,6 +76,11 @@ func (opt *loadCollectionOption) WithSkipLoadDynamicField(skipFlag bool) *loadCo return opt } +func (opt *loadCollectionOption) WithRefresh(isRefresh bool) *loadCollectionOption { + opt.isRefresh = isRefresh + return opt +} + func NewLoadCollectionOption(collectionName string) *loadCollectionOption { return &loadCollectionOption{ collectionName: collectionName, @@ -74,6 +92,7 @@ func NewLoadCollectionOption(collectionName string) *loadCollectionOption { type LoadPartitionsOption interface { Request() *milvuspb.LoadPartitionsRequest CheckInterval() time.Duration + IsRefresh() bool } var _ LoadPartitionsOption = (*loadPartitionsOption)(nil) @@ -83,8 +102,10 @@ type loadPartitionsOption struct { partitionNames []string interval time.Duration replicaNum int + resourceGroups []string loadFields []string skipLoadDynamicField bool + isRefresh bool } func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest { @@ -94,6 +115,7 @@ func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest { ReplicaNumber: int32(opt.replicaNum), LoadFields: opt.loadFields, SkipLoadDynamicField: opt.skipLoadDynamicField, + ResourceGroups: opt.resourceGroups, } } @@ -101,11 +123,20 @@ func (opt *loadPartitionsOption) CheckInterval() time.Duration { return opt.interval } +func (opt *loadPartitionsOption) IsRefresh() bool { + return opt.isRefresh +} + func (opt *loadPartitionsOption) WithReplica(num int) *loadPartitionsOption { opt.replicaNum = num return opt } +func (opt *loadPartitionsOption) WithResourceGroup(resourceGroups ...string) *loadPartitionsOption { + opt.resourceGroups = resourceGroups + return opt +} + func (opt *loadPartitionsOption) WithLoadFields(loadFields ...string) *loadPartitionsOption { opt.loadFields = loadFields return opt @@ -116,6 +147,11 @@ func (opt *loadPartitionsOption) WithSkipLoadDynamicField(skipFlag bool) *loadPa return opt } +func (opt *loadPartitionsOption) WithRefresh(isRefresh bool) *loadPartitionsOption { + opt.isRefresh = isRefresh + return opt +} + func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) *loadPartitionsOption { return &loadPartitionsOption{ collectionName: collectionName, @@ -125,6 +161,65 @@ func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) * } } +type GetLoadStateOption interface { + Request() *milvuspb.GetLoadStateRequest + ProgressRequest() *milvuspb.GetLoadingProgressRequest +} + +type getLoadStateOption struct { + collectionName string + partitionNames []string +} + +func (opt *getLoadStateOption) Request() *milvuspb.GetLoadStateRequest { + return &milvuspb.GetLoadStateRequest{ + CollectionName: opt.collectionName, + PartitionNames: opt.partitionNames, + } +} + +func (opt *getLoadStateOption) ProgressRequest() *milvuspb.GetLoadingProgressRequest { + return &milvuspb.GetLoadingProgressRequest{ + CollectionName: opt.collectionName, + PartitionNames: opt.partitionNames, + } +} + +func NewGetLoadStateOption(collectionName string, partitionNames ...string) *getLoadStateOption { + return &getLoadStateOption{ + collectionName: collectionName, + partitionNames: partitionNames, + } +} + +type RefreshLoadOption interface { + Request() *milvuspb.LoadCollectionRequest + CheckInterval() time.Duration +} + +type refreshLoadOption struct { + collectionName string + checkInterval time.Duration +} + +func (opt *refreshLoadOption) Request() *milvuspb.LoadCollectionRequest { + return &milvuspb.LoadCollectionRequest{ + CollectionName: opt.collectionName, + Refresh: true, + } +} + +func (opt *refreshLoadOption) CheckInterval() time.Duration { + return opt.checkInterval +} + +func NewRefreshLoadOption(collectionName string) *refreshLoadOption { + return &refreshLoadOption{ + collectionName: collectionName, + checkInterval: time.Millisecond * 200, + } +} + type ReleaseCollectionOption interface { Request() *milvuspb.ReleaseCollectionRequest } @@ -203,3 +298,43 @@ func NewFlushOption(collName string) *flushOption { interval: time.Millisecond * 200, } } + +type CompactOption interface { + Request() *milvuspb.ManualCompactionRequest +} + +type compactOption struct { + collectionName string +} + +func (opt *compactOption) Request() *milvuspb.ManualCompactionRequest { + return &milvuspb.ManualCompactionRequest{ + CollectionName: opt.collectionName, + } +} + +func NewCompactOption(collectionName string) *compactOption { + return &compactOption{ + collectionName: collectionName, + } +} + +type GetCompactionStateOption interface { + Request() *milvuspb.GetCompactionStateRequest +} + +type getCompactionStateOption struct { + compactionID int64 +} + +func (opt *getCompactionStateOption) Request() *milvuspb.GetCompactionStateRequest { + return &milvuspb.GetCompactionStateRequest{ + CompactionID: opt.compactionID, + } +} + +func NewGetCompactionStateOption(compactionID int64) *getCompactionStateOption { + return &getCompactionStateOption{ + compactionID: compactionID, + } +} diff --git a/client/milvusclient/maintenance_test.go b/client/milvusclient/maintenance_test.go index f6363c2aa42ca..a41f639a58737 100644 --- a/client/milvusclient/maintenance_test.go +++ b/client/milvusclient/maintenance_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -44,6 +45,7 @@ func (s *MaintenanceSuite) TestLoadCollection() { collectionName := fmt.Sprintf("coll_%s", s.randString(6)) fieldNames := []string{"id", "part", "vector"} replicaNum := rand.Intn(3) + 1 + rgs := []string{"rg1", "rg2"} done := atomic.NewBool(false) s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) { @@ -51,6 +53,7 @@ func (s *MaintenanceSuite) TestLoadCollection() { s.ElementsMatch(fieldNames, lcr.GetLoadFields()) s.True(lcr.SkipLoadDynamicField) s.EqualValues(replicaNum, lcr.GetReplicaNumber()) + s.ElementsMatch(rgs, lcr.GetResourceGroups()) return merr.Success(), nil }).Once() s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { @@ -70,6 +73,7 @@ func (s *MaintenanceSuite) TestLoadCollection() { task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName). WithReplica(replicaNum). + WithResourceGroup(rgs...). WithLoadFields(fieldNames...). WithSkipLoadDynamicField(true)) s.NoError(err) @@ -114,6 +118,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() { partitionName := fmt.Sprintf("part_%s", s.randString(6)) fieldNames := []string{"id", "part", "vector"} replicaNum := rand.Intn(3) + 1 + rgs := []string{"rg1", "rg2"} done := atomic.NewBool(false) s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lpr *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) { @@ -122,6 +127,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() { s.ElementsMatch(fieldNames, lpr.GetLoadFields()) s.True(lpr.SkipLoadDynamicField) s.EqualValues(replicaNum, lpr.GetReplicaNumber()) + s.ElementsMatch(rgs, lpr.GetResourceGroups()) return merr.Success(), nil }).Once() s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { @@ -142,6 +148,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() { task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, partitionName). WithReplica(replicaNum). + WithResourceGroup(rgs...). WithLoadFields(fieldNames...). WithSkipLoadDynamicField(true)) s.NoError(err) @@ -293,6 +300,167 @@ func (s *MaintenanceSuite) TestFlush() { }) } +func (s *MaintenanceSuite) TestRefreshLoad() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + done := atomic.NewBool(false) + s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) { + s.Equal(collectionName, lcr.GetCollectionName()) + s.True(lcr.GetRefresh()) + return merr.Success(), nil + }).Once() + s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { + s.Equal(collectionName, glpr.GetCollectionName()) + + progress := int64(50) + if done.Load() { + progress = 100 + } + + return &milvuspb.GetLoadingProgressResponse{ + Status: merr.Success(), + RefreshProgress: progress, + }, nil + }) + defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset() + + task, err := s.client.RefreshLoad(ctx, NewRefreshLoadOption(collectionName)) + s.NoError(err) + + ch := make(chan struct{}) + go func() { + defer close(ch) + err := task.Await(ctx) + s.NoError(err) + }() + + select { + case <-ch: + s.FailNow("task done before index state set to finish") + case <-time.After(time.Second): + } + + done.Store(true) + + select { + case <-ch: + case <-time.After(time.Second): + s.FailNow("task not done after index set finished") + } + }) + + s.Run("failure", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.RefreshLoad(ctx, NewRefreshLoadOption(collectionName)) + s.Error(err) + }) +} + +func (s *MaintenanceSuite) TestCompact() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + compactID := rand.Int63() + + s.mock.EXPECT().ManualCompaction(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cr *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { + s.Equal(collectionName, cr.GetCollectionName()) + return &milvuspb.ManualCompactionResponse{ + CompactionID: compactID, + }, nil + }).Once() + + id, err := s.client.Compact(ctx, NewCompactOption(collectionName)) + s.NoError(err) + s.Equal(compactID, id) + }) + + s.Run("failure", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.Compact(ctx, NewCompactOption(collectionName)) + s.Error(err) + }) +} + +func (s *MaintenanceSuite) TestGetCompactionState() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + compactID := rand.Int63() + + s.mock.EXPECT().GetCompactionState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gcsr *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { + s.Equal(compactID, gcsr.GetCompactionID()) + return &milvuspb.GetCompactionStateResponse{ + Status: merr.Success(), + State: commonpb.CompactionState_Completed, + }, nil + }).Once() + + state, err := s.client.GetCompactionState(ctx, NewGetCompactionStateOption(compactID)) + s.NoError(err) + s.Equal(entity.CompactionStateCompleted, state) + }) + + s.Run("failure", func() { + compactID := rand.Int63() + + s.mock.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.GetCompactionState(ctx, NewGetCompactionStateOption(compactID)) + s.Error(err) + }) +} + +func (s *MaintenanceSuite) TestGetLoadState() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + progress := rand.Int63n(100) + + s.mock.EXPECT().GetLoadState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glsr *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) { + s.Equal(collectionName, glsr.GetCollectionName()) + return &milvuspb.GetLoadStateResponse{ + Status: merr.Success(), + State: commonpb.LoadState_LoadStateLoading, + }, nil + }).Once() + s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { + s.Equal(collectionName, glpr.GetCollectionName()) + return &milvuspb.GetLoadingProgressResponse{ + Status: merr.Success(), + Progress: progress, + }, nil + }).Once() + + state, err := s.client.GetLoadState(ctx, NewGetLoadStateOption(collectionName)) + s.NoError(err) + s.Equal(entity.LoadStateLoading, state.State) + s.Equal(progress, state.Progress) + }) + + s.Run("failure", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.GetLoadState(ctx, NewGetLoadStateOption(collectionName)) + s.Error(err) + }) +} + func TestMaintenance(t *testing.T) { suite.Run(t, new(MaintenanceSuite)) } diff --git a/client/milvusclient/partition.go b/client/milvusclient/partition.go index 99cd00ced43cb..63a48c5766fa5 100644 --- a/client/milvusclient/partition.go +++ b/client/milvusclient/partition.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -76,3 +77,20 @@ func (c *Client) ListPartitions(ctx context.Context, opt ListPartitionsOption, c }) return partitionNames, err } + +func (c *Client) GetPartitionStats(ctx context.Context, opt GetPartitionStatsOption, callOptions ...grpc.CallOption) (map[string]string, error) { + req := opt.Request() + + var result map[string]string + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.GetPartitionStatistics(ctx, req, callOptions...) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + result = entity.KvPairsMap(resp.GetStats()) + return nil + }) + return result, err +} diff --git a/client/milvusclient/partition_options.go b/client/milvusclient/partition_options.go index b67cce46cad94..f34e8bd1f703c 100644 --- a/client/milvusclient/partition_options.go +++ b/client/milvusclient/partition_options.go @@ -117,3 +117,26 @@ func NewListPartitionOption(collectionName string) *listPartitionsOpt { collectionName: collectionName, } } + +type GetPartitionStatsOption interface { + Request() *milvuspb.GetPartitionStatisticsRequest +} + +type getPartitionStatsOpt struct { + collectionName string + partitionName string +} + +func (opt *getPartitionStatsOpt) Request() *milvuspb.GetPartitionStatisticsRequest { + return &milvuspb.GetPartitionStatisticsRequest{ + CollectionName: opt.collectionName, + PartitionName: opt.partitionName, + } +} + +func NewGetPartitionStatsOption(collectionName string, partitionName string) *getPartitionStatsOpt { + return &getPartitionStatsOpt{ + collectionName: collectionName, + partitionName: partitionName, + } +} diff --git a/client/milvusclient/partition_test.go b/client/milvusclient/partition_test.go index 9f5b843a51612..eb2b7c26aae62 100644 --- a/client/milvusclient/partition_test.go +++ b/client/milvusclient/partition_test.go @@ -162,6 +162,39 @@ func (s *PartitionSuite) TestDropPartition() { }) } +func (s *PartitionSuite) TestGetPartitionStats() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + partitionName := fmt.Sprintf("part_%s", s.randString(6)) + s.mock.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gpsr *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) { + s.Equal(collectionName, gpsr.GetCollectionName()) + s.Equal(partitionName, gpsr.GetPartitionName()) + return &milvuspb.GetPartitionStatisticsResponse{ + Status: merr.Success(), + Stats: []*commonpb.KeyValuePair{ + {Key: "rows", Value: "100"}, + }, + }, nil + }).Once() + + stats, err := s.client.GetPartitionStats(ctx, NewGetPartitionStatsOption(collectionName, partitionName)) + s.NoError(err) + s.Equal("100", stats["rows"]) + }) + + s.Run("failure", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + partitionName := fmt.Sprintf("part_%s", s.randString(6)) + s.mock.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.GetPartitionStats(ctx, NewGetPartitionStatsOption(collectionName, partitionName)) + s.Error(err) + }) +} + func TestPartition(t *testing.T) { suite.Run(t, new(PartitionSuite)) } diff --git a/client/milvusclient/rbac.go b/client/milvusclient/rbac.go index 8abfe1d74790f..f769ceae0c4a4 100644 --- a/client/milvusclient/rbac.go +++ b/client/milvusclient/rbac.go @@ -19,6 +19,7 @@ package milvusclient import ( "context" + "github.com/cockroachdb/errors" "github.com/samber/lo" "google.golang.org/grpc" @@ -27,6 +28,153 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +func (c *Client) ListUsers(ctx context.Context, opt ListUserOption, callOpts ...grpc.CallOption) ([]string, error) { + var users []string + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.ListCredUsers(ctx, opt.Request(), callOpts...) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + users = resp.GetUsernames() + return nil + }) + return users, err +} + +func (c *Client) DescribeUser(ctx context.Context, opt DescribeUserOption, callOpts ...grpc.CallOption) (*entity.User, error) { + var user *entity.User + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.SelectUser(ctx, opt.Request(), callOpts...) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + if len(resp.GetResults()) == 0 { + return errors.New("not user found") + } + result := resp.GetResults()[0] + user = &entity.User{ + UserName: result.GetUser().GetName(), + Roles: lo.Map(result.GetRoles(), func(r *milvuspb.RoleEntity, _ int) string { return r.GetName() }), + } + + return nil + }) + + return user, err +} + +func (c *Client) CreateUser(ctx context.Context, opt CreateUserOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.CreateCredential(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) UpdatePassword(ctx context.Context, opt UpdatePasswordOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.UpdateCredential(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) DropUser(ctx context.Context, opt DropUserOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.DeleteCredential(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) ListRoles(ctx context.Context, opt ListRoleOption, callOpts ...grpc.CallOption) ([]string, error) { + var roles []string + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.SelectRole(ctx, opt.Request(), callOpts...) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + roles = lo.Map(resp.GetResults(), func(r *milvuspb.RoleResult, _ int) string { + return r.GetRole().GetName() + }) + return nil + }) + return roles, err +} + +func (c *Client) CreateRole(ctx context.Context, opt CreateRoleOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.CreateRole(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) GrantRole(ctx context.Context, opt GrantRoleOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperateUserRole(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) RevokeRole(ctx context.Context, opt RevokeRoleOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperateUserRole(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) DropRole(ctx context.Context, opt DropRoleOption, callOpts ...grpc.CallOption) error { + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.DropRole(ctx, opt.Request(), callOpts...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) DescribeRole(ctx context.Context, option DescribeRoleOption, callOptions ...grpc.CallOption) (*entity.Role, error) { + req := option.Request() + + var role *entity.Role + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.SelectGrant(ctx, req, callOptions...) + if err := merr.CheckRPCCall(resp, err); err != nil { + return err + } + if len(resp.GetEntities()) == 0 { + return errors.New("role not found") + } + + role = &entity.Role{ + RoleName: req.GetEntity().GetRole().GetName(), + Privileges: lo.Map(resp.GetEntities(), func(g *milvuspb.GrantEntity, _ int) entity.GrantItem { + return entity.GrantItem{ + Object: g.Object.GetName(), + ObjectName: g.GetObjectName(), + RoleName: g.GetRole().GetName(), + Grantor: g.GetGrantor().GetUser().GetName(), + Privilege: g.GetGrantor().GetPrivilege().GetName(), + } + }), + } + return nil + }) + return role, err +} + +func (c *Client) GrantPrivilege(ctx context.Context, option GrantPrivilegeOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilege(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) RevokePrivilege(ctx context.Context, option RevokePrivilegeOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilege(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + func (c *Client) GrantV2(ctx context.Context, option GrantV2Option, callOptions ...grpc.CallOption) error { req := option.Request() diff --git a/client/milvusclient/rbac_options.go b/client/milvusclient/rbac_options.go index 525bd2047b0c8..cacb72718d211 100644 --- a/client/milvusclient/rbac_options.go +++ b/client/milvusclient/rbac_options.go @@ -20,6 +20,314 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) +type ListUserOption interface { + Request() *milvuspb.ListCredUsersRequest +} + +// listUserOption is the struct to build ListCredUsersRequest +// left empty for not attribute needed right now +type listUserOption struct{} + +func (opt *listUserOption) Request() *milvuspb.ListCredUsersRequest { + return &milvuspb.ListCredUsersRequest{} +} + +func NewListUserOption() *listUserOption { + return &listUserOption{} +} + +type DescribeUserOption interface { + Request() *milvuspb.SelectUserRequest +} + +type describeUserOption struct { + userName string +} + +func (opt *describeUserOption) Request() *milvuspb.SelectUserRequest { + return &milvuspb.SelectUserRequest{ + User: &milvuspb.UserEntity{ + Name: opt.userName, + }, + IncludeRoleInfo: true, + } +} + +func NewDescribeUserOption(userName string) *describeUserOption { + return &describeUserOption{ + userName: userName, + } +} + +type CreateUserOption interface { + Request() *milvuspb.CreateCredentialRequest +} + +type createUserOption struct { + userName string + password string +} + +func (opt *createUserOption) Request() *milvuspb.CreateCredentialRequest { + return &milvuspb.CreateCredentialRequest{ + Username: opt.userName, + Password: opt.password, + } +} + +func NewCreateUserOption(userName, password string) *createUserOption { + return &createUserOption{ + userName: userName, + password: password, + } +} + +type UpdatePasswordOption interface { + Request() *milvuspb.UpdateCredentialRequest +} + +type updatePasswordOption struct { + userName string + oldPassword string + newPassword string +} + +func (opt *updatePasswordOption) Request() *milvuspb.UpdateCredentialRequest { + return &milvuspb.UpdateCredentialRequest{ + Username: opt.userName, + OldPassword: opt.oldPassword, + NewPassword: opt.newPassword, + } +} + +func NewUpdatePasswordOption(userName, oldPassword, newPassword string) *updatePasswordOption { + return &updatePasswordOption{ + userName: userName, + oldPassword: oldPassword, + newPassword: newPassword, + } +} + +type DropUserOption interface { + Request() *milvuspb.DeleteCredentialRequest +} + +type dropUserOption struct { + userName string +} + +func (opt *dropUserOption) Request() *milvuspb.DeleteCredentialRequest { + return &milvuspb.DeleteCredentialRequest{ + Username: opt.userName, + } +} + +func NewDropUserOption(userName string) *dropUserOption { + return &dropUserOption{ + userName: userName, + } +} + +type ListRoleOption interface { + Request() *milvuspb.SelectRoleRequest +} + +type listRoleOption struct{} + +func (opt *listRoleOption) Request() *milvuspb.SelectRoleRequest { + return &milvuspb.SelectRoleRequest{ + IncludeUserInfo: false, + } +} + +func NewListRoleOption() *listRoleOption { + return &listRoleOption{} +} + +type CreateRoleOption interface { + Request() *milvuspb.CreateRoleRequest +} + +type createRoleOption struct { + roleName string +} + +func (opt *createRoleOption) Request() *milvuspb.CreateRoleRequest { + return &milvuspb.CreateRoleRequest{ + Entity: &milvuspb.RoleEntity{Name: opt.roleName}, + } +} + +func NewCreateRoleOption(roleName string) *createRoleOption { + return &createRoleOption{ + roleName: roleName, + } +} + +type GrantRoleOption interface { + Request() *milvuspb.OperateUserRoleRequest +} + +type grantRoleOption struct { + roleName string + userName string +} + +func (opt *grantRoleOption) Request() *milvuspb.OperateUserRoleRequest { + return &milvuspb.OperateUserRoleRequest{ + Username: opt.userName, + RoleName: opt.roleName, + Type: milvuspb.OperateUserRoleType_AddUserToRole, + } +} + +func NewGrantRoleOption(userName, roleName string) *grantRoleOption { + return &grantRoleOption{ + roleName: roleName, + userName: userName, + } +} + +type RevokeRoleOption interface { + Request() *milvuspb.OperateUserRoleRequest +} + +type revokeRoleOption struct { + roleName string + userName string +} + +func (opt *revokeRoleOption) Request() *milvuspb.OperateUserRoleRequest { + return &milvuspb.OperateUserRoleRequest{ + Username: opt.userName, + RoleName: opt.roleName, + Type: milvuspb.OperateUserRoleType_RemoveUserFromRole, + } +} + +func NewRevokeRoleOption(userName, roleName string) *revokeRoleOption { + return &revokeRoleOption{ + roleName: roleName, + userName: userName, + } +} + +type DropRoleOption interface { + Request() *milvuspb.DropRoleRequest +} + +type dropDropRoleOption struct { + roleName string +} + +func (opt *dropDropRoleOption) Request() *milvuspb.DropRoleRequest { + return &milvuspb.DropRoleRequest{ + RoleName: opt.roleName, + } +} + +func NewDropRoleOption(roleName string) *dropDropRoleOption { + return &dropDropRoleOption{ + roleName: roleName, + } +} + +type DescribeRoleOption interface { + Request() *milvuspb.SelectGrantRequest +} + +type describeRoleOption struct { + roleName string +} + +func (opt *describeRoleOption) Request() *milvuspb.SelectGrantRequest { + return &milvuspb.SelectGrantRequest{ + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: opt.roleName}, + }, + } +} + +func NewDescribeRoleOption(roleName string) *describeRoleOption { + return &describeRoleOption{ + roleName: roleName, + } +} + +type GrantPrivilegeOption interface { + Request() *milvuspb.OperatePrivilegeRequest +} + +type grantPrivilegeOption struct { + roleName string + privilegeName string + objectName string + objectType string +} + +func (opt *grantPrivilegeOption) Request() *milvuspb.OperatePrivilegeRequest { + return &milvuspb.OperatePrivilegeRequest{ + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: opt.roleName}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{Name: opt.privilegeName}, + }, + Object: &milvuspb.ObjectEntity{ + Name: opt.objectType, + }, + ObjectName: opt.objectName, + }, + + Type: milvuspb.OperatePrivilegeType_Grant, + } +} + +func NewGrantPrivilegeOption(roleName, objectType, privilegeName, objectName string) *grantPrivilegeOption { + return &grantPrivilegeOption{ + roleName: roleName, + privilegeName: privilegeName, + objectName: objectName, + objectType: objectType, + } +} + +type RevokePrivilegeOption interface { + Request() *milvuspb.OperatePrivilegeRequest +} + +type revokePrivilegeOption struct { + roleName string + privilegeName string + objectName string + objectType string +} + +func (opt *revokePrivilegeOption) Request() *milvuspb.OperatePrivilegeRequest { + return &milvuspb.OperatePrivilegeRequest{ + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: opt.roleName}, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{Name: opt.privilegeName}, + }, + Object: &milvuspb.ObjectEntity{ + Name: opt.objectType, + }, + ObjectName: opt.objectName, + }, + + Type: milvuspb.OperatePrivilegeType_Revoke, + } +} + +func NewRevokePrivilegeOption(roleName, objectType, privilegeName, objectName string) *revokePrivilegeOption { + return &revokePrivilegeOption{ + roleName: roleName, + privilegeName: privilegeName, + objectName: objectName, + objectType: objectType, + } +} + // GrantV2Option is the interface builds OperatePrivilegeV2Request type GrantV2Option interface { Request() *milvuspb.OperatePrivilegeV2Request diff --git a/client/milvusclient/rbac_test.go b/client/milvusclient/rbac_test.go index cc8f480da2876..f5d0429b6595e 100644 --- a/client/milvusclient/rbac_test.go +++ b/client/milvusclient/rbac_test.go @@ -29,6 +29,376 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +type UserSuite struct { + MockSuiteBase +} + +func (s *UserSuite) TestListUsers() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + s.mock.EXPECT().ListCredUsers(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { + return &milvuspb.ListCredUsersResponse{ + Usernames: []string{"user1", "user2"}, + }, nil + }).Once() + + users, err := s.client.ListUsers(ctx, NewListUserOption()) + s.NoError(err) + s.Equal([]string{"user1", "user2"}, users) + }) + + s.Run("failure", func() { + s.mock.EXPECT().ListCredUsers(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.ListUsers(ctx, NewListUserOption()) + s.Error(err) + }) +} + +func (s *UserSuite) TestDescribeUser() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + userName := fmt.Sprintf("user_%s", s.randString(5)) + + s.Run("success", func() { + s.mock.EXPECT().SelectUser(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { + s.Equal(userName, r.GetUser().GetName()) + return &milvuspb.SelectUserResponse{ + Results: []*milvuspb.UserResult{ + { + User: &milvuspb.UserEntity{Name: userName}, + Roles: []*milvuspb.RoleEntity{ + {Name: "role1"}, + {Name: "role2"}, + }, + }, + }, + }, nil + }).Once() + + user, err := s.client.DescribeUser(ctx, NewDescribeUserOption(userName)) + s.NoError(err) + s.Equal(userName, user.UserName) + s.Equal([]string{"role1", "role2"}, user.Roles) + }) + + s.Run("failure", func() { + s.mock.EXPECT().SelectUser(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.DescribeUser(ctx, NewDescribeUserOption(userName)) + s.Error(err) + }) +} + +func (s *UserSuite) TestCreateUser() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + userName := fmt.Sprintf("user_%s", s.randString(5)) + password := s.randString(12) + s.mock.EXPECT().CreateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ccr *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { + s.Equal(userName, ccr.GetUsername()) + s.Equal(password, ccr.GetPassword()) + return merr.Success(), nil + }).Once() + + err := s.client.CreateUser(ctx, NewCreateUserOption(userName, password)) + s.NoError(err) + }) +} + +func (s *UserSuite) TestUpdatePassword() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + userName := fmt.Sprintf("user_%s", s.randString(5)) + oldPassword := s.randString(12) + newPassword := s.randString(12) + s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ucr *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) { + s.Equal(userName, ucr.GetUsername()) + s.Equal(oldPassword, ucr.GetOldPassword()) + s.Equal(newPassword, ucr.GetNewPassword()) + return merr.Success(), nil + }).Once() + + err := s.client.UpdatePassword(ctx, NewUpdatePasswordOption(userName, oldPassword, newPassword)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.UpdatePassword(ctx, NewUpdatePasswordOption("user", "old", "new")) + s.Error(err) + }) +} + +func (s *UserSuite) TestDropUser() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + userName := fmt.Sprintf("user_%s", s.randString(5)) + s.mock.EXPECT().DeleteCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dcr *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { + s.Equal(userName, dcr.GetUsername()) + return merr.Success(), nil + }).Once() + + err := s.client.DropUser(ctx, NewDropUserOption(userName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().DeleteCredential(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.DropUser(ctx, NewDropUserOption("user")) + s.Error(err) + }) +} + +func TestUserRBAC(t *testing.T) { + suite.Run(t, new(UserSuite)) +} + +type RoleSuite struct { + MockSuiteBase +} + +func (s *RoleSuite) TestListRoles() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { + return &milvuspb.SelectRoleResponse{ + Results: []*milvuspb.RoleResult{ + {Role: &milvuspb.RoleEntity{Name: "role1"}}, + {Role: &milvuspb.RoleEntity{Name: "role2"}}, + }, + }, nil + }).Once() + + roles, err := s.client.ListRoles(ctx, NewListRoleOption()) + s.NoError(err) + s.Equal([]string{"role1", "role2"}, roles) + }) + + s.Run("failure", func() { + s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.ListRoles(ctx, NewListRoleOption()) + s.Error(err) + }) +} + +func (s *RoleSuite) TestCreateRole() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + roleName := fmt.Sprintf("role_%s", s.randString(5)) + s.mock.EXPECT().CreateRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { + s.Equal(roleName, r.GetEntity().GetName()) + return merr.Success(), nil + }).Once() + + err := s.client.CreateRole(ctx, NewCreateRoleOption(roleName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().CreateRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.CreateRole(ctx, NewCreateRoleOption("role")) + s.Error(err) + }) +} + +func (s *RoleSuite) TestGrantRole() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + userName := fmt.Sprintf("user_%s", s.randString(5)) + roleName := fmt.Sprintf("role_%s", s.randString(5)) + s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { + s.Equal(userName, r.GetUsername()) + s.Equal(roleName, r.GetRoleName()) + return merr.Success(), nil + }).Once() + + err := s.client.GrantRole(ctx, NewGrantRoleOption(userName, roleName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.GrantRole(ctx, NewGrantRoleOption("user", "role")) + s.Error(err) + }) +} + +func (s *RoleSuite) TestRevokeRole() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + userName := fmt.Sprintf("user_%s", s.randString(5)) + roleName := fmt.Sprintf("role_%s", s.randString(5)) + s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { + s.Equal(userName, r.GetUsername()) + s.Equal(roleName, r.GetRoleName()) + return merr.Success(), nil + }).Once() + + err := s.client.RevokeRole(ctx, NewRevokeRoleOption(userName, roleName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.RevokeRole(ctx, NewRevokeRoleOption("user", "role")) + s.Error(err) + }) +} + +func (s *RoleSuite) TestDropRole() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + roleName := fmt.Sprintf("role_%s", s.randString(5)) + s.mock.EXPECT().DropRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.DropRoleRequest) (*commonpb.Status, error) { + s.Equal(roleName, r.GetRoleName()) + return merr.Success(), nil + }).Once() + + err := s.client.DropRole(ctx, NewDropRoleOption(roleName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().DropRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.DropRole(ctx, NewDropRoleOption("role")) + s.Error(err) + }) +} + +func (s *RoleSuite) TestDescribeRole() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + roleName := fmt.Sprintf("role_%s", s.randString(5)) + s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { + s.Equal(roleName, r.GetEntity().GetRole().GetName()) + return &milvuspb.SelectGrantResponse{ + Entities: []*milvuspb.GrantEntity{ + { + ObjectName: "*", + Object: &milvuspb.ObjectEntity{ + Name: "collection", + }, + Role: &milvuspb.RoleEntity{Name: roleName}, + Grantor: &milvuspb.GrantorEntity{User: &milvuspb.UserEntity{Name: "admin"}, Privilege: &milvuspb.PrivilegeEntity{Name: "Insert"}}, + }, + { + ObjectName: "*", + Object: &milvuspb.ObjectEntity{ + Name: "collection", + }, + Role: &milvuspb.RoleEntity{Name: roleName}, + Grantor: &milvuspb.GrantorEntity{User: &milvuspb.UserEntity{Name: "admin"}, Privilege: &milvuspb.PrivilegeEntity{Name: "Query"}}, + }, + }, + }, nil + }).Once() + + role, err := s.client.DescribeRole(ctx, NewDescribeRoleOption(roleName)) + s.NoError(err) + s.Equal(roleName, role.RoleName) + }) + + s.Run("failure", func() { + s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + _, err := s.client.DescribeRole(ctx, NewDescribeRoleOption("role")) + s.Error(err) + }) +} + +func (s *RoleSuite) TestGrantPrivilege() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + roleName := fmt.Sprintf("role_%s", s.randString(5)) + privilegeName := "Insert" + collectionName := fmt.Sprintf("collection_%s", s.randString(6)) + + s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { + s.Equal(roleName, r.GetEntity().GetRole().GetName()) + s.Equal("collection", r.GetEntity().GetObject().GetName()) + s.Equal(privilegeName, r.GetEntity().GetGrantor().GetPrivilege().GetName()) + s.Equal(collectionName, r.GetEntity().GetObjectName()) + s.Equal(milvuspb.OperatePrivilegeType_Grant, r.GetType()) + return merr.Success(), nil + }).Once() + + err := s.client.GrantPrivilege(ctx, NewGrantPrivilegeOption(roleName, "collection", privilegeName, collectionName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.GrantPrivilege(ctx, NewGrantPrivilegeOption("role", "collection", "privilege", "coll_1")) + s.Error(err) + }) +} + +func (s *RoleSuite) TestRevokePrivilege() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + roleName := fmt.Sprintf("role_%s", s.randString(5)) + privilegeName := "Insert" + collectionName := fmt.Sprintf("collection_%s", s.randString(6)) + + s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { + s.Equal(roleName, r.GetEntity().GetRole().GetName()) + s.Equal("collection", r.GetEntity().GetObject().GetName()) + s.Equal(privilegeName, r.GetEntity().GetGrantor().GetPrivilege().GetName()) + s.Equal(collectionName, r.GetEntity().GetObjectName()) + s.Equal(milvuspb.OperatePrivilegeType_Revoke, r.GetType()) + return merr.Success(), nil + }).Once() + + err := s.client.RevokePrivilege(ctx, NewRevokePrivilegeOption(roleName, "collection", privilegeName, collectionName)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.RevokePrivilege(ctx, NewRevokePrivilegeOption("role", "collection", "privilege", "coll_1")) + s.Error(err) + }) +} + +func TestRoleRBAC(t *testing.T) { + suite.Run(t, new(RoleSuite)) +} + type PrivilgeGroupSuite struct { MockSuiteBase } diff --git a/client/milvusclient/read.go b/client/milvusclient/read.go index e07185e4846d4..ab87598c23a80 100644 --- a/client/milvusclient/read.go +++ b/client/milvusclient/read.go @@ -192,6 +192,10 @@ func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...g return resultSet, err } +func (c *Client) Get(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) { + return c.Query(ctx, option, callOptions...) +} + func (c *Client) HybridSearch(ctx context.Context, option HybridSearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) { req, err := option.HybridRequest() if err != nil { diff --git a/client/milvusclient/read_option_test.go b/client/milvusclient/read_option_test.go index c4e1e52c99159..04a061047073a 100644 --- a/client/milvusclient/read_option_test.go +++ b/client/milvusclient/read_option_test.go @@ -52,8 +52,7 @@ func (s *SearchOptionSuite) TestBasic() { topK := rand.Intn(100) + 1 opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})}) - opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000") - + opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true) req, err := opt.Request() s.Require().NoError(err) @@ -64,6 +63,15 @@ func (s *SearchOptionSuite) TestBasic() { annField, ok := searchParams[spAnnsField] s.Require().True(ok) s.Equal("test_field", annField) + groupField, ok := searchParams[spGroupBy] + s.Require().True(ok) + s.Equal("group_field", groupField) + groupSize, ok := searchParams[spGroupSize] + s.Require().True(ok) + s.Equal("10", groupSize) + spStrictGroupSize, ok := searchParams[spStrictGroupSize] + s.Require().True(ok) + s.Equal("true", spStrictGroupSize) opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}}) _, err = opt.Request() diff --git a/client/milvusclient/read_options.go b/client/milvusclient/read_options.go index 8f040f5f531ee..89f3d4e3bb765 100644 --- a/client/milvusclient/read_options.go +++ b/client/milvusclient/read_options.go @@ -21,6 +21,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "github.com/cockroachdb/errors" "google.golang.org/protobuf/proto" @@ -28,20 +29,23 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/client/v2/column" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/index" ) const ( - spAnnsField = `anns_field` - spTopK = `topk` - spOffset = `offset` - spLimit = `limit` - spParams = `params` - spMetricsType = `metric_type` - spRoundDecimal = `round_decimal` - spIgnoreGrowing = `ignore_growing` - spGroupBy = `group_by_field` + spAnnsField = `anns_field` + spTopK = `topk` + spOffset = `offset` + spLimit = `limit` + spParams = `params` + spMetricsType = `metric_type` + spRoundDecimal = `round_decimal` + spIgnoreGrowing = `ignore_growing` + spGroupBy = `group_by_field` + spGroupSize = `group_size` + spStrictGroupSize = `strict_group_size` ) type SearchOption interface { @@ -62,16 +66,18 @@ type searchOption struct { type annRequest struct { vectors []entity.Vector - annField string - metricsType entity.MetricType - searchParam map[string]string - groupByField string - annParam index.AnnParam - ignoreGrowing bool - expr string - topK int - offset int - templateParams map[string]any + annField string + metricsType entity.MetricType + searchParam map[string]string + groupByField string + groupSize int + strictGroupSize bool + annParam index.AnnParam + ignoreGrowing bool + expr string + topK int + offset int + templateParams map[string]any } func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest { @@ -108,6 +114,12 @@ func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) { if r.groupByField != "" { params[spGroupBy] = r.groupByField } + if r.groupSize != 0 { + params[spGroupSize] = strconv.Itoa(r.groupSize) + } + if r.strictGroupSize { + params[spStrictGroupSize] = "true" + } // ann param if r.annParam != nil { bs, _ := json.Marshal(r.annParam.Params()) @@ -223,6 +235,16 @@ func (r *annRequest) WithGroupByField(groupByField string) *annRequest { return r } +func (r *annRequest) WithGroupSize(groupSize int) *annRequest { + r.groupSize = groupSize + return r +} + +func (r *annRequest) WithStrictGroupSize(strictGroupSize bool) *annRequest { + r.strictGroupSize = strictGroupSize + return r +} + func (r *annRequest) WithSearchParam(key, value string) *annRequest { r.searchParam[key] = value return r @@ -309,6 +331,16 @@ func (opt *searchOption) WithGroupByField(groupByField string) *searchOption { return opt } +func (opt *searchOption) WithGroupSize(groupSize int) *searchOption { + opt.annRequest.WithGroupSize(groupSize) + return opt +} + +func (opt *searchOption) WithStrictGroupSize(strictGroupSize bool) *searchOption { + opt.annRequest.WithStrictGroupSize(strictGroupSize) + return opt +} + func (opt *searchOption) WithIgnoreGrowing(ignoreGrowing bool) *searchOption { opt.annRequest.WithIgnoreGrowing(ignoreGrowing) return opt @@ -550,6 +582,27 @@ func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption { return opt } +func (opt *queryOption) WithIDs(ids column.Column) *queryOption { + opt.expr = pks2Expr(ids) + return opt +} + +func pks2Expr(ids column.Column) string { + var expr string + pkName := ids.Name() + switch ids.Type() { + case entity.FieldTypeInt64: + expr = fmt.Sprintf("%s in %s", pkName, strings.Join(strings.Fields(fmt.Sprint(ids.FieldData().GetScalars().GetLongData().GetData())), ",")) + case entity.FieldTypeVarChar: + data := ids.FieldData().GetScalars().GetData().(*schemapb.ScalarField_StringData).StringData.GetData() + for i := range data { + data[i] = fmt.Sprintf("\"%s\"", data[i]) + } + expr = fmt.Sprintf("%s in [%s]", pkName, strings.Join(data, ",")) + } + return expr +} + func NewQueryOption(collectionName string) *queryOption { return &queryOption{ collectionName: collectionName, diff --git a/client/milvusclient/resource_group.go b/client/milvusclient/resource_group.go new file mode 100644 index 0000000000000..e91654548dfaf --- /dev/null +++ b/client/milvusclient/resource_group.go @@ -0,0 +1,65 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func (c *Client) ListResourceGroups(ctx context.Context, opt ListResourceGroupsOption, callOptions ...grpc.CallOption) ([]string, error) { + req := opt.Request() + + var rgs []string + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.ListResourceGroups(ctx, req, callOptions...) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + + rgs = resp.GetResourceGroups() + return nil + }) + + return rgs, err +} + +func (c *Client) CreateResourceGroup(ctx context.Context, opt CreateResourceGroupOption, callOptions ...grpc.CallOption) error { + req := opt.Request() + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.CreateResourceGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) + + return err +} + +func (c *Client) DropResourceGroup(ctx context.Context, opt DropResourceGroupOption, callOptions ...grpc.CallOption) error { + req := opt.Request() + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.DropResourceGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) + + return err +} diff --git a/client/milvusclient/resource_group_option.go b/client/milvusclient/resource_group_option.go new file mode 100644 index 0000000000000..5f70f69d0b597 --- /dev/null +++ b/client/milvusclient/resource_group_option.go @@ -0,0 +1,92 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" +) + +type ListResourceGroupsOption interface { + Request() *milvuspb.ListResourceGroupsRequest +} + +type listResourceGroupsOption struct{} + +func (opt *listResourceGroupsOption) Request() *milvuspb.ListResourceGroupsRequest { + return &milvuspb.ListResourceGroupsRequest{} +} + +func NewListResourceGroupsOption() *listResourceGroupsOption { + return &listResourceGroupsOption{} +} + +type CreateResourceGroupOption interface { + Request() *milvuspb.CreateResourceGroupRequest +} + +type createResourceGroupOption struct { + name string + nodeRequest int + nodeLimit int +} + +func (opt *createResourceGroupOption) WithNodeRequest(nodeRequest int) *createResourceGroupOption { + opt.nodeRequest = nodeRequest + return opt +} + +func (opt *createResourceGroupOption) WithNodeLimit(nodeLimit int) *createResourceGroupOption { + opt.nodeLimit = nodeLimit + return opt +} + +func (opt *createResourceGroupOption) Request() *milvuspb.CreateResourceGroupRequest { + return &milvuspb.CreateResourceGroupRequest{ + ResourceGroup: opt.name, + Config: &rgpb.ResourceGroupConfig{ + Requests: &rgpb.ResourceGroupLimit{ + NodeNum: int32(opt.nodeRequest), + }, + Limits: &rgpb.ResourceGroupLimit{ + NodeNum: int32(opt.nodeLimit), + }, + }, + } +} + +func NewCreateResourceGroupOption(name string) *createResourceGroupOption { + return &createResourceGroupOption{name: name} +} + +type DropResourceGroupOption interface { + Request() *milvuspb.DropResourceGroupRequest +} + +type dropResourceGroupOption struct { + name string +} + +func (opt *dropResourceGroupOption) Request() *milvuspb.DropResourceGroupRequest { + return &milvuspb.DropResourceGroupRequest{ + ResourceGroup: opt.name, + } +} + +func NewDropResourceGroupOption(name string) *dropResourceGroupOption { + return &dropResourceGroupOption{name: name} +} diff --git a/client/milvusclient/resource_group_test.go b/client/milvusclient/resource_group_test.go new file mode 100644 index 0000000000000..2e87647bf0965 --- /dev/null +++ b/client/milvusclient/resource_group_test.go @@ -0,0 +1,108 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +type ResourceGroupSuite struct { + MockSuiteBase +} + +func (s *ResourceGroupSuite) TestListResourceGroups() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + s.mock.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{ + ResourceGroups: []string{"rg1", "rg2"}, + }, nil).Once() + rgs, err := s.client.ListResourceGroups(ctx, NewListResourceGroupsOption()) + s.NoError(err) + s.Equal([]string{"rg1", "rg2"}, rgs) + }) + + s.Run("failure", func() { + s.mock.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once() + _, err := s.client.ListResourceGroups(ctx, NewListResourceGroupsOption()) + s.Error(err) + }) +} + +func (s *ResourceGroupSuite) TestCreateResourceGroup() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + rgName := fmt.Sprintf("rg_%s", s.randString(6)) + s.mock.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, crgr *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { + s.Equal(rgName, crgr.GetResourceGroup()) + s.Equal(int32(5), crgr.GetConfig().GetRequests().GetNodeNum()) + s.Equal(int32(10), crgr.GetConfig().GetLimits().GetNodeNum()) + return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil + }).Once() + opt := NewCreateResourceGroupOption(rgName).WithNodeLimit(10).WithNodeRequest(5) + err := s.client.CreateResourceGroup(ctx, opt) + s.NoError(err) + }) + + s.Run("failure", func() { + rgName := fmt.Sprintf("rg_%s", s.randString(6)) + s.mock.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once() + opt := NewCreateResourceGroupOption(rgName).WithNodeLimit(10).WithNodeRequest(5) + err := s.client.CreateResourceGroup(ctx, opt) + s.Error(err) + }) +} + +func (s *ResourceGroupSuite) TestDropResourceGroup() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + rgName := fmt.Sprintf("rg_%s", s.randString(6)) + s.mock.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, drgr *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { + s.Equal(rgName, drgr.GetResourceGroup()) + return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil + }).Once() + opt := NewDropResourceGroupOption(rgName) + err := s.client.DropResourceGroup(ctx, opt) + s.NoError(err) + }) + + s.Run("failure", func() { + rgName := fmt.Sprintf("rg_%s", s.randString(6)) + s.mock.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once() + opt := NewDropResourceGroupOption(rgName) + err := s.client.DropResourceGroup(ctx, opt) + s.Error(err) + }) +} + +func TestResourceGroup(t *testing.T) { + suite.Run(t, new(ResourceGroupSuite)) +}