From 1f8299f3b295d5ee27354ee6fd25670c8b5c357d Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Thu, 5 Dec 2024 20:56:40 +0800 Subject: [PATCH] fix: [2.4] Fix checkGeneralCapacity slowly (#38074) Cache the general count to speed up checkGeneralCapacity. issue: https://github.com/milvus-io/milvus/issues/37630 pr: https://github.com/milvus-io/milvus/pull/37976 --------- Signed-off-by: bigsheeper --- internal/rootcoord/constrant.go | 25 ++----- internal/rootcoord/create_collection_task.go | 2 +- .../rootcoord/create_collection_task_test.go | 31 ++------- internal/rootcoord/create_partition_task.go | 2 +- .../rootcoord/create_partition_task_test.go | 10 +-- internal/rootcoord/meta_table.go | 19 ++++++ internal/rootcoord/mocks/meta_table.go | 66 ++++++++++++------- 7 files changed, 72 insertions(+), 83 deletions(-) diff --git a/internal/rootcoord/constrant.go b/internal/rootcoord/constrant.go index b54b5d8beca1e..199b31fb8f484 100644 --- a/internal/rootcoord/constrant.go +++ b/internal/rootcoord/constrant.go @@ -20,7 +20,6 @@ import ( "context" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -35,7 +34,6 @@ func checkGeneralCapacity(ctx context.Context, newColNum int, newParNum int64, newShardNum int32, core *Core, - ts typeutil.Timestamp, ) error { var addedNum int64 = 0 if newColNum > 0 && newParNum > 0 && newShardNum > 0 { @@ -46,25 +44,10 @@ func checkGeneralCapacity(ctx context.Context, newColNum int, addedNum += newParNum } - var generalNum int64 = 0 - collectionsMap := core.meta.ListAllAvailCollections(ctx) - for dbId, collectionIDs := range collectionsMap { - db, err := core.meta.GetDatabaseByID(ctx, dbId, ts) - if err == nil { - for _, collectionId := range collectionIDs { - collection, err := core.meta.GetCollectionByID(ctx, db.Name, collectionId, ts, true) - if err == nil { - partNum := int64(collection.GetPartitionNum(false)) - shardNum := int64(collection.ShardsNum) - generalNum += partNum * shardNum - } - } - } - } - - generalNum += addedNum - if generalNum > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64() { - return merr.WrapGeneralCapacityExceed(generalNum, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(), + generalCount := core.meta.GetGeneralCount(ctx) + generalCount += int(addedNum) + if generalCount > Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt() { + return merr.WrapGeneralCapacityExceed(generalCount, Params.RootCoordCfg.MaxGeneralCapacity.GetAsInt64(), "failed checking constraint: sum_collections(parition*shard) exceeding the max general capacity:") } return nil diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index 3bfd9a680425f..b24b14fc80245 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -103,7 +103,7 @@ func (t *createCollectionTask) validate() error { if t.Req.GetNumPartitions() > 0 { newPartNum = t.Req.GetNumPartitions() } - return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core, t.ts) + return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core) } // checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections. diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index 6cc64bafd96ac..6d1bdb56c9c3c 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -246,23 +246,7 @@ func Test_createCollectionTask_validate(t *testing.T) { meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{1: {1, 2}}) meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything). Return(&model.Database{Name: "db1"}, nil).Once() - - meta.On("GetDatabaseByID", - mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Database{ - Name: "default", - }, nil) - meta.On("GetCollectionByID", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Collection{ - Name: "default", - ShardsNum: 2, - Partitions: []*model.Partition{ - { - PartitionID: 1, - }, - }, - }, nil) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(1) core := newTestCore(withMeta(meta)) @@ -295,8 +279,7 @@ func Test_createCollectionTask_validate(t *testing.T) { }, }, }, nil).Once() - meta.EXPECT().GetDatabaseByID(mock.Anything, mock.Anything, mock.Anything). - Return(nil, errors.New("mock")) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) core := newTestCore(withMeta(meta)) task := createCollectionTask{ @@ -642,6 +625,7 @@ func Test_createCollectionTask_Prepare(t *testing.T) { ).Return(map[int64][]int64{ util.DefaultDBID: {1, 2}, }, nil) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64)) defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key) @@ -662,8 +646,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) { }) t.Run("invalid schema", func(t *testing.T) { - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) core := newTestCore(withMeta(meta)) collectionName := funcutil.GenRandomStr() task := &createCollectionTask{ @@ -692,8 +674,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) { } marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) core := newTestCore(withInvalidIDAllocator(), withMeta(meta)) task := createCollectionTask{ @@ -716,8 +696,6 @@ func Test_createCollectionTask_Prepare(t *testing.T) { field1 := funcutil.GenRandomStr() ticker := newRocksMqTtSynchronizer() - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta)) @@ -1056,8 +1034,7 @@ func Test_createCollectionTask_PartitionKey(t *testing.T) { ).Return(map[int64][]int64{ util.DefaultDBID: {1, 2}, }, nil) - meta.On("GetDatabaseByID", mock.Anything, - mock.Anything, mock.Anything).Return(nil, errors.New("mock")) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) paramtable.Get().Save(Params.QuotaConfig.MaxCollectionNum.Key, strconv.Itoa(math.MaxInt64)) defer paramtable.Get().Reset(Params.QuotaConfig.MaxCollectionNum.Key) diff --git a/internal/rootcoord/create_partition_task.go b/internal/rootcoord/create_partition_task.go index 875ede9f3caa1..bf53bd3fca4ad 100644 --- a/internal/rootcoord/create_partition_task.go +++ b/internal/rootcoord/create_partition_task.go @@ -44,7 +44,7 @@ func (t *createPartitionTask) Prepare(ctx context.Context) error { return err } t.collMeta = collMeta - return checkGeneralCapacity(ctx, 0, 1, 0, t.core, t.ts) + return checkGeneralCapacity(ctx, 0, 1, 0, t.core) } func (t *createPartitionTask) Execute(ctx context.Context) error { diff --git a/internal/rootcoord/create_partition_task_test.go b/internal/rootcoord/create_partition_task_test.go index 880291da9baac..bf830fd04de75 100644 --- a/internal/rootcoord/create_partition_task_test.go +++ b/internal/rootcoord/create_partition_task_test.go @@ -20,7 +20,6 @@ import ( "context" "testing" - "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -62,14 +61,7 @@ func Test_createPartitionTask_Prepare(t *testing.T) { mock.Anything, mock.Anything, ).Return(coll.Clone(), nil) - meta.On("ListAllAvailCollections", - mock.Anything, - ).Return(map[int64][]int64{ - 1: {1, 2}, - }, nil) - meta.On("GetDatabaseByID", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, errors.New("mock")) + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0) core := newTestCore(withMeta(meta)) task := &createPartitionTask{ diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index fd22fe7b0f64c..f866353ba70ae 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -72,6 +72,7 @@ type IMetaTable interface { ListAliases(ctx context.Context, dbName string, collectionName string, ts Timestamp) ([]string, error) AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp) error RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error + GetGeneralCount(ctx context.Context) int // TODO: it'll be a big cost if we handle the time travel logic, since we should always list all aliases in catalog. IsAlias(db, name string) bool @@ -114,6 +115,8 @@ type MetaTable struct { dbName2Meta map[string]*model.Database // database name -> db meta collID2Meta map[typeutil.UniqueID]*model.Collection // collection id -> collection meta + generalCnt int // sum of product of partition number and shard number + // collections *collectionDb names *nameDb aliases *nameDb @@ -187,6 +190,7 @@ func (mt *MetaTable) reload() error { } for _, collection := range collections { mt.collID2Meta[collection.CollectionID] = collection + mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum) if collection.Available() { mt.names.insert(dbName, collection.Name, collection.CollectionID) collectionNum++ @@ -409,6 +413,8 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection) mt.collID2Meta[coll.CollectionID] = coll.Clone() mt.names.insert(db.Name, coll.Name, coll.CollectionID) + mt.generalCnt += len(coll.Partitions) * int(coll.ShardsNum) + log.Ctx(ctx).Info("add collection to meta table", zap.Int64("dbID", coll.DBID), zap.String("collection", coll.Name), @@ -513,6 +519,8 @@ func (mt *MetaTable) RemoveCollection(ctx context.Context, collectionID UniqueID mt.removeAllNamesIfMatchedInternal(collectionID, allNames) mt.removeCollectionByIDInternal(collectionID) + mt.generalCnt -= len(coll.Partitions) * int(coll.ShardsNum) + log.Ctx(ctx).Info("remove collection", zap.Int64("dbID", coll.DBID), zap.String("name", coll.Name), @@ -861,6 +869,8 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio } mt.collID2Meta[partition.CollectionID].Partitions = append(mt.collID2Meta[partition.CollectionID].Partitions, partition.Clone()) + mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum + log.Ctx(ctx).Info("add partition to meta table", zap.Int64("collection", partition.CollectionID), zap.String("partition", partition.PartitionName), zap.Int64("partitionid", partition.PartitionID), zap.Uint64("ts", partition.PartitionCreatedTimestamp)) @@ -925,6 +935,7 @@ func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collection } if loc != -1 { coll.Partitions = append(coll.Partitions[:loc], coll.Partitions[loc+1:]...) + mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum } log.Info("remove partition", zap.Int64("collection", collectionID), zap.Int64("partition", partitionID), zap.Uint64("ts", ts)) return nil @@ -1193,6 +1204,14 @@ func (mt *MetaTable) ListAliasesByID(collID UniqueID) []string { return mt.listAliasesByID(collID) } +// GetGeneralCount gets the general count(sum of product of partition number and shard number). +func (mt *MetaTable) GetGeneralCount(ctx context.Context) int { + mt.ddLock.RLock() + defer mt.ddLock.RUnlock() + + return mt.generalCnt +} + // AddCredential add credential func (mt *MetaTable) AddCredential(credInfo *internalpb.CredentialInfo) error { if credInfo.Username == "" { diff --git a/internal/rootcoord/mocks/meta_table.go b/internal/rootcoord/mocks/meta_table.go index 28f679f055991..0272510d1bdf5 100644 --- a/internal/rootcoord/mocks/meta_table.go +++ b/internal/rootcoord/mocks/meta_table.go @@ -574,10 +574,6 @@ func (_c *IMetaTable_CreateDatabase_Call) RunAndReturn(run func(context.Context, func (_m *IMetaTable) CreatePrivilegeGroup(groupName string) error { ret := _m.Called(groupName) - if len(ret) == 0 { - panic("no return value specified for CreatePrivilegeGroup") - } - var r0 error if rf, ok := ret.Get(0).(func(string) error); ok { r0 = rf(groupName) @@ -892,10 +888,6 @@ func (_c *IMetaTable_DropGrant_Call) RunAndReturn(run func(string, *milvuspb.Rol func (_m *IMetaTable) DropPrivilegeGroup(groupName string) error { ret := _m.Called(groupName) - if len(ret) == 0 { - panic("no return value specified for DropPrivilegeGroup") - } - var r0 error if rf, ok := ret.Get(0).(func(string) error); ok { r0 = rf(groupName) @@ -1357,14 +1349,52 @@ func (_c *IMetaTable_GetDatabaseByName_Call) RunAndReturn(run func(context.Conte return _c } +// GetGeneralCount provides a mock function with given fields: ctx +func (_m *IMetaTable) GetGeneralCount(ctx context.Context) int { + ret := _m.Called(ctx) + + var r0 int + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// IMetaTable_GetGeneralCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGeneralCount' +type IMetaTable_GetGeneralCount_Call struct { + *mock.Call +} + +// GetGeneralCount is a helper method to define mock.On call +// - ctx context.Context +func (_e *IMetaTable_Expecter) GetGeneralCount(ctx interface{}) *IMetaTable_GetGeneralCount_Call { + return &IMetaTable_GetGeneralCount_Call{Call: _e.mock.On("GetGeneralCount", ctx)} +} + +func (_c *IMetaTable_GetGeneralCount_Call) Run(run func(ctx context.Context)) *IMetaTable_GetGeneralCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *IMetaTable_GetGeneralCount_Call) Return(_a0 int) *IMetaTable_GetGeneralCount_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *IMetaTable_GetGeneralCount_Call) RunAndReturn(run func(context.Context) int) *IMetaTable_GetGeneralCount_Call { + _c.Call.Return(run) + return _c +} + // GetPrivilegeGroupRoles provides a mock function with given fields: groupName func (_m *IMetaTable) GetPrivilegeGroupRoles(groupName string) ([]*milvuspb.RoleEntity, error) { ret := _m.Called(groupName) - if len(ret) == 0 { - panic("no return value specified for GetPrivilegeGroupRoles") - } - var r0 []*milvuspb.RoleEntity var r1 error if rf, ok := ret.Get(0).(func(string) ([]*milvuspb.RoleEntity, error)); ok { @@ -1462,10 +1492,6 @@ func (_c *IMetaTable_IsAlias_Call) RunAndReturn(run func(string, string) bool) * func (_m *IMetaTable) IsCustomPrivilegeGroup(groupName string) (bool, error) { ret := _m.Called(groupName) - if len(ret) == 0 { - panic("no return value specified for IsCustomPrivilegeGroup") - } - var r0 bool var r1 error if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { @@ -1925,10 +1951,6 @@ func (_c *IMetaTable_ListPolicy_Call) RunAndReturn(run func(string) ([]string, e func (_m *IMetaTable) ListPrivilegeGroups() ([]*milvuspb.PrivilegeGroupInfo, error) { ret := _m.Called() - if len(ret) == 0 { - panic("no return value specified for ListPrivilegeGroups") - } - var r0 []*milvuspb.PrivilegeGroupInfo var r1 error if rf, ok := ret.Get(0).(func() ([]*milvuspb.PrivilegeGroupInfo, error)); ok { @@ -2080,10 +2102,6 @@ func (_c *IMetaTable_OperatePrivilege_Call) RunAndReturn(run func(string, *milvu func (_m *IMetaTable) OperatePrivilegeGroup(groupName string, privileges []*milvuspb.PrivilegeEntity, operateType milvuspb.OperatePrivilegeGroupType) error { ret := _m.Called(groupName, privileges, operateType) - if len(ret) == 0 { - panic("no return value specified for OperatePrivilegeGroup") - } - var r0 error if rf, ok := ret.Get(0).(func(string, []*milvuspb.PrivilegeEntity, milvuspb.OperatePrivilegeGroupType) error); ok { r0 = rf(groupName, privileges, operateType)