diff --git a/internal/rootcoord/constrant.go b/internal/rootcoord/constrant.go index b54b5d8beca1e..199b31fb8f484 100644 --- a/internal/rootcoord/constrant.go +++ b/internal/rootcoord/constrant.go @@ -20,7 +20,6 @@ import ( "context" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -35,7 +34,6 @@ func checkGeneralCapacity(ctx context.Context, newColNum int, newParNum int64, newShardNum int32, core *Core, - ts typeutil.Timestamp, ) error { var addedNum int64 = 0 if newColNum > 0 && newParNum > 0 && newShardNum > 0 { @@ -46,25 +44,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int, addedNum += newParNum } - var generalNum int64 = 0 - collectionsMap := core.meta.ListAllAvailCollections(ctx) - for dbId, collectionIDs := range collectionsMap { - db, err := core.meta.GetDatabaseByID(ctx, dbId, ts) - if err == nil { - for _, collectionId := range collectionIDs { - collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true) - if err == nil { - partNum := int64(collection.GetPartitionNum(false)) - shardNum := int64(collection.ShardsNum) - generalNum += partNum * shardNum - } - } - } - } - - generalNum += addedNum - if generalNum > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64() { - return merr.WrapGeneralCapacityExceed(generalNum, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(), + generalCount := core.meta.GetGeneralCount(ctx) + generalCount += int(addedNum) + if generalCount > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt() { + return merr.WrapGeneralCapacityExceed(generalCount, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(), "failed checking constraint: sum_collections(parition*shard) exceeding the max general capacity:") } return nil diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index 3bfd9a680425f..b24b14fc80245 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -103,7 +103,7 @@ func (t *createCollectionTask) validate() error { if t.Req.GetNumPartitions() > 0 { newPartNum = t.Req.GetNumPartitions() } - return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core, t.ts) + return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core) } // checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections. diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index 6cc64bafd96ac..6d1bdb56c9c3c 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -246,23 +246,7 @@ func Test_createCollectionTask_validate(t *testing.T) { meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{1: {1, 2}}) meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). Return(&model.Database{Name: "db1"}, nil).Once() - - meta.On("GetDatabaseByID", - mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Database{ - Name: "default", - }, nil) - meta.On("GetCollectionByID", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Collection{ - Name: "default", - ShardsNum: 2, - Partitions: []*model.Partition{ - { - PartitionID: 1, - }, - }, - }, nil) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(1) core := newTestCore(withMeta(meta)) @@ -295,8 +279,7 @@ func Test_createCollectionTask_validate(t *testing.T) { }, }, }, nil).Once() - meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything). - Return(nil, errors.New("mock")) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) core := newTestCore(withMeta(meta)) task := createCollectionTask{ @@ -642,6 +625,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) { ).Return(map[int64][]int64{ util.DefaultDBID: {1, 2}, }, nil) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64)) defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key) @@ -662,8 +646,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) { }) t.Run("invalid schema", func(t *testing.T) { - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) core := newTestCore(withMeta(meta)) collectionName := funcutil.GenRandomStr() task := &createCollectionTask{ @@ -692,8 +674,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) { } marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) core := newTestCore(withInvalidIDAllocator(), withMeta(meta)) task := createCollectionTask{ @@ -716,8 +696,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) { field1 := funcutil.GenRandomStr() ticker := newRocksMqTtSynchronizer() - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta)) @@ -1056,8 +1034,7 @@ func Test_createCollectionTask_PartitionKey(t *testing.T) { ).Return(map[int64][]int64{ util.DefaultDBID: {1, 2}, }, nil) - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64)) defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key) diff --git a/internal/rootcoord/create_partition_task.go b/internal/rootcoord/create_partition_task.go index 5a97449e54ac1..53682cb018b9c 100644 --- a/internal/rootcoord/create_partition_task.go +++ b/internal/rootcoord/create_partition_task.go @@ -44,7 +44,7 @@ func (t *createPartitionTask) Prepare(ctx context.Context) error { return err } t.collMeta = collMeta - return checkGeneralCapacity(ctx, 0, 1, 0, t.core, t.ts) + return checkGeneralCapacity(ctx, 0, 1, 0, t.core) } func (t *createPartitionTask) Execute(ctx context.Context) error { diff --git a/internal/rootcoord/create_partition_task_test.go b/internal/rootcoord/create_partition_task_test.go index 880291da9baac..bf830fd04de75 100644 --- a/internal/rootcoord/create_partition_task_test.go +++ b/internal/rootcoord/create_partition_task_test.go @@ -20,7 +20,6 @@ import ( "context" "testing" - "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -62,14 +61,7 @@ func Test_createPartitionTask_Prepare(t *testing.T) { mock.Anything, mock.Anything, ).Return(coll.Clone(), nil) - meta.On("ListAllAvailCollections", - mock.Anything, - ).Return(map[int64][]int64{ - 1: {1, 2}, - }, nil) - meta.On("GetDatabaseByID", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, errors.New("mock")) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) core := newTestCore(withMeta(meta)) task := &createPartitionTask{ diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index 31364f431b055..ca58353f85c85 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -72,6 +72,7 @@ type IMetaTable interface { ListAliases(ctx context.Context, dbName string, collectionName string, ts Timestamp) ([]string, error) AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error + GetGeneralCount(ctx context.Context) int // TODO: it'll be a big cost if we handle the time travel logic, since we should always list all aliases in catalog. IsAlias(db, name string) bool @@ -114,6 +115,8 @@ type MetaTable struct { dbName2Meta map[string]*model.Database // database name -> db meta collID2Meta map[typeutil.UniqueID]*model.Collection // collection id -> collection meta + generalCnt int // sum of product of partition number and shard number + // collections *collectionDb names *nameDb aliases *nameDb @@ -187,6 +190,7 @@ func (mt *MetaTable) reload() error { } for _, collection := range collections { mt.collID2Meta[collection.CollectionID] = collection + mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum) if collection.Available() { mt.names.insert(dbName, collection.Name, collection.CollectionID) collectionNum++ @@ -409,6 +413,8 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection) mt.collID2Meta[coll.CollectionID] = coll.Clone() mt.names.insert(db.Name, coll.Name, coll.CollectionID) + mt.generalCnt += len(coll.Partitions) * int(coll.ShardsNum) + log.Ctx(ctx).Info("add collection to meta table", zap.Int64("dbID", coll.DBID), zap.String("collection", coll.Name), @@ -513,6 +519,8 @@ func (mt *MetaTable) RemoveCollection(ctx context.Context, collectionID UniqueID mt.removeAllNamesIfMatchedInternal(collectionID, allNames) mt.removeCollectionByIDInternal(collectionID) + mt.generalCnt -= len(coll.Partitions) * int(coll.ShardsNum) + log.Ctx(ctx).Info("remove collection", zap.Int64("dbID", coll.DBID), zap.String("name", coll.Name), @@ -861,6 +869,8 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio } mt.collID2Meta[partition.CollectionID].Partitions = append(mt.collID2Meta[partition.CollectionID].Partitions, partition.Clone()) + mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum + metrics.RootCoordNumOfPartitions.WithLabelValues().Inc() log.Ctx(ctx).Info("add partition to meta table", @@ -927,6 +937,7 @@ func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collection } if loc != -1 { coll.Partitions = append(coll.Partitions[:loc], coll.Partitions[loc+1:]...) + mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum } log.Info("remove partition", zap.Int64("collection", collectionID), zap.Int64("partition", partitionID), zap.Uint64("ts", ts)) return nil @@ -1195,6 +1206,14 @@ func (mt *MetaTable) ListAliasesByID(collID UniqueID) []string { return mt.listAliasesByID(collID) } +// GetGeneralCount gets the general count(sum of product of partition number and shard number). +func (mt *MetaTable) GetGeneralCount(ctx context.Context) int { + mt.ddLock.RLock() + defer mt.ddLock.RUnlock() + + return mt.generalCnt +} + // AddCredential add credential func (mt *MetaTable) AddCredential(credInfo *internalpb.CredentialInfo) error { if credInfo.Username == "" { diff --git a/internal/rootcoord/mocks/meta_table.go b/internal/rootcoord/mocks/meta_table.go index 28f679f055991..b7916680aad24 100644 --- a/internal/rootcoord/mocks/meta_table.go +++ b/internal/rootcoord/mocks/meta_table.go @@ -1357,9 +1357,104 @@ func (_c *IMetaTable_GetDatabaseByName_Call) RunAndReturn(run func(context.Conte return _c } -// GetPrivilegeGroupRoles provides a mock function with given fields: groupName -func (_m *IMetaTable) GetPrivilegeGroupRoles(groupName string) ([]*milvuspb.RoleEntity, error) { - ret := _m.Called(groupName) +// GetGeneralCount provides a mock function with given fields: ctx +func (_m *IMetaTable) GetGeneralCount(ctx context.Context) int { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetGeneralCount") + } + + var r0 int + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// IMetaTable_GetGeneralCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGeneralCount' +type IMetaTable_GetGeneralCount_Call struct { + *mock.Call +} + +// GetGeneralCount is a helper method to define mock.On call +// - ctx context.Context +func (_e *IMetaTable_Expecter) GetGeneralCount(ctx interface{}) *IMetaTable_GetGeneralCount_Call { + return &IMetaTable_GetGeneralCount_Call{Call: _e.mock.On("GetGeneralCount", ctx)} +} + +func (_c *IMetaTable_GetGeneralCount_Call) Run(run func(ctx context.Context)) *IMetaTable_GetGeneralCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *IMetaTable_GetGeneralCount_Call) Return(_a0 int) *IMetaTable_GetGeneralCount_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *IMetaTable_GetGeneralCount_Call) RunAndReturn(run func(context.Context) int) *IMetaTable_GetGeneralCount_Call { + _c.Call.Return(run) + return _c +} + +// GetPChannelInfo provides a mock function with given fields: ctx, pchannel +func (_m *IMetaTable) GetPChannelInfo(ctx context.Context, pchannel string) *rootcoordpb.GetPChannelInfoResponse { + ret := _m.Called(ctx, pchannel) + + if len(ret) == 0 { + panic("no return value specified for GetPChannelInfo") + } + + var r0 *rootcoordpb.GetPChannelInfoResponse + if rf, ok := ret.Get(0).(func(context.Context, string) *rootcoordpb.GetPChannelInfoResponse); ok { + r0 = rf(ctx, pchannel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.GetPChannelInfoResponse) + } + } + + return r0 +} + +// IMetaTable_GetPChannelInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPChannelInfo' +type IMetaTable_GetPChannelInfo_Call struct { + *mock.Call +} + +// GetPChannelInfo is a helper method to define mock.On call +// - ctx context.Context +// - pchannel string +func (_e *IMetaTable_Expecter) GetPChannelInfo(ctx interface{}, pchannel interface{}) *IMetaTable_GetPChannelInfo_Call { + return &IMetaTable_GetPChannelInfo_Call{Call: _e.mock.On("GetPChannelInfo", ctx, pchannel)} +} + +func (_c *IMetaTable_GetPChannelInfo_Call) Run(run func(ctx context.Context, pchannel string)) *IMetaTable_GetPChannelInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *IMetaTable_GetPChannelInfo_Call) Return(_a0 *rootcoordpb.GetPChannelInfoResponse) *IMetaTable_GetPChannelInfo_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *IMetaTable_GetPChannelInfo_Call) RunAndReturn(run func(context.Context, string) *rootcoordpb.GetPChannelInfoResponse) *IMetaTable_GetPChannelInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetPrivilegeGroupRoles provides a mock function with given fields: ctx, groupName +func (_m *IMetaTable) GetPrivilegeGroupRoles(ctx context.Context, groupName string) ([]*milvuspb.RoleEntity, error) { + ret := _m.Called(ctx, groupName) if len(ret) == 0 { panic("no return value specified for GetPrivilegeGroupRoles")