From e76802f91054e3b724fea5758f9305933bb22b9b Mon Sep 17 00:00:00 2001 From: tinswzy Date: Mon, 25 Nov 2024 11:14:34 +0800 Subject: [PATCH] enhance: refine querycoord meta/catalog related interfaces to ensure that each method includes a ctx parameter (#37916) issue: #35917 This PR refine the querycoord meta related interfaces to ensure that each method includes a ctx parameter. Signed-off-by: tinswzy --- internal/metastore/catalog.go | 34 +- .../metastore/kv/querycoord/kv_catalog.go | 39 +- .../kv/querycoord/kv_catalog_test.go | 86 +-- .../mocks/mock_querycoord_catalog.go | 471 ++++++++------- internal/querycoordv2/balance/balance.go | 13 +- internal/querycoordv2/balance/balance_test.go | 7 +- .../balance/channel_level_score_balancer.go | 35 +- .../channel_level_score_balancer_test.go | 142 ++--- .../querycoordv2/balance/mock_balancer.go | 73 +-- .../balance/multi_target_balance.go | 15 +- .../balance/rowcount_based_balancer.go | 34 +- .../balance/rowcount_based_balancer_test.go | 88 +-- .../balance/score_based_balancer.go | 23 +- .../balance/score_based_balancer_test.go | 124 ++-- .../querycoordv2/checkers/balance_checker.go | 32 +- .../checkers/balance_checker_test.go | 106 ++-- .../querycoordv2/checkers/channel_checker.go | 30 +- .../checkers/channel_checker_test.go | 42 +- .../querycoordv2/checkers/controller_test.go | 18 +- .../querycoordv2/checkers/index_checker.go | 8 +- .../checkers/index_checker_test.go | 42 +- .../querycoordv2/checkers/leader_checker.go | 20 +- .../checkers/leader_checker_test.go | 122 ++-- .../querycoordv2/checkers/segment_checker.go | 61 +- .../checkers/segment_checker_test.go | 157 ++--- internal/querycoordv2/dist/dist_controller.go | 2 +- internal/querycoordv2/dist/dist_handler.go | 22 +- .../querycoordv2/dist/dist_handler_test.go | 12 +- internal/querycoordv2/handlers.go | 16 +- internal/querycoordv2/job/job_load.go | 38 +- internal/querycoordv2/job/job_release.go | 18 +- internal/querycoordv2/job/job_sync.go | 6 +- internal/querycoordv2/job/job_test.go | 142 ++--- internal/querycoordv2/job/job_update.go | 20 +- internal/querycoordv2/job/undo.go | 4 +- internal/querycoordv2/job/utils.go | 4 +- .../querycoordv2/meta/collection_manager.go | 116 ++-- .../meta/collection_manager_test.go | 188 +++--- .../querycoordv2/meta/mock_target_manager.go | 414 ++++++------- internal/querycoordv2/meta/replica_manager.go | 63 +- .../querycoordv2/meta/replica_manager_test.go | 105 ++-- .../querycoordv2/meta/resource_manager.go | 101 ++-- .../meta/resource_manager_test.go | 542 +++++++++--------- internal/querycoordv2/meta/target_manager.go | 90 +-- .../querycoordv2/meta/target_manager_test.go | 194 ++++--- .../observers/collection_observer.go | 50 +- .../observers/collection_observer_test.go | 63 +- .../observers/replica_observer.go | 13 +- .../observers/replica_observer_test.go | 28 +- .../observers/resource_observer.go | 18 +- .../observers/resource_observer_test.go | 90 +-- .../querycoordv2/observers/target_observer.go | 60 +- .../observers/target_observer_test.go | 41 +- internal/querycoordv2/ops_service_test.go | 28 +- internal/querycoordv2/ops_services.go | 16 +- internal/querycoordv2/server.go | 38 +- internal/querycoordv2/server_test.go | 24 +- internal/querycoordv2/services.go | 78 +-- internal/querycoordv2/services_test.go | 270 ++++----- internal/querycoordv2/task/executor.go | 24 +- internal/querycoordv2/task/scheduler.go | 24 +- internal/querycoordv2/task/task_test.go | 64 ++- internal/querycoordv2/utils/meta.go | 46 +- internal/querycoordv2/utils/meta_test.go | 80 +-- internal/querycoordv2/utils/util.go | 26 +- internal/querycoordv2/utils/util_test.go | 20 +- 66 files changed, 2667 insertions(+), 2353 deletions(-) diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 926339e188f62..382e103708e73 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -182,23 +182,23 @@ type DataCoordCatalog interface { } type QueryCoordCatalog interface { - SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error - SavePartition(info ...*querypb.PartitionLoadInfo) error - SaveReplica(replicas ...*querypb.Replica) error - GetCollections() ([]*querypb.CollectionLoadInfo, error) - GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) - GetReplicas() ([]*querypb.Replica, error) - ReleaseCollection(collection int64) error - ReleasePartition(collection int64, partitions ...int64) error - ReleaseReplicas(collectionID int64) error - ReleaseReplica(collection int64, replicas ...int64) error - SaveResourceGroup(rgs ...*querypb.ResourceGroup) error - RemoveResourceGroup(rgName string) error - GetResourceGroups() ([]*querypb.ResourceGroup, error) - - SaveCollectionTargets(target ...*querypb.CollectionTarget) error - RemoveCollectionTarget(collectionID int64) error - GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) + SaveCollection(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error + SavePartition(ctx context.Context, info ...*querypb.PartitionLoadInfo) error + SaveReplica(ctx context.Context, replicas ...*querypb.Replica) error + GetCollections(ctx context.Context) ([]*querypb.CollectionLoadInfo, error) + GetPartitions(ctx context.Context) (map[int64][]*querypb.PartitionLoadInfo, error) + GetReplicas(ctx context.Context) ([]*querypb.Replica, error) + ReleaseCollection(ctx context.Context, collection int64) error + ReleasePartition(ctx context.Context, collection int64, partitions ...int64) error + ReleaseReplicas(ctx context.Context, collectionID int64) error + ReleaseReplica(ctx context.Context, collection int64, replicas ...int64) error + SaveResourceGroup(ctx context.Context, rgs ...*querypb.ResourceGroup) error + RemoveResourceGroup(ctx context.Context, rgName string) error + GetResourceGroups(ctx context.Context) ([]*querypb.ResourceGroup, error) + + SaveCollectionTargets(ctx context.Context, target ...*querypb.CollectionTarget) error + RemoveCollectionTarget(ctx context.Context, collectionID int64) error + GetCollectionTargets(ctx context.Context) (map[int64]*querypb.CollectionTarget, error) } // StreamingCoordCataLog is the interface for streamingcoord catalog diff --git a/internal/metastore/kv/querycoord/kv_catalog.go b/internal/metastore/kv/querycoord/kv_catalog.go index ce546d52340c4..2c7d5edc38f41 100644 --- a/internal/metastore/kv/querycoord/kv_catalog.go +++ b/internal/metastore/kv/querycoord/kv_catalog.go @@ -2,6 +2,7 @@ package querycoord import ( "bytes" + "context" "fmt" "io" @@ -42,7 +43,7 @@ func NewCatalog(cli kv.MetaKv) Catalog { } } -func (s Catalog) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error { +func (s Catalog) SaveCollection(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error { k := EncodeCollectionLoadInfoKey(collection.GetCollectionID()) v, err := proto.Marshal(collection) if err != nil { @@ -52,10 +53,10 @@ func (s Catalog) SaveCollection(collection *querypb.CollectionLoadInfo, partitio if err != nil { return err } - return s.SavePartition(partitions...) + return s.SavePartition(ctx, partitions...) } -func (s Catalog) SavePartition(info ...*querypb.PartitionLoadInfo) error { +func (s Catalog) SavePartition(ctx context.Context, info ...*querypb.PartitionLoadInfo) error { for _, partition := range info { k := EncodePartitionLoadInfoKey(partition.GetCollectionID(), partition.GetPartitionID()) v, err := proto.Marshal(partition) @@ -70,7 +71,7 @@ func (s Catalog) SavePartition(info ...*querypb.PartitionLoadInfo) error { return nil } -func (s Catalog) SaveReplica(replicas ...*querypb.Replica) error { +func (s Catalog) SaveReplica(ctx context.Context, replicas ...*querypb.Replica) error { kvs := make(map[string]string) for _, replica := range replicas { key := encodeReplicaKey(replica.GetCollectionID(), replica.GetID()) @@ -83,7 +84,7 @@ func (s Catalog) SaveReplica(replicas ...*querypb.Replica) error { return s.cli.MultiSave(kvs) } -func (s Catalog) SaveResourceGroup(rgs ...*querypb.ResourceGroup) error { +func (s Catalog) SaveResourceGroup(ctx context.Context, rgs ...*querypb.ResourceGroup) error { ret := make(map[string]string) for _, rg := range rgs { key := encodeResourceGroupKey(rg.GetName()) @@ -98,12 +99,12 @@ func (s Catalog) SaveResourceGroup(rgs ...*querypb.ResourceGroup) error { return s.cli.MultiSave(ret) } -func (s Catalog) RemoveResourceGroup(rgName string) error { +func (s Catalog) RemoveResourceGroup(ctx context.Context, rgName string) error { key := encodeResourceGroupKey(rgName) return s.cli.Remove(key) } -func (s Catalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) { +func (s Catalog) GetCollections(ctx context.Context) ([]*querypb.CollectionLoadInfo, error) { _, values, err := s.cli.LoadWithPrefix(CollectionLoadInfoPrefix) if err != nil { return nil, err @@ -120,7 +121,7 @@ func (s Catalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) { return ret, nil } -func (s Catalog) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) { +func (s Catalog) GetPartitions(ctx context.Context) (map[int64][]*querypb.PartitionLoadInfo, error) { _, values, err := s.cli.LoadWithPrefix(PartitionLoadInfoPrefix) if err != nil { return nil, err @@ -137,7 +138,7 @@ func (s Catalog) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) return ret, nil } -func (s Catalog) GetReplicas() ([]*querypb.Replica, error) { +func (s Catalog) GetReplicas(ctx context.Context) ([]*querypb.Replica, error) { _, values, err := s.cli.LoadWithPrefix(ReplicaPrefix) if err != nil { return nil, err @@ -151,7 +152,7 @@ func (s Catalog) GetReplicas() ([]*querypb.Replica, error) { ret = append(ret, &info) } - replicasV1, err := s.getReplicasFromV1() + replicasV1, err := s.getReplicasFromV1(ctx) if err != nil { return nil, err } @@ -160,7 +161,7 @@ func (s Catalog) GetReplicas() ([]*querypb.Replica, error) { return ret, nil } -func (s Catalog) getReplicasFromV1() ([]*querypb.Replica, error) { +func (s Catalog) getReplicasFromV1(ctx context.Context) ([]*querypb.Replica, error) { _, replicaValues, err := s.cli.LoadWithPrefix(ReplicaMetaPrefixV1) if err != nil { return nil, err @@ -183,7 +184,7 @@ func (s Catalog) getReplicasFromV1() ([]*querypb.Replica, error) { return ret, nil } -func (s Catalog) GetResourceGroups() ([]*querypb.ResourceGroup, error) { +func (s Catalog) GetResourceGroups(ctx context.Context) ([]*querypb.ResourceGroup, error) { _, rgs, err := s.cli.LoadWithPrefix(ResourceGroupPrefix) if err != nil { return nil, err @@ -202,7 +203,7 @@ func (s Catalog) GetResourceGroups() ([]*querypb.ResourceGroup, error) { return ret, nil } -func (s Catalog) ReleaseCollection(collection int64) error { +func (s Catalog) ReleaseCollection(ctx context.Context, collection int64) error { // remove collection and obtained partitions collectionKey := EncodeCollectionLoadInfoKey(collection) err := s.cli.Remove(collectionKey) @@ -213,7 +214,7 @@ func (s Catalog) ReleaseCollection(collection int64) error { return s.cli.RemoveWithPrefix(partitionsPrefix) } -func (s Catalog) ReleasePartition(collection int64, partitions ...int64) error { +func (s Catalog) ReleasePartition(ctx context.Context, collection int64, partitions ...int64) error { keys := lo.Map(partitions, func(partition int64, _ int) string { return EncodePartitionLoadInfoKey(collection, partition) }) @@ -235,12 +236,12 @@ func (s Catalog) ReleasePartition(collection int64, partitions ...int64) error { return s.cli.MultiRemove(keys) } -func (s Catalog) ReleaseReplicas(collectionID int64) error { +func (s Catalog) ReleaseReplicas(ctx context.Context, collectionID int64) error { key := encodeCollectionReplicaKey(collectionID) return s.cli.RemoveWithPrefix(key) } -func (s Catalog) ReleaseReplica(collection int64, replicas ...int64) error { +func (s Catalog) ReleaseReplica(ctx context.Context, collection int64, replicas ...int64) error { keys := lo.Map(replicas, func(replica int64, _ int) string { return encodeReplicaKey(collection, replica) }) @@ -262,7 +263,7 @@ func (s Catalog) ReleaseReplica(collection int64, replicas ...int64) error { return s.cli.MultiRemove(keys) } -func (s Catalog) SaveCollectionTargets(targets ...*querypb.CollectionTarget) error { +func (s Catalog) SaveCollectionTargets(ctx context.Context, targets ...*querypb.CollectionTarget) error { kvs := make(map[string]string) for _, target := range targets { k := encodeCollectionTargetKey(target.GetCollectionID()) @@ -283,12 +284,12 @@ func (s Catalog) SaveCollectionTargets(targets ...*querypb.CollectionTarget) err return nil } -func (s Catalog) RemoveCollectionTarget(collectionID int64) error { +func (s Catalog) RemoveCollectionTarget(ctx context.Context, collectionID int64) error { k := encodeCollectionTargetKey(collectionID) return s.cli.Remove(k) } -func (s Catalog) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) { +func (s Catalog) GetCollectionTargets(ctx context.Context) (map[int64]*querypb.CollectionTarget, error) { keys, values, err := s.cli.LoadWithPrefix(CollectionTargetPrefix) if err != nil { return nil, err diff --git a/internal/metastore/kv/querycoord/kv_catalog_test.go b/internal/metastore/kv/querycoord/kv_catalog_test.go index 6dbdadfb1f004..9bde8491a6604 100644 --- a/internal/metastore/kv/querycoord/kv_catalog_test.go +++ b/internal/metastore/kv/querycoord/kv_catalog_test.go @@ -1,6 +1,7 @@ package querycoord import ( + "context" "sort" "testing" @@ -50,53 +51,55 @@ func (suite *CatalogTestSuite) TearDownTest() { } func (suite *CatalogTestSuite) TestCollection() { - suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{ + ctx := context.Background() + suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{ CollectionID: 1, }) - suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{ + suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{ CollectionID: 2, }) - suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{ + suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{ CollectionID: 3, }) - suite.catalog.ReleaseCollection(1) - suite.catalog.ReleaseCollection(2) + suite.catalog.ReleaseCollection(ctx, 1) + suite.catalog.ReleaseCollection(ctx, 2) - collections, err := suite.catalog.GetCollections() + collections, err := suite.catalog.GetCollections(ctx) suite.NoError(err) suite.Len(collections, 1) } func (suite *CatalogTestSuite) TestCollectionWithPartition() { - suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{ + ctx := context.Background() + suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{ CollectionID: 1, }) - suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{ + suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{ CollectionID: 2, }, &querypb.PartitionLoadInfo{ CollectionID: 2, PartitionID: 102, }) - suite.catalog.SaveCollection(&querypb.CollectionLoadInfo{ + suite.catalog.SaveCollection(ctx, &querypb.CollectionLoadInfo{ CollectionID: 3, }, &querypb.PartitionLoadInfo{ CollectionID: 3, PartitionID: 103, }) - suite.catalog.ReleaseCollection(1) - suite.catalog.ReleaseCollection(2) + suite.catalog.ReleaseCollection(ctx, 1) + suite.catalog.ReleaseCollection(ctx, 2) - collections, err := suite.catalog.GetCollections() + collections, err := suite.catalog.GetCollections(ctx) suite.NoError(err) suite.Len(collections, 1) suite.Equal(int64(3), collections[0].GetCollectionID()) - partitions, err := suite.catalog.GetPartitions() + partitions, err := suite.catalog.GetPartitions(ctx) suite.NoError(err) suite.Len(partitions, 1) suite.Len(partitions[int64(3)], 1) @@ -104,88 +107,92 @@ func (suite *CatalogTestSuite) TestCollectionWithPartition() { } func (suite *CatalogTestSuite) TestPartition() { - suite.catalog.SavePartition(&querypb.PartitionLoadInfo{ + ctx := context.Background() + suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{ PartitionID: 1, }) - suite.catalog.SavePartition(&querypb.PartitionLoadInfo{ + suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{ PartitionID: 2, }) - suite.catalog.SavePartition(&querypb.PartitionLoadInfo{ + suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{ PartitionID: 3, }) - suite.catalog.ReleasePartition(1) - suite.catalog.ReleasePartition(2) + suite.catalog.ReleasePartition(ctx, 1) + suite.catalog.ReleasePartition(ctx, 2) - partitions, err := suite.catalog.GetPartitions() + partitions, err := suite.catalog.GetPartitions(ctx) suite.NoError(err) suite.Len(partitions, 1) } func (suite *CatalogTestSuite) TestReleaseManyPartitions() { + ctx := context.Background() partitionIDs := make([]int64, 0) for i := 1; i <= 150; i++ { - suite.catalog.SavePartition(&querypb.PartitionLoadInfo{ + suite.catalog.SavePartition(ctx, &querypb.PartitionLoadInfo{ CollectionID: 1, PartitionID: int64(i), }) partitionIDs = append(partitionIDs, int64(i)) } - err := suite.catalog.ReleasePartition(1, partitionIDs...) + err := suite.catalog.ReleasePartition(ctx, 1, partitionIDs...) suite.NoError(err) - partitions, err := suite.catalog.GetPartitions() + partitions, err := suite.catalog.GetPartitions(ctx) suite.NoError(err) suite.Len(partitions, 0) } func (suite *CatalogTestSuite) TestReplica() { - suite.catalog.SaveReplica(&querypb.Replica{ + ctx := context.Background() + suite.catalog.SaveReplica(ctx, &querypb.Replica{ CollectionID: 1, ID: 1, }) - suite.catalog.SaveReplica(&querypb.Replica{ + suite.catalog.SaveReplica(ctx, &querypb.Replica{ CollectionID: 1, ID: 2, }) - suite.catalog.SaveReplica(&querypb.Replica{ + suite.catalog.SaveReplica(ctx, &querypb.Replica{ CollectionID: 1, ID: 3, }) - suite.catalog.ReleaseReplica(1, 1) - suite.catalog.ReleaseReplica(1, 2) + suite.catalog.ReleaseReplica(ctx, 1, 1) + suite.catalog.ReleaseReplica(ctx, 1, 2) - replicas, err := suite.catalog.GetReplicas() + replicas, err := suite.catalog.GetReplicas(ctx) suite.NoError(err) suite.Len(replicas, 1) } func (suite *CatalogTestSuite) TestResourceGroup() { - suite.catalog.SaveResourceGroup(&querypb.ResourceGroup{ + ctx := context.Background() + suite.catalog.SaveResourceGroup(ctx, &querypb.ResourceGroup{ Name: "rg1", Capacity: 3, Nodes: []int64{1, 2, 3}, }) - suite.catalog.SaveResourceGroup(&querypb.ResourceGroup{ + suite.catalog.SaveResourceGroup(ctx, &querypb.ResourceGroup{ Name: "rg2", Capacity: 3, Nodes: []int64{4, 5}, }) - suite.catalog.SaveResourceGroup(&querypb.ResourceGroup{ + suite.catalog.SaveResourceGroup(ctx, &querypb.ResourceGroup{ Name: "rg3", Capacity: 0, Nodes: []int64{}, }) - suite.catalog.RemoveResourceGroup("rg3") + suite.catalog.RemoveResourceGroup(ctx, "rg3") - groups, err := suite.catalog.GetResourceGroups() + groups, err := suite.catalog.GetResourceGroups(ctx) suite.NoError(err) suite.Len(groups, 2) @@ -203,7 +210,8 @@ func (suite *CatalogTestSuite) TestResourceGroup() { } func (suite *CatalogTestSuite) TestCollectionTarget() { - suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{ + ctx := context.Background() + suite.catalog.SaveCollectionTargets(ctx, &querypb.CollectionTarget{ CollectionID: 1, Version: 1, }, @@ -219,9 +227,9 @@ func (suite *CatalogTestSuite) TestCollectionTarget() { CollectionID: 1, Version: 4, }) - suite.catalog.RemoveCollectionTarget(2) + suite.catalog.RemoveCollectionTarget(ctx, 2) - targets, err := suite.catalog.GetCollectionTargets() + targets, err := suite.catalog.GetCollectionTargets(ctx) suite.NoError(err) suite.Len(targets, 2) suite.Equal(int64(4), targets[1].Version) @@ -234,14 +242,14 @@ func (suite *CatalogTestSuite) TestCollectionTarget() { mockStore.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, mockErr) suite.catalog.cli = mockStore - err = suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{}) + err = suite.catalog.SaveCollectionTargets(ctx, &querypb.CollectionTarget{}) suite.ErrorIs(err, mockErr) - _, err = suite.catalog.GetCollectionTargets() + _, err = suite.catalog.GetCollectionTargets(ctx) suite.ErrorIs(err, mockErr) // test invalid message - err = suite.catalog.SaveCollectionTargets(nil) + err = suite.catalog.SaveCollectionTargets(ctx) suite.Error(err) } diff --git a/internal/metastore/mocks/mock_querycoord_catalog.go b/internal/metastore/mocks/mock_querycoord_catalog.go index 3ce66e8441206..c8368bfcbf31a 100644 --- a/internal/metastore/mocks/mock_querycoord_catalog.go +++ b/internal/metastore/mocks/mock_querycoord_catalog.go @@ -1,10 +1,13 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. +// Code generated by mockery v2.46.0. DO NOT EDIT. package mocks import ( - querypb "github.com/milvus-io/milvus/internal/proto/querypb" + context "context" + mock "github.com/stretchr/testify/mock" + + querypb "github.com/milvus-io/milvus/internal/proto/querypb" ) // QueryCoordCatalog is an autogenerated mock type for the QueryCoordCatalog type @@ -20,25 +23,29 @@ func (_m *QueryCoordCatalog) EXPECT() *QueryCoordCatalog_Expecter { return &QueryCoordCatalog_Expecter{mock: &_m.Mock} } -// GetCollectionTargets provides a mock function with given fields: -func (_m *QueryCoordCatalog) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) { - ret := _m.Called() +// GetCollectionTargets provides a mock function with given fields: ctx +func (_m *QueryCoordCatalog) GetCollectionTargets(ctx context.Context) (map[int64]*querypb.CollectionTarget, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetCollectionTargets") + } var r0 map[int64]*querypb.CollectionTarget var r1 error - if rf, ok := ret.Get(0).(func() (map[int64]*querypb.CollectionTarget, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*querypb.CollectionTarget, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() map[int64]*querypb.CollectionTarget); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) map[int64]*querypb.CollectionTarget); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[int64]*querypb.CollectionTarget) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -52,13 +59,14 @@ type QueryCoordCatalog_GetCollectionTargets_Call struct { } // GetCollectionTargets is a helper method to define mock.On call -func (_e *QueryCoordCatalog_Expecter) GetCollectionTargets() *QueryCoordCatalog_GetCollectionTargets_Call { - return &QueryCoordCatalog_GetCollectionTargets_Call{Call: _e.mock.On("GetCollectionTargets")} +// - ctx context.Context +func (_e *QueryCoordCatalog_Expecter) GetCollectionTargets(ctx interface{}) *QueryCoordCatalog_GetCollectionTargets_Call { + return &QueryCoordCatalog_GetCollectionTargets_Call{Call: _e.mock.On("GetCollectionTargets", ctx)} } -func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Run(run func()) *QueryCoordCatalog_GetCollectionTargets_Call { +func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetCollectionTargets_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -68,30 +76,34 @@ func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Return(_a0 map[int64]*que return _c } -func (_c *QueryCoordCatalog_GetCollectionTargets_Call) RunAndReturn(run func() (map[int64]*querypb.CollectionTarget, error)) *QueryCoordCatalog_GetCollectionTargets_Call { +func (_c *QueryCoordCatalog_GetCollectionTargets_Call) RunAndReturn(run func(context.Context) (map[int64]*querypb.CollectionTarget, error)) *QueryCoordCatalog_GetCollectionTargets_Call { _c.Call.Return(run) return _c } -// GetCollections provides a mock function with given fields: -func (_m *QueryCoordCatalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) { - ret := _m.Called() +// GetCollections provides a mock function with given fields: ctx +func (_m *QueryCoordCatalog) GetCollections(ctx context.Context) ([]*querypb.CollectionLoadInfo, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetCollections") + } var r0 []*querypb.CollectionLoadInfo var r1 error - if rf, ok := ret.Get(0).(func() ([]*querypb.CollectionLoadInfo, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]*querypb.CollectionLoadInfo, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []*querypb.CollectionLoadInfo); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*querypb.CollectionLoadInfo); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*querypb.CollectionLoadInfo) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -105,13 +117,14 @@ type QueryCoordCatalog_GetCollections_Call struct { } // GetCollections is a helper method to define mock.On call -func (_e *QueryCoordCatalog_Expecter) GetCollections() *QueryCoordCatalog_GetCollections_Call { - return &QueryCoordCatalog_GetCollections_Call{Call: _e.mock.On("GetCollections")} +// - ctx context.Context +func (_e *QueryCoordCatalog_Expecter) GetCollections(ctx interface{}) *QueryCoordCatalog_GetCollections_Call { + return &QueryCoordCatalog_GetCollections_Call{Call: _e.mock.On("GetCollections", ctx)} } -func (_c *QueryCoordCatalog_GetCollections_Call) Run(run func()) *QueryCoordCatalog_GetCollections_Call { +func (_c *QueryCoordCatalog_GetCollections_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetCollections_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -121,30 +134,34 @@ func (_c *QueryCoordCatalog_GetCollections_Call) Return(_a0 []*querypb.Collectio return _c } -func (_c *QueryCoordCatalog_GetCollections_Call) RunAndReturn(run func() ([]*querypb.CollectionLoadInfo, error)) *QueryCoordCatalog_GetCollections_Call { +func (_c *QueryCoordCatalog_GetCollections_Call) RunAndReturn(run func(context.Context) ([]*querypb.CollectionLoadInfo, error)) *QueryCoordCatalog_GetCollections_Call { _c.Call.Return(run) return _c } -// GetPartitions provides a mock function with given fields: -func (_m *QueryCoordCatalog) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) { - ret := _m.Called() +// GetPartitions provides a mock function with given fields: ctx +func (_m *QueryCoordCatalog) GetPartitions(ctx context.Context) (map[int64][]*querypb.PartitionLoadInfo, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetPartitions") + } var r0 map[int64][]*querypb.PartitionLoadInfo var r1 error - if rf, ok := ret.Get(0).(func() (map[int64][]*querypb.PartitionLoadInfo, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (map[int64][]*querypb.PartitionLoadInfo, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() map[int64][]*querypb.PartitionLoadInfo); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) map[int64][]*querypb.PartitionLoadInfo); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[int64][]*querypb.PartitionLoadInfo) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -158,13 +175,14 @@ type QueryCoordCatalog_GetPartitions_Call struct { } // GetPartitions is a helper method to define mock.On call -func (_e *QueryCoordCatalog_Expecter) GetPartitions() *QueryCoordCatalog_GetPartitions_Call { - return &QueryCoordCatalog_GetPartitions_Call{Call: _e.mock.On("GetPartitions")} +// - ctx context.Context +func (_e *QueryCoordCatalog_Expecter) GetPartitions(ctx interface{}) *QueryCoordCatalog_GetPartitions_Call { + return &QueryCoordCatalog_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx)} } -func (_c *QueryCoordCatalog_GetPartitions_Call) Run(run func()) *QueryCoordCatalog_GetPartitions_Call { +func (_c *QueryCoordCatalog_GetPartitions_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetPartitions_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -174,30 +192,34 @@ func (_c *QueryCoordCatalog_GetPartitions_Call) Return(_a0 map[int64][]*querypb. return _c } -func (_c *QueryCoordCatalog_GetPartitions_Call) RunAndReturn(run func() (map[int64][]*querypb.PartitionLoadInfo, error)) *QueryCoordCatalog_GetPartitions_Call { +func (_c *QueryCoordCatalog_GetPartitions_Call) RunAndReturn(run func(context.Context) (map[int64][]*querypb.PartitionLoadInfo, error)) *QueryCoordCatalog_GetPartitions_Call { _c.Call.Return(run) return _c } -// GetReplicas provides a mock function with given fields: -func (_m *QueryCoordCatalog) GetReplicas() ([]*querypb.Replica, error) { - ret := _m.Called() +// GetReplicas provides a mock function with given fields: ctx +func (_m *QueryCoordCatalog) GetReplicas(ctx context.Context) ([]*querypb.Replica, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetReplicas") + } var r0 []*querypb.Replica var r1 error - if rf, ok := ret.Get(0).(func() ([]*querypb.Replica, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]*querypb.Replica, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []*querypb.Replica); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*querypb.Replica); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*querypb.Replica) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -211,13 +233,14 @@ type QueryCoordCatalog_GetReplicas_Call struct { } // GetReplicas is a helper method to define mock.On call -func (_e *QueryCoordCatalog_Expecter) GetReplicas() *QueryCoordCatalog_GetReplicas_Call { - return &QueryCoordCatalog_GetReplicas_Call{Call: _e.mock.On("GetReplicas")} +// - ctx context.Context +func (_e *QueryCoordCatalog_Expecter) GetReplicas(ctx interface{}) *QueryCoordCatalog_GetReplicas_Call { + return &QueryCoordCatalog_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx)} } -func (_c *QueryCoordCatalog_GetReplicas_Call) Run(run func()) *QueryCoordCatalog_GetReplicas_Call { +func (_c *QueryCoordCatalog_GetReplicas_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetReplicas_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -227,30 +250,34 @@ func (_c *QueryCoordCatalog_GetReplicas_Call) Return(_a0 []*querypb.Replica, _a1 return _c } -func (_c *QueryCoordCatalog_GetReplicas_Call) RunAndReturn(run func() ([]*querypb.Replica, error)) *QueryCoordCatalog_GetReplicas_Call { +func (_c *QueryCoordCatalog_GetReplicas_Call) RunAndReturn(run func(context.Context) ([]*querypb.Replica, error)) *QueryCoordCatalog_GetReplicas_Call { _c.Call.Return(run) return _c } -// GetResourceGroups provides a mock function with given fields: -func (_m *QueryCoordCatalog) GetResourceGroups() ([]*querypb.ResourceGroup, error) { - ret := _m.Called() +// GetResourceGroups provides a mock function with given fields: ctx +func (_m *QueryCoordCatalog) GetResourceGroups(ctx context.Context) ([]*querypb.ResourceGroup, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetResourceGroups") + } var r0 []*querypb.ResourceGroup var r1 error - if rf, ok := ret.Get(0).(func() ([]*querypb.ResourceGroup, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]*querypb.ResourceGroup, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []*querypb.ResourceGroup); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*querypb.ResourceGroup); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*querypb.ResourceGroup) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -264,13 +291,14 @@ type QueryCoordCatalog_GetResourceGroups_Call struct { } // GetResourceGroups is a helper method to define mock.On call -func (_e *QueryCoordCatalog_Expecter) GetResourceGroups() *QueryCoordCatalog_GetResourceGroups_Call { - return &QueryCoordCatalog_GetResourceGroups_Call{Call: _e.mock.On("GetResourceGroups")} +// - ctx context.Context +func (_e *QueryCoordCatalog_Expecter) GetResourceGroups(ctx interface{}) *QueryCoordCatalog_GetResourceGroups_Call { + return &QueryCoordCatalog_GetResourceGroups_Call{Call: _e.mock.On("GetResourceGroups", ctx)} } -func (_c *QueryCoordCatalog_GetResourceGroups_Call) Run(run func()) *QueryCoordCatalog_GetResourceGroups_Call { +func (_c *QueryCoordCatalog_GetResourceGroups_Call) Run(run func(ctx context.Context)) *QueryCoordCatalog_GetResourceGroups_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -280,18 +308,22 @@ func (_c *QueryCoordCatalog_GetResourceGroups_Call) Return(_a0 []*querypb.Resour return _c } -func (_c *QueryCoordCatalog_GetResourceGroups_Call) RunAndReturn(run func() ([]*querypb.ResourceGroup, error)) *QueryCoordCatalog_GetResourceGroups_Call { +func (_c *QueryCoordCatalog_GetResourceGroups_Call) RunAndReturn(run func(context.Context) ([]*querypb.ResourceGroup, error)) *QueryCoordCatalog_GetResourceGroups_Call { _c.Call.Return(run) return _c } -// ReleaseCollection provides a mock function with given fields: collection -func (_m *QueryCoordCatalog) ReleaseCollection(collection int64) error { - ret := _m.Called(collection) +// ReleaseCollection provides a mock function with given fields: ctx, collection +func (_m *QueryCoordCatalog) ReleaseCollection(ctx context.Context, collection int64) error { + ret := _m.Called(ctx, collection) + + if len(ret) == 0 { + panic("no return value specified for ReleaseCollection") + } var r0 error - if rf, ok := ret.Get(0).(func(int64) error); ok { - r0 = rf(collection) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, collection) } else { r0 = ret.Error(0) } @@ -305,14 +337,15 @@ type QueryCoordCatalog_ReleaseCollection_Call struct { } // ReleaseCollection is a helper method to define mock.On call +// - ctx context.Context // - collection int64 -func (_e *QueryCoordCatalog_Expecter) ReleaseCollection(collection interface{}) *QueryCoordCatalog_ReleaseCollection_Call { - return &QueryCoordCatalog_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)} +func (_e *QueryCoordCatalog_Expecter) ReleaseCollection(ctx interface{}, collection interface{}) *QueryCoordCatalog_ReleaseCollection_Call { + return &QueryCoordCatalog_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, collection)} } -func (_c *QueryCoordCatalog_ReleaseCollection_Call) Run(run func(collection int64)) *QueryCoordCatalog_ReleaseCollection_Call { +func (_c *QueryCoordCatalog_ReleaseCollection_Call) Run(run func(ctx context.Context, collection int64)) *QueryCoordCatalog_ReleaseCollection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -322,25 +355,29 @@ func (_c *QueryCoordCatalog_ReleaseCollection_Call) Return(_a0 error) *QueryCoor return _c } -func (_c *QueryCoordCatalog_ReleaseCollection_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_ReleaseCollection_Call { +func (_c *QueryCoordCatalog_ReleaseCollection_Call) RunAndReturn(run func(context.Context, int64) error) *QueryCoordCatalog_ReleaseCollection_Call { _c.Call.Return(run) return _c } -// ReleasePartition provides a mock function with given fields: collection, partitions -func (_m *QueryCoordCatalog) ReleasePartition(collection int64, partitions ...int64) error { +// ReleasePartition provides a mock function with given fields: ctx, collection, partitions +func (_m *QueryCoordCatalog) ReleasePartition(ctx context.Context, collection int64, partitions ...int64) error { _va := make([]interface{}, len(partitions)) for _i := range partitions { _va[_i] = partitions[_i] } var _ca []interface{} - _ca = append(_ca, collection) + _ca = append(_ca, ctx, collection) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for ReleasePartition") + } + var r0 error - if rf, ok := ret.Get(0).(func(int64, ...int64) error); ok { - r0 = rf(collection, partitions...) + if rf, ok := ret.Get(0).(func(context.Context, int64, ...int64) error); ok { + r0 = rf(ctx, collection, partitions...) } else { r0 = ret.Error(0) } @@ -354,22 +391,23 @@ type QueryCoordCatalog_ReleasePartition_Call struct { } // ReleasePartition is a helper method to define mock.On call +// - ctx context.Context // - collection int64 // - partitions ...int64 -func (_e *QueryCoordCatalog_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *QueryCoordCatalog_ReleasePartition_Call { +func (_e *QueryCoordCatalog_Expecter) ReleasePartition(ctx interface{}, collection interface{}, partitions ...interface{}) *QueryCoordCatalog_ReleasePartition_Call { return &QueryCoordCatalog_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition", - append([]interface{}{collection}, partitions...)...)} + append([]interface{}{ctx, collection}, partitions...)...)} } -func (_c *QueryCoordCatalog_ReleasePartition_Call) Run(run func(collection int64, partitions ...int64)) *QueryCoordCatalog_ReleasePartition_Call { +func (_c *QueryCoordCatalog_ReleasePartition_Call) Run(run func(ctx context.Context, collection int64, partitions ...int64)) *QueryCoordCatalog_ReleasePartition_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]int64, len(args)-1) - for i, a := range args[1:] { + variadicArgs := make([]int64, len(args)-2) + for i, a := range args[2:] { if a != nil { variadicArgs[i] = a.(int64) } } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), variadicArgs...) }) return _c } @@ -379,25 +417,29 @@ func (_c *QueryCoordCatalog_ReleasePartition_Call) Return(_a0 error) *QueryCoord return _c } -func (_c *QueryCoordCatalog_ReleasePartition_Call) RunAndReturn(run func(int64, ...int64) error) *QueryCoordCatalog_ReleasePartition_Call { +func (_c *QueryCoordCatalog_ReleasePartition_Call) RunAndReturn(run func(context.Context, int64, ...int64) error) *QueryCoordCatalog_ReleasePartition_Call { _c.Call.Return(run) return _c } -// ReleaseReplica provides a mock function with given fields: collection, replicas -func (_m *QueryCoordCatalog) ReleaseReplica(collection int64, replicas ...int64) error { +// ReleaseReplica provides a mock function with given fields: ctx, collection, replicas +func (_m *QueryCoordCatalog) ReleaseReplica(ctx context.Context, collection int64, replicas ...int64) error { _va := make([]interface{}, len(replicas)) for _i := range replicas { _va[_i] = replicas[_i] } var _ca []interface{} - _ca = append(_ca, collection) + _ca = append(_ca, ctx, collection) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for ReleaseReplica") + } + var r0 error - if rf, ok := ret.Get(0).(func(int64, ...int64) error); ok { - r0 = rf(collection, replicas...) + if rf, ok := ret.Get(0).(func(context.Context, int64, ...int64) error); ok { + r0 = rf(ctx, collection, replicas...) } else { r0 = ret.Error(0) } @@ -411,22 +453,23 @@ type QueryCoordCatalog_ReleaseReplica_Call struct { } // ReleaseReplica is a helper method to define mock.On call +// - ctx context.Context // - collection int64 // - replicas ...int64 -func (_e *QueryCoordCatalog_Expecter) ReleaseReplica(collection interface{}, replicas ...interface{}) *QueryCoordCatalog_ReleaseReplica_Call { +func (_e *QueryCoordCatalog_Expecter) ReleaseReplica(ctx interface{}, collection interface{}, replicas ...interface{}) *QueryCoordCatalog_ReleaseReplica_Call { return &QueryCoordCatalog_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", - append([]interface{}{collection}, replicas...)...)} + append([]interface{}{ctx, collection}, replicas...)...)} } -func (_c *QueryCoordCatalog_ReleaseReplica_Call) Run(run func(collection int64, replicas ...int64)) *QueryCoordCatalog_ReleaseReplica_Call { +func (_c *QueryCoordCatalog_ReleaseReplica_Call) Run(run func(ctx context.Context, collection int64, replicas ...int64)) *QueryCoordCatalog_ReleaseReplica_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]int64, len(args)-1) - for i, a := range args[1:] { + variadicArgs := make([]int64, len(args)-2) + for i, a := range args[2:] { if a != nil { variadicArgs[i] = a.(int64) } } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), variadicArgs...) }) return _c } @@ -436,18 +479,22 @@ func (_c *QueryCoordCatalog_ReleaseReplica_Call) Return(_a0 error) *QueryCoordCa return _c } -func (_c *QueryCoordCatalog_ReleaseReplica_Call) RunAndReturn(run func(int64, ...int64) error) *QueryCoordCatalog_ReleaseReplica_Call { +func (_c *QueryCoordCatalog_ReleaseReplica_Call) RunAndReturn(run func(context.Context, int64, ...int64) error) *QueryCoordCatalog_ReleaseReplica_Call { _c.Call.Return(run) return _c } -// ReleaseReplicas provides a mock function with given fields: collectionID -func (_m *QueryCoordCatalog) ReleaseReplicas(collectionID int64) error { - ret := _m.Called(collectionID) +// ReleaseReplicas provides a mock function with given fields: ctx, collectionID +func (_m *QueryCoordCatalog) ReleaseReplicas(ctx context.Context, collectionID int64) error { + ret := _m.Called(ctx, collectionID) + + if len(ret) == 0 { + panic("no return value specified for ReleaseReplicas") + } var r0 error - if rf, ok := ret.Get(0).(func(int64) error); ok { - r0 = rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, collectionID) } else { r0 = ret.Error(0) } @@ -461,14 +508,15 @@ type QueryCoordCatalog_ReleaseReplicas_Call struct { } // ReleaseReplicas is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *QueryCoordCatalog_Expecter) ReleaseReplicas(collectionID interface{}) *QueryCoordCatalog_ReleaseReplicas_Call { - return &QueryCoordCatalog_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)} +func (_e *QueryCoordCatalog_Expecter) ReleaseReplicas(ctx interface{}, collectionID interface{}) *QueryCoordCatalog_ReleaseReplicas_Call { + return &QueryCoordCatalog_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", ctx, collectionID)} } -func (_c *QueryCoordCatalog_ReleaseReplicas_Call) Run(run func(collectionID int64)) *QueryCoordCatalog_ReleaseReplicas_Call { +func (_c *QueryCoordCatalog_ReleaseReplicas_Call) Run(run func(ctx context.Context, collectionID int64)) *QueryCoordCatalog_ReleaseReplicas_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -478,18 +526,22 @@ func (_c *QueryCoordCatalog_ReleaseReplicas_Call) Return(_a0 error) *QueryCoordC return _c } -func (_c *QueryCoordCatalog_ReleaseReplicas_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_ReleaseReplicas_Call { +func (_c *QueryCoordCatalog_ReleaseReplicas_Call) RunAndReturn(run func(context.Context, int64) error) *QueryCoordCatalog_ReleaseReplicas_Call { _c.Call.Return(run) return _c } -// RemoveCollectionTarget provides a mock function with given fields: collectionID -func (_m *QueryCoordCatalog) RemoveCollectionTarget(collectionID int64) error { - ret := _m.Called(collectionID) +// RemoveCollectionTarget provides a mock function with given fields: ctx, collectionID +func (_m *QueryCoordCatalog) RemoveCollectionTarget(ctx context.Context, collectionID int64) error { + ret := _m.Called(ctx, collectionID) + + if len(ret) == 0 { + panic("no return value specified for RemoveCollectionTarget") + } var r0 error - if rf, ok := ret.Get(0).(func(int64) error); ok { - r0 = rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, collectionID) } else { r0 = ret.Error(0) } @@ -503,14 +555,15 @@ type QueryCoordCatalog_RemoveCollectionTarget_Call struct { } // RemoveCollectionTarget is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *QueryCoordCatalog_Expecter) RemoveCollectionTarget(collectionID interface{}) *QueryCoordCatalog_RemoveCollectionTarget_Call { - return &QueryCoordCatalog_RemoveCollectionTarget_Call{Call: _e.mock.On("RemoveCollectionTarget", collectionID)} +func (_e *QueryCoordCatalog_Expecter) RemoveCollectionTarget(ctx interface{}, collectionID interface{}) *QueryCoordCatalog_RemoveCollectionTarget_Call { + return &QueryCoordCatalog_RemoveCollectionTarget_Call{Call: _e.mock.On("RemoveCollectionTarget", ctx, collectionID)} } -func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Run(run func(collectionID int64)) *QueryCoordCatalog_RemoveCollectionTarget_Call { +func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Run(run func(ctx context.Context, collectionID int64)) *QueryCoordCatalog_RemoveCollectionTarget_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -520,18 +573,22 @@ func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Return(_a0 error) *Quer return _c } -func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_RemoveCollectionTarget_Call { +func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) RunAndReturn(run func(context.Context, int64) error) *QueryCoordCatalog_RemoveCollectionTarget_Call { _c.Call.Return(run) return _c } -// RemoveResourceGroup provides a mock function with given fields: rgName -func (_m *QueryCoordCatalog) RemoveResourceGroup(rgName string) error { - ret := _m.Called(rgName) +// RemoveResourceGroup provides a mock function with given fields: ctx, rgName +func (_m *QueryCoordCatalog) RemoveResourceGroup(ctx context.Context, rgName string) error { + ret := _m.Called(ctx, rgName) + + if len(ret) == 0 { + panic("no return value specified for RemoveResourceGroup") + } var r0 error - if rf, ok := ret.Get(0).(func(string) error); ok { - r0 = rf(rgName) + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, rgName) } else { r0 = ret.Error(0) } @@ -545,14 +602,15 @@ type QueryCoordCatalog_RemoveResourceGroup_Call struct { } // RemoveResourceGroup is a helper method to define mock.On call +// - ctx context.Context // - rgName string -func (_e *QueryCoordCatalog_Expecter) RemoveResourceGroup(rgName interface{}) *QueryCoordCatalog_RemoveResourceGroup_Call { - return &QueryCoordCatalog_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)} +func (_e *QueryCoordCatalog_Expecter) RemoveResourceGroup(ctx interface{}, rgName interface{}) *QueryCoordCatalog_RemoveResourceGroup_Call { + return &QueryCoordCatalog_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", ctx, rgName)} } -func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) Run(run func(rgName string)) *QueryCoordCatalog_RemoveResourceGroup_Call { +func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) Run(run func(ctx context.Context, rgName string)) *QueryCoordCatalog_RemoveResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + run(args[0].(context.Context), args[1].(string)) }) return _c } @@ -562,25 +620,29 @@ func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) Return(_a0 error) *QueryCo return _c } -func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) RunAndReturn(run func(string) error) *QueryCoordCatalog_RemoveResourceGroup_Call { +func (_c *QueryCoordCatalog_RemoveResourceGroup_Call) RunAndReturn(run func(context.Context, string) error) *QueryCoordCatalog_RemoveResourceGroup_Call { _c.Call.Return(run) return _c } -// SaveCollection provides a mock function with given fields: collection, partitions -func (_m *QueryCoordCatalog) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error { +// SaveCollection provides a mock function with given fields: ctx, collection, partitions +func (_m *QueryCoordCatalog) SaveCollection(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error { _va := make([]interface{}, len(partitions)) for _i := range partitions { _va[_i] = partitions[_i] } var _ca []interface{} - _ca = append(_ca, collection) + _ca = append(_ca, ctx, collection) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for SaveCollection") + } + var r0 error - if rf, ok := ret.Get(0).(func(*querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error); ok { - r0 = rf(collection, partitions...) + if rf, ok := ret.Get(0).(func(context.Context, *querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error); ok { + r0 = rf(ctx, collection, partitions...) } else { r0 = ret.Error(0) } @@ -594,22 +656,23 @@ type QueryCoordCatalog_SaveCollection_Call struct { } // SaveCollection is a helper method to define mock.On call +// - ctx context.Context // - collection *querypb.CollectionLoadInfo // - partitions ...*querypb.PartitionLoadInfo -func (_e *QueryCoordCatalog_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *QueryCoordCatalog_SaveCollection_Call { +func (_e *QueryCoordCatalog_Expecter) SaveCollection(ctx interface{}, collection interface{}, partitions ...interface{}) *QueryCoordCatalog_SaveCollection_Call { return &QueryCoordCatalog_SaveCollection_Call{Call: _e.mock.On("SaveCollection", - append([]interface{}{collection}, partitions...)...)} + append([]interface{}{ctx, collection}, partitions...)...)} } -func (_c *QueryCoordCatalog_SaveCollection_Call) Run(run func(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SaveCollection_Call { +func (_c *QueryCoordCatalog_SaveCollection_Call) Run(run func(ctx context.Context, collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SaveCollection_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-1) - for i, a := range args[1:] { + variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-2) + for i, a := range args[2:] { if a != nil { variadicArgs[i] = a.(*querypb.PartitionLoadInfo) } } - run(args[0].(*querypb.CollectionLoadInfo), variadicArgs...) + run(args[0].(context.Context), args[1].(*querypb.CollectionLoadInfo), variadicArgs...) }) return _c } @@ -619,24 +682,29 @@ func (_c *QueryCoordCatalog_SaveCollection_Call) Return(_a0 error) *QueryCoordCa return _c } -func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(*querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SaveCollection_Call { +func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(context.Context, *querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SaveCollection_Call { _c.Call.Return(run) return _c } -// SaveCollectionTargets provides a mock function with given fields: target -func (_m *QueryCoordCatalog) SaveCollectionTargets(target ...*querypb.CollectionTarget) error { +// SaveCollectionTargets provides a mock function with given fields: ctx, target +func (_m *QueryCoordCatalog) SaveCollectionTargets(ctx context.Context, target ...*querypb.CollectionTarget) error { _va := make([]interface{}, len(target)) for _i := range target { _va[_i] = target[_i] } var _ca []interface{} + _ca = append(_ca, ctx) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for SaveCollectionTargets") + } + var r0 error - if rf, ok := ret.Get(0).(func(...*querypb.CollectionTarget) error); ok { - r0 = rf(target...) + if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.CollectionTarget) error); ok { + r0 = rf(ctx, target...) } else { r0 = ret.Error(0) } @@ -650,21 +718,22 @@ type QueryCoordCatalog_SaveCollectionTargets_Call struct { } // SaveCollectionTargets is a helper method to define mock.On call +// - ctx context.Context // - target ...*querypb.CollectionTarget -func (_e *QueryCoordCatalog_Expecter) SaveCollectionTargets(target ...interface{}) *QueryCoordCatalog_SaveCollectionTargets_Call { +func (_e *QueryCoordCatalog_Expecter) SaveCollectionTargets(ctx interface{}, target ...interface{}) *QueryCoordCatalog_SaveCollectionTargets_Call { return &QueryCoordCatalog_SaveCollectionTargets_Call{Call: _e.mock.On("SaveCollectionTargets", - append([]interface{}{}, target...)...)} + append([]interface{}{ctx}, target...)...)} } -func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Run(run func(target ...*querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTargets_Call { +func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Run(run func(ctx context.Context, target ...*querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTargets_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*querypb.CollectionTarget, len(args)-0) - for i, a := range args[0:] { + variadicArgs := make([]*querypb.CollectionTarget, len(args)-1) + for i, a := range args[1:] { if a != nil { variadicArgs[i] = a.(*querypb.CollectionTarget) } } - run(variadicArgs...) + run(args[0].(context.Context), variadicArgs...) }) return _c } @@ -674,24 +743,29 @@ func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Return(_a0 error) *Query return _c } -func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) RunAndReturn(run func(...*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTargets_Call { +func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) RunAndReturn(run func(context.Context, ...*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTargets_Call { _c.Call.Return(run) return _c } -// SavePartition provides a mock function with given fields: info -func (_m *QueryCoordCatalog) SavePartition(info ...*querypb.PartitionLoadInfo) error { +// SavePartition provides a mock function with given fields: ctx, info +func (_m *QueryCoordCatalog) SavePartition(ctx context.Context, info ...*querypb.PartitionLoadInfo) error { _va := make([]interface{}, len(info)) for _i := range info { _va[_i] = info[_i] } var _ca []interface{} + _ca = append(_ca, ctx) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for SavePartition") + } + var r0 error - if rf, ok := ret.Get(0).(func(...*querypb.PartitionLoadInfo) error); ok { - r0 = rf(info...) + if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.PartitionLoadInfo) error); ok { + r0 = rf(ctx, info...) } else { r0 = ret.Error(0) } @@ -705,21 +779,22 @@ type QueryCoordCatalog_SavePartition_Call struct { } // SavePartition is a helper method to define mock.On call +// - ctx context.Context // - info ...*querypb.PartitionLoadInfo -func (_e *QueryCoordCatalog_Expecter) SavePartition(info ...interface{}) *QueryCoordCatalog_SavePartition_Call { +func (_e *QueryCoordCatalog_Expecter) SavePartition(ctx interface{}, info ...interface{}) *QueryCoordCatalog_SavePartition_Call { return &QueryCoordCatalog_SavePartition_Call{Call: _e.mock.On("SavePartition", - append([]interface{}{}, info...)...)} + append([]interface{}{ctx}, info...)...)} } -func (_c *QueryCoordCatalog_SavePartition_Call) Run(run func(info ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SavePartition_Call { +func (_c *QueryCoordCatalog_SavePartition_Call) Run(run func(ctx context.Context, info ...*querypb.PartitionLoadInfo)) *QueryCoordCatalog_SavePartition_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-0) - for i, a := range args[0:] { + variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-1) + for i, a := range args[1:] { if a != nil { variadicArgs[i] = a.(*querypb.PartitionLoadInfo) } } - run(variadicArgs...) + run(args[0].(context.Context), variadicArgs...) }) return _c } @@ -729,24 +804,29 @@ func (_c *QueryCoordCatalog_SavePartition_Call) Return(_a0 error) *QueryCoordCat return _c } -func (_c *QueryCoordCatalog_SavePartition_Call) RunAndReturn(run func(...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SavePartition_Call { +func (_c *QueryCoordCatalog_SavePartition_Call) RunAndReturn(run func(context.Context, ...*querypb.PartitionLoadInfo) error) *QueryCoordCatalog_SavePartition_Call { _c.Call.Return(run) return _c } -// SaveReplica provides a mock function with given fields: replicas -func (_m *QueryCoordCatalog) SaveReplica(replicas ...*querypb.Replica) error { +// SaveReplica provides a mock function with given fields: ctx, replicas +func (_m *QueryCoordCatalog) SaveReplica(ctx context.Context, replicas ...*querypb.Replica) error { _va := make([]interface{}, len(replicas)) for _i := range replicas { _va[_i] = replicas[_i] } var _ca []interface{} + _ca = append(_ca, ctx) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for SaveReplica") + } + var r0 error - if rf, ok := ret.Get(0).(func(...*querypb.Replica) error); ok { - r0 = rf(replicas...) + if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.Replica) error); ok { + r0 = rf(ctx, replicas...) } else { r0 = ret.Error(0) } @@ -760,21 +840,22 @@ type QueryCoordCatalog_SaveReplica_Call struct { } // SaveReplica is a helper method to define mock.On call +// - ctx context.Context // - replicas ...*querypb.Replica -func (_e *QueryCoordCatalog_Expecter) SaveReplica(replicas ...interface{}) *QueryCoordCatalog_SaveReplica_Call { +func (_e *QueryCoordCatalog_Expecter) SaveReplica(ctx interface{}, replicas ...interface{}) *QueryCoordCatalog_SaveReplica_Call { return &QueryCoordCatalog_SaveReplica_Call{Call: _e.mock.On("SaveReplica", - append([]interface{}{}, replicas...)...)} + append([]interface{}{ctx}, replicas...)...)} } -func (_c *QueryCoordCatalog_SaveReplica_Call) Run(run func(replicas ...*querypb.Replica)) *QueryCoordCatalog_SaveReplica_Call { +func (_c *QueryCoordCatalog_SaveReplica_Call) Run(run func(ctx context.Context, replicas ...*querypb.Replica)) *QueryCoordCatalog_SaveReplica_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*querypb.Replica, len(args)-0) - for i, a := range args[0:] { + variadicArgs := make([]*querypb.Replica, len(args)-1) + for i, a := range args[1:] { if a != nil { variadicArgs[i] = a.(*querypb.Replica) } } - run(variadicArgs...) + run(args[0].(context.Context), variadicArgs...) }) return _c } @@ -784,24 +865,29 @@ func (_c *QueryCoordCatalog_SaveReplica_Call) Return(_a0 error) *QueryCoordCatal return _c } -func (_c *QueryCoordCatalog_SaveReplica_Call) RunAndReturn(run func(...*querypb.Replica) error) *QueryCoordCatalog_SaveReplica_Call { +func (_c *QueryCoordCatalog_SaveReplica_Call) RunAndReturn(run func(context.Context, ...*querypb.Replica) error) *QueryCoordCatalog_SaveReplica_Call { _c.Call.Return(run) return _c } -// SaveResourceGroup provides a mock function with given fields: rgs -func (_m *QueryCoordCatalog) SaveResourceGroup(rgs ...*querypb.ResourceGroup) error { +// SaveResourceGroup provides a mock function with given fields: ctx, rgs +func (_m *QueryCoordCatalog) SaveResourceGroup(ctx context.Context, rgs ...*querypb.ResourceGroup) error { _va := make([]interface{}, len(rgs)) for _i := range rgs { _va[_i] = rgs[_i] } var _ca []interface{} + _ca = append(_ca, ctx) _ca = append(_ca, _va...) ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for SaveResourceGroup") + } + var r0 error - if rf, ok := ret.Get(0).(func(...*querypb.ResourceGroup) error); ok { - r0 = rf(rgs...) + if rf, ok := ret.Get(0).(func(context.Context, ...*querypb.ResourceGroup) error); ok { + r0 = rf(ctx, rgs...) } else { r0 = ret.Error(0) } @@ -815,21 +901,22 @@ type QueryCoordCatalog_SaveResourceGroup_Call struct { } // SaveResourceGroup is a helper method to define mock.On call +// - ctx context.Context // - rgs ...*querypb.ResourceGroup -func (_e *QueryCoordCatalog_Expecter) SaveResourceGroup(rgs ...interface{}) *QueryCoordCatalog_SaveResourceGroup_Call { +func (_e *QueryCoordCatalog_Expecter) SaveResourceGroup(ctx interface{}, rgs ...interface{}) *QueryCoordCatalog_SaveResourceGroup_Call { return &QueryCoordCatalog_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup", - append([]interface{}{}, rgs...)...)} + append([]interface{}{ctx}, rgs...)...)} } -func (_c *QueryCoordCatalog_SaveResourceGroup_Call) Run(run func(rgs ...*querypb.ResourceGroup)) *QueryCoordCatalog_SaveResourceGroup_Call { +func (_c *QueryCoordCatalog_SaveResourceGroup_Call) Run(run func(ctx context.Context, rgs ...*querypb.ResourceGroup)) *QueryCoordCatalog_SaveResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*querypb.ResourceGroup, len(args)-0) - for i, a := range args[0:] { + variadicArgs := make([]*querypb.ResourceGroup, len(args)-1) + for i, a := range args[1:] { if a != nil { variadicArgs[i] = a.(*querypb.ResourceGroup) } } - run(variadicArgs...) + run(args[0].(context.Context), variadicArgs...) }) return _c } @@ -839,7 +926,7 @@ func (_c *QueryCoordCatalog_SaveResourceGroup_Call) Return(_a0 error) *QueryCoor return _c } -func (_c *QueryCoordCatalog_SaveResourceGroup_Call) RunAndReturn(run func(...*querypb.ResourceGroup) error) *QueryCoordCatalog_SaveResourceGroup_Call { +func (_c *QueryCoordCatalog_SaveResourceGroup_Call) RunAndReturn(run func(context.Context, ...*querypb.ResourceGroup) error) *QueryCoordCatalog_SaveResourceGroup_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index 0ccdc90ddafea..2c1b3ffe97d26 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -17,6 +17,7 @@ package balance import ( + "context" "fmt" "sort" @@ -57,9 +58,9 @@ func (chanPlan *ChannelAssignPlan) String() string { } type Balance interface { - AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan - AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan - BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) + AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan + AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan + BalanceReplica(ctx context.Context, replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) } type RoundRobinBalancer struct { @@ -67,7 +68,7 @@ type RoundRobinBalancer struct { nodeManager *session.NodeManager } -func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { +func (b *RoundRobinBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { // skip out suspend node and stopping node during assignment, but skip this check for manual balance if !manualBalance { nodes = lo.Filter(nodes, func(node int64, _ int) bool { @@ -103,7 +104,7 @@ func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta. return ret } -func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { +func (b *RoundRobinBalancer) AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { // skip out suspend node and stopping node during assignment, but skip this check for manual balance if !manualBalance { versionRangeFilter := semver.MustParseRange(">2.3.x") @@ -136,7 +137,7 @@ func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []i return ret } -func (b *RoundRobinBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { +func (b *RoundRobinBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { // TODO by chun.han return nil, nil } diff --git a/internal/querycoordv2/balance/balance_test.go b/internal/querycoordv2/balance/balance_test.go index b99e1272cfdab..c76cf019112c3 100644 --- a/internal/querycoordv2/balance/balance_test.go +++ b/internal/querycoordv2/balance/balance_test.go @@ -17,6 +17,7 @@ package balance import ( + "context" "testing" "github.com/stretchr/testify/mock" @@ -50,6 +51,7 @@ func (suite *BalanceTestSuite) SetupTest() { } func (suite *BalanceTestSuite) TestAssignBalance() { + ctx := context.Background() cases := []struct { name string nodeIDs []int64 @@ -108,13 +110,14 @@ func (suite *BalanceTestSuite) TestAssignBalance() { suite.mockScheduler.EXPECT().GetSegmentTaskDelta(c.nodeIDs[i], int64(-1)).Return(c.deltaCnts[i]) } } - plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs, false) + plans := suite.roundRobinBalancer.AssignSegment(ctx, 0, c.assignments, c.nodeIDs, false) suite.ElementsMatch(c.expectPlans, plans) }) } } func (suite *BalanceTestSuite) TestAssignChannel() { + ctx := context.Background() cases := []struct { name string nodeIDs []int64 @@ -174,7 +177,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() { suite.mockScheduler.EXPECT().GetChannelTaskDelta(c.nodeIDs[i], int64(-1)).Return(c.deltaCnts[i]) } } - plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs, false) + plans := suite.roundRobinBalancer.AssignChannel(ctx, c.assignments, c.nodeIDs, false) suite.ElementsMatch(c.expectPlans, plans) }) } diff --git a/internal/querycoordv2/balance/channel_level_score_balancer.go b/internal/querycoordv2/balance/channel_level_score_balancer.go index ba0a3398bf700..fb4afd1521fab 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer.go @@ -17,6 +17,7 @@ package balance import ( + "context" "fmt" "math" "sort" @@ -49,7 +50,7 @@ func NewChannelLevelScoreBalancer(scheduler task.Scheduler, } } -func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { +func (b *ChannelLevelScoreBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { log := log.With( zap.Int64("collection", replica.GetCollectionID()), zap.Int64("replica id", replica.GetID()), @@ -67,7 +68,7 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme }() exclusiveMode := true - channels := b.targetMgr.GetDmChannelsByCollection(replica.GetCollectionID(), meta.CurrentTarget) + channels := b.targetMgr.GetDmChannelsByCollection(ctx, replica.GetCollectionID(), meta.CurrentTarget) for channelName := range channels { if len(replica.GetChannelRWNodes(channelName)) == 0 { exclusiveMode = false @@ -77,7 +78,7 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme // if some channel doesn't own nodes, exit exclusive mode if !exclusiveMode { - return b.ScoreBasedBalancer.BalanceReplica(replica) + return b.ScoreBasedBalancer.BalanceReplica(ctx, replica) } channelPlans = make([]ChannelAssignPlan, 0) @@ -122,19 +123,19 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score if b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, channelName, rwNodes, roNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, channelName, rwNodes, roNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, channelName, rwNodes, roNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, channelName, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genChannelPlan(replica, channelName, rwNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(ctx, replica, channelName, rwNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genSegmentPlan(br, replica, channelName, rwNodes)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(ctx, br, replica, channelName, rwNodes)...) } } } @@ -142,11 +143,11 @@ func (b *ChannelLevelScoreBalancer) BalanceReplica(replica *meta.Replica) (segme return segmentPlans, channelPlans } -func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan { +func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(ctx context.Context, replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) for _, nodeID := range offlineNodes { dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID), meta.WithChannelName2Channel(channelName)) - plans := b.AssignChannel(dmChannels, onlineNodes, false) + plans := b.AssignChannel(ctx, dmChannels, onlineNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -156,14 +157,14 @@ func (b *ChannelLevelScoreBalancer) genStoppingChannelPlan(replica *meta.Replica return channelPlans } -func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { +func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, channelName string, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) for _, nodeID := range offlineNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID), meta.WithChannel(channelName)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) - plans := b.AssignSegment(replica.GetCollectionID(), segments, onlineNodes, false) + plans := b.AssignSegment(ctx, replica.GetCollectionID(), segments, onlineNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -173,7 +174,7 @@ func (b *ChannelLevelScoreBalancer) genStoppingSegmentPlan(replica *meta.Replica return segmentPlans } -func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *meta.Replica, channelName string, onlineNodes []int64) []SegmentAssignPlan { +func (b *ChannelLevelScoreBalancer) genSegmentPlan(ctx context.Context, br *balanceReport, replica *meta.Replica, channelName string, onlineNodes []int64) []SegmentAssignPlan { segmentDist := make(map[int64][]*meta.Segment) nodeItemsMap := b.convertToNodeItems(br, replica.GetCollectionID(), onlineNodes) if len(nodeItemsMap) == 0 { @@ -189,7 +190,7 @@ func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *m for _, node := range onlineNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node), meta.WithChannel(channelName)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) segmentDist[node] = segments } @@ -224,7 +225,7 @@ func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *m return nil } - segmentPlans := b.AssignSegment(replica.GetCollectionID(), segmentsToMove, onlineNodes, false) + segmentPlans := b.AssignSegment(ctx, replica.GetCollectionID(), segmentsToMove, onlineNodes, false) for i := range segmentPlans { segmentPlans[i].From = segmentPlans[i].Segment.Node segmentPlans[i].Replica = replica @@ -233,7 +234,7 @@ func (b *ChannelLevelScoreBalancer) genSegmentPlan(br *balanceReport, replica *m return segmentPlans } -func (b *ChannelLevelScoreBalancer) genChannelPlan(replica *meta.Replica, channelName string, onlineNodes []int64) []ChannelAssignPlan { +func (b *ChannelLevelScoreBalancer) genChannelPlan(ctx context.Context, replica *meta.Replica, channelName string, onlineNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) if len(onlineNodes) > 1 { // start to balance channels on all available nodes @@ -261,7 +262,7 @@ func (b *ChannelLevelScoreBalancer) genChannelPlan(replica *meta.Replica, channe return nil } - channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false) + channelPlans := b.AssignChannel(ctx, channelsToMove, nodeWithLessChannel, false) for i := range channelPlans { channelPlans[i].From = channelPlans[i].Channel.Node channelPlans[i].Replica = replica diff --git a/internal/querycoordv2/balance/channel_level_score_balancer_test.go b/internal/querycoordv2/balance/channel_level_score_balancer_test.go index 64508cd41e136..2b8f0ff79b4f4 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer_test.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer_test.go @@ -16,6 +16,7 @@ package balance import ( + "context" "testing" "github.com/samber/lo" @@ -85,6 +86,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TearDownTest() { } func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() { + ctx := context.Background() cases := []struct { name string comment string @@ -240,7 +242,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() { suite.balancer.nodeManager.Add(nodeInfo) } for i := range c.collectionIDs { - plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false) + plans := balancer.AssignSegment(ctx, c.collectionIDs[i], c.assignments[i], c.nodes, false) if c.unstableAssignment { suite.Equal(len(plans), len(c.expectPlans[i])) } else { @@ -252,6 +254,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegment() { } func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegmentWithGrowing() { + ctx := context.Background() suite.SetupSuite() defer suite.TearDownTest() balancer := suite.balancer @@ -293,13 +296,14 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestAssignSegmentWithGrowing() CollectionID: 1, } suite.balancer.dist.LeaderViewManager.Update(1, leaderView) - plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false) + plans := balancer.AssignSegment(ctx, 1, toAssign, lo.Keys(distributions), false) for _, p := range plans { suite.Equal(int64(2), p.To) } } func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -376,11 +380,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -400,7 +404,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } // 4. balance and verify result @@ -412,6 +416,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceOneRound() { } func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() { + ctx := context.Background() balanceCase := struct { name string nodes []int64 @@ -495,12 +500,12 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() { collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded collection.LoadType = querypb.LoadType_LoadCollection - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i], + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i], append(balanceCase.nodes, balanceCase.notExistedNodes...))) - balancer.targetMgr.UpdateCollectionNextTarget(balanceCase.collectionIDs[i]) - balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, balanceCase.collectionIDs[i]) } // 2. set up target for distribution for multi collections @@ -517,7 +522,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() { }) nodeInfo.SetState(balanceCase.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(balanceCase.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, balanceCase.nodes[i]) } // 4. first round balance @@ -535,6 +540,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestBalanceMultiRound() { } func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -654,11 +660,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -678,11 +684,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } for i := range c.outBoundNodes { - suite.balancer.meta.ResourceManager.HandleNodeDown(c.outBoundNodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeDown(ctx, c.outBoundNodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -695,6 +701,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestStoppedBalance() { } func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() { + ctx := context.Background() cases := []struct { name string collectionID int64 @@ -771,13 +778,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) for replicaID, nodes := range c.replicaWithNodes { - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, c.collectionID, nodes)) } - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.segmentDist { @@ -798,7 +805,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i]) } } @@ -824,10 +831,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestMultiReplicaBalance() { func (suite *ChannelLevelScoreBalancerTestSuite) getCollectionBalancePlans(balancer *ChannelLevelScoreBalancer, collectionID int64, ) ([]SegmentAssignPlan, []ChannelAssignPlan) { - replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID) + ctx := context.Background() + replicas := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID) segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) for _, replica := range replicas { - sPlans, cPlans := balancer.BalanceReplica(replica) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) segmentPlans = append(segmentPlans, sPlans...) channelPlans = append(channelPlans, cPlans...) } @@ -835,6 +843,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) getCollectionBalancePlans(balan } func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_ChannelOutBound() { + ctx := context.Background() Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") @@ -865,11 +874,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Cha collection := utils.CreateTestCollection(collectionID, int32(1)) collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) - balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) - balancer.targetMgr.UpdateCollectionNextTarget(collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) // 3. set up nodes info and resourceManager for balancer nodeCount := 4 @@ -883,11 +892,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Cha // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(session.NodeStateNormal) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID()) } utils.RecoverAllCollection(balancer.meta) - replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0] ch1Nodes := replica.GetChannelRWNodes("channel1") ch2Nodes := replica.GetChannelRWNodes("channel2") suite.Len(ch1Nodes, 2) @@ -903,12 +912,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Cha }, }...) - sPlans, cPlans := balancer.BalanceReplica(replica) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) suite.Len(sPlans, 0) suite.Len(cPlans, 1) } func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_SegmentOutbound() { + ctx := context.Background() Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") @@ -939,11 +949,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg collection := utils.CreateTestCollection(collectionID, int32(1)) collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) - balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) - balancer.targetMgr.UpdateCollectionNextTarget(collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) // 3. set up nodes info and resourceManager for balancer nodeCount := 4 @@ -957,11 +967,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(session.NodeStateNormal) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID()) } utils.RecoverAllCollection(balancer.meta) - replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0] ch1Nodes := replica.GetChannelRWNodes("channel1") ch2Nodes := replica.GetChannelRWNodes("channel2") suite.Len(ch1Nodes, 2) @@ -1000,12 +1010,13 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg }, }...) - sPlans, cPlans := balancer.BalanceReplica(replica) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) suite.Len(sPlans, 1) suite.Len(cPlans, 0) } func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_NodeStopping() { + ctx := context.Background() Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") @@ -1036,11 +1047,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod collection := utils.CreateTestCollection(collectionID, int32(1)) collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) - balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) - balancer.targetMgr.UpdateCollectionNextTarget(collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) // 3. set up nodes info and resourceManager for balancer nodeCount := 4 @@ -1054,11 +1065,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(session.NodeStateNormal) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID()) } utils.RecoverAllCollection(balancer.meta) - replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0] ch1Nodes := replica.GetChannelRWNodes("channel1") ch2Nodes := replica.GetChannelRWNodes("channel2") suite.Len(ch1Nodes, 2) @@ -1112,24 +1123,25 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Nod balancer.nodeManager.Stopping(ch1Nodes[0]) balancer.nodeManager.Stopping(ch2Nodes[0]) - suite.balancer.meta.ResourceManager.HandleNodeStopping(ch1Nodes[0]) - suite.balancer.meta.ResourceManager.HandleNodeStopping(ch2Nodes[0]) + suite.balancer.meta.ResourceManager.HandleNodeStopping(ctx, ch1Nodes[0]) + suite.balancer.meta.ResourceManager.HandleNodeStopping(ctx, ch2Nodes[0]) utils.RecoverAllCollection(balancer.meta) - replica = balancer.meta.ReplicaManager.Get(replica.GetID()) - sPlans, cPlans := balancer.BalanceReplica(replica) + replica = balancer.meta.ReplicaManager.Get(ctx, replica.GetID()) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) suite.Len(sPlans, 0) suite.Len(cPlans, 2) balancer.dist.ChannelDistManager.Update(ch1Nodes[0]) balancer.dist.ChannelDistManager.Update(ch2Nodes[0]) - sPlans, cPlans = balancer.BalanceReplica(replica) + sPlans, cPlans = balancer.BalanceReplica(ctx, replica) suite.Len(sPlans, 2) suite.Len(cPlans, 0) } func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_SegmentUnbalance() { + ctx := context.Background() Params.Save(Params.QueryCoordCfg.Balancer.Key, meta.ChannelLevelScoreBalancerName) defer Params.Reset(Params.QueryCoordCfg.Balancer.Key) Params.Save(Params.QueryCoordCfg.ChannelExclusiveNodeFactor.Key, "2") @@ -1160,11 +1172,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg collection := utils.CreateTestCollection(collectionID, int32(1)) collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) - balancer.meta.ReplicaManager.Spawn(1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) - balancer.targetMgr.UpdateCollectionNextTarget(collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID)) + balancer.meta.ReplicaManager.Spawn(ctx, 1, map[string]int{meta.DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) // 3. set up nodes info and resourceManager for balancer nodeCount := 4 @@ -1178,11 +1190,11 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg // nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(session.NodeStateNormal) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodeInfo.ID()) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodeInfo.ID()) } utils.RecoverAllCollection(balancer.meta) - replica := balancer.meta.ReplicaManager.GetByCollection(collectionID)[0] + replica := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID)[0] ch1Nodes := replica.GetChannelRWNodes("channel1") ch2Nodes := replica.GetChannelRWNodes("channel2") suite.Len(ch1Nodes, 2) @@ -1254,7 +1266,7 @@ func (suite *ChannelLevelScoreBalancerTestSuite) TestExclusiveChannelBalance_Seg }, }...) - sPlans, cPlans := balancer.BalanceReplica(replica) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) suite.Len(sPlans, 2) suite.Len(cPlans, 0) } diff --git a/internal/querycoordv2/balance/mock_balancer.go b/internal/querycoordv2/balance/mock_balancer.go index 003451a9ca83d..3ba72b89fb729 100644 --- a/internal/querycoordv2/balance/mock_balancer.go +++ b/internal/querycoordv2/balance/mock_balancer.go @@ -3,6 +3,8 @@ package balance import ( + context "context" + meta "github.com/milvus-io/milvus/internal/querycoordv2/meta" mock "github.com/stretchr/testify/mock" ) @@ -20,17 +22,17 @@ func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter { return &MockBalancer_Expecter{mock: &_m.Mock} } -// AssignChannel provides a mock function with given fields: channels, nodes, manualBalance -func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { - ret := _m.Called(channels, nodes, manualBalance) +// AssignChannel provides a mock function with given fields: ctx, channels, nodes, manualBalance +func (_m *MockBalancer) AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { + ret := _m.Called(ctx, channels, nodes, manualBalance) if len(ret) == 0 { panic("no return value specified for AssignChannel") } var r0 []ChannelAssignPlan - if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok { - r0 = rf(channels, nodes, manualBalance) + if rf, ok := ret.Get(0).(func(context.Context, []*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok { + r0 = rf(ctx, channels, nodes, manualBalance) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ChannelAssignPlan) @@ -46,16 +48,17 @@ type MockBalancer_AssignChannel_Call struct { } // AssignChannel is a helper method to define mock.On call +// - ctx context.Context // - channels []*meta.DmChannel // - nodes []int64 // - manualBalance bool -func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call { - return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes, manualBalance)} +func (_e *MockBalancer_Expecter) AssignChannel(ctx interface{}, channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call { + return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", ctx, channels, nodes, manualBalance)} } -func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call { +func (_c *MockBalancer_AssignChannel_Call) Run(run func(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*meta.DmChannel), args[1].([]int64), args[2].(bool)) + run(args[0].(context.Context), args[1].([]*meta.DmChannel), args[2].([]int64), args[3].(bool)) }) return _c } @@ -65,22 +68,22 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock return _c } -func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call { +func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func(context.Context, []*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call { _c.Call.Return(run) return _c } -// AssignSegment provides a mock function with given fields: collectionID, segments, nodes, manualBalance -func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { - ret := _m.Called(collectionID, segments, nodes, manualBalance) +// AssignSegment provides a mock function with given fields: ctx, collectionID, segments, nodes, manualBalance +func (_m *MockBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { + ret := _m.Called(ctx, collectionID, segments, nodes, manualBalance) if len(ret) == 0 { panic("no return value specified for AssignSegment") } var r0 []SegmentAssignPlan - if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok { - r0 = rf(collectionID, segments, nodes, manualBalance) + if rf, ok := ret.Get(0).(func(context.Context, int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok { + r0 = rf(ctx, collectionID, segments, nodes, manualBalance) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]SegmentAssignPlan) @@ -96,17 +99,18 @@ type MockBalancer_AssignSegment_Call struct { } // AssignSegment is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - segments []*meta.Segment // - nodes []int64 // - manualBalance bool -func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call { - return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes, manualBalance)} +func (_e *MockBalancer_Expecter) AssignSegment(ctx interface{}, collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call { + return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", ctx, collectionID, segments, nodes, manualBalance)} } -func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call { +func (_c *MockBalancer_AssignSegment_Call) Run(run func(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64), args[3].(bool)) + run(args[0].(context.Context), args[1].(int64), args[2].([]*meta.Segment), args[3].([]int64), args[4].(bool)) }) return _c } @@ -116,14 +120,14 @@ func (_c *MockBalancer_AssignSegment_Call) Return(_a0 []SegmentAssignPlan) *Mock return _c } -func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call { +func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(context.Context, int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call { _c.Call.Return(run) return _c } -// BalanceReplica provides a mock function with given fields: replica -func (_m *MockBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { - ret := _m.Called(replica) +// BalanceReplica provides a mock function with given fields: ctx, replica +func (_m *MockBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { + ret := _m.Called(ctx, replica) if len(ret) == 0 { panic("no return value specified for BalanceReplica") @@ -131,19 +135,19 @@ func (_m *MockBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPl var r0 []SegmentAssignPlan var r1 []ChannelAssignPlan - if rf, ok := ret.Get(0).(func(*meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)); ok { - return rf(replica) + if rf, ok := ret.Get(0).(func(context.Context, *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)); ok { + return rf(ctx, replica) } - if rf, ok := ret.Get(0).(func(*meta.Replica) []SegmentAssignPlan); ok { - r0 = rf(replica) + if rf, ok := ret.Get(0).(func(context.Context, *meta.Replica) []SegmentAssignPlan); ok { + r0 = rf(ctx, replica) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]SegmentAssignPlan) } } - if rf, ok := ret.Get(1).(func(*meta.Replica) []ChannelAssignPlan); ok { - r1 = rf(replica) + if rf, ok := ret.Get(1).(func(context.Context, *meta.Replica) []ChannelAssignPlan); ok { + r1 = rf(ctx, replica) } else { if ret.Get(1) != nil { r1 = ret.Get(1).([]ChannelAssignPlan) @@ -159,14 +163,15 @@ type MockBalancer_BalanceReplica_Call struct { } // BalanceReplica is a helper method to define mock.On call +// - ctx context.Context // - replica *meta.Replica -func (_e *MockBalancer_Expecter) BalanceReplica(replica interface{}) *MockBalancer_BalanceReplica_Call { - return &MockBalancer_BalanceReplica_Call{Call: _e.mock.On("BalanceReplica", replica)} +func (_e *MockBalancer_Expecter) BalanceReplica(ctx interface{}, replica interface{}) *MockBalancer_BalanceReplica_Call { + return &MockBalancer_BalanceReplica_Call{Call: _e.mock.On("BalanceReplica", ctx, replica)} } -func (_c *MockBalancer_BalanceReplica_Call) Run(run func(replica *meta.Replica)) *MockBalancer_BalanceReplica_Call { +func (_c *MockBalancer_BalanceReplica_Call) Run(run func(ctx context.Context, replica *meta.Replica)) *MockBalancer_BalanceReplica_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*meta.Replica)) + run(args[0].(context.Context), args[1].(*meta.Replica)) }) return _c } @@ -176,7 +181,7 @@ func (_c *MockBalancer_BalanceReplica_Call) Return(_a0 []SegmentAssignPlan, _a1 return _c } -func (_c *MockBalancer_BalanceReplica_Call) RunAndReturn(run func(*meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)) *MockBalancer_BalanceReplica_Call { +func (_c *MockBalancer_BalanceReplica_Call) RunAndReturn(run func(context.Context, *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan)) *MockBalancer_BalanceReplica_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/balance/multi_target_balance.go b/internal/querycoordv2/balance/multi_target_balance.go index 3f0f199a42000..ca078d5992743 100644 --- a/internal/querycoordv2/balance/multi_target_balance.go +++ b/internal/querycoordv2/balance/multi_target_balance.go @@ -1,6 +1,7 @@ package balance import ( + "context" "fmt" "math" "math/rand" @@ -468,7 +469,7 @@ type MultiTargetBalancer struct { targetMgr meta.TargetManagerInterface } -func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { +func (b *MultiTargetBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { log := log.With( zap.Int64("collection", replica.GetCollectionID()), zap.Int64("replica id", replica.GetID()), @@ -510,32 +511,32 @@ func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) (segmentPlan ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score if b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, rwNodes, roNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(ctx, br, replica, rwNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = b.genSegmentPlan(replica, rwNodes) + segmentPlans = b.genSegmentPlan(ctx, replica, rwNodes) } } return segmentPlans, channelPlans } -func (b *MultiTargetBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan { +func (b *MultiTargetBalancer) genSegmentPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan { // get segments distribution on replica level and global level nodeSegments := make(map[int64][]*meta.Segment) globalNodeSegments := make(map[int64][]*meta.Segment) for _, node := range rwNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) nodeSegments[node] = segments globalNodeSegments[node] = b.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(node)) diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index a664c5885ac63..60d36607136d3 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -42,7 +42,7 @@ type RowCountBasedBalancer struct { // AssignSegment, when row count based balancer assign segments, it will assign segment to node with least global row count. // try to make every query node has same row count. -func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { // skip out suspend node and stopping node during assignment, but skip this check for manual balance if !manualBalance { nodes = lo.Filter(nodes, func(node int64, _ int) bool { @@ -87,7 +87,7 @@ func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*me // AssignSegment, when row count based balancer assign segments, it will assign channel to node with least global channel count. // try to make every query node has channel count -func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { +func (b *RowCountBasedBalancer) AssignChannel(ctx context.Context, channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { // skip out suspend node and stopping node during assignment, but skip this check for manual balance if !manualBalance { versionRangeFilter := semver.MustParseRange(">2.3.x") @@ -167,7 +167,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []* return ret } -func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { +func (b *RowCountBasedBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { log := log.Ctx(context.TODO()).WithRateGroup("qcv2.RowCountBasedBalancer", 1, 60).With( zap.Int64("collectionID", replica.GetCollectionID()), zap.Int64("replicaID", replica.GetCollectionID()), @@ -206,33 +206,33 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPl ) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score if b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, rwNodes, roNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(ctx, br, replica, rwNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genSegmentPlan(replica, rwNodes)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(ctx, replica, rwNodes)...) } } return segmentPlans, channelPlans } -func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, rwNodes []int64, roNodes []int64) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64, roNodes []int64) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) for _, nodeID := range roNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) - plans := b.AssignSegment(replica.GetCollectionID(), segments, rwNodes, false) + plans := b.AssignSegment(ctx, replica.GetCollectionID(), segments, rwNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -242,7 +242,7 @@ func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, rw return segmentPlans } -func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) genSegmentPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64) []SegmentAssignPlan { segmentsToMove := make([]*meta.Segment, 0) nodeRowCount := make(map[int64]int, 0) @@ -251,7 +251,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes [] for _, node := range rwNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) rowCount := 0 for _, s := range segments { @@ -298,7 +298,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes [] return nil } - segmentPlans := b.AssignSegment(replica.GetCollectionID(), segmentsToMove, nodesWithLessRow, false) + segmentPlans := b.AssignSegment(ctx, replica.GetCollectionID(), segmentsToMove, nodesWithLessRow, false) for i := range segmentPlans { segmentPlans[i].From = segmentPlans[i].Segment.Node segmentPlans[i].Replica = replica @@ -307,11 +307,11 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, rwNodes [] return segmentPlans } -func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan { +func (b *RowCountBasedBalancer) genStoppingChannelPlan(ctx context.Context, replica *meta.Replica, rwNodes []int64, roNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) for _, nodeID := range roNodes { dmChannels := b.dist.ChannelDistManager.GetByCollectionAndFilter(replica.GetCollectionID(), meta.WithNodeID2Channel(nodeID)) - plans := b.AssignChannel(dmChannels, rwNodes, false) + plans := b.AssignChannel(ctx, dmChannels, rwNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -321,7 +321,7 @@ func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, rw return channelPlans } -func (b *RowCountBasedBalancer) genChannelPlan(br *balanceReport, replica *meta.Replica, rwNodes []int64) []ChannelAssignPlan { +func (b *RowCountBasedBalancer) genChannelPlan(ctx context.Context, br *balanceReport, replica *meta.Replica, rwNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) if len(rwNodes) > 1 { // start to balance channels on all available nodes @@ -349,7 +349,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(br *balanceReport, replica *meta. return nil } - channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false) + channelPlans := b.AssignChannel(ctx, channelsToMove, nodeWithLessChannel, false) for i := range channelPlans { channelPlans[i].From = channelPlans[i].Channel.Node channelPlans[i].Replica = replica diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index d423192c825e3..628c5159ca9cf 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -17,6 +17,7 @@ package balance import ( + "context" "fmt" "testing" @@ -90,6 +91,7 @@ func (suite *RowCountBasedBalancerTestSuite) TearDownTest() { } func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { + ctx := context.Background() cases := []struct { name string distributions map[int64][]*meta.Segment @@ -142,13 +144,14 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) } - plans := balancer.AssignSegment(0, c.assignments, c.nodes, false) + plans := balancer.AssignSegment(ctx, 0, c.assignments, c.nodes, false) assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, plans) }) } } func (suite *RowCountBasedBalancerTestSuite) TestBalance() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -403,13 +406,13 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded collection.LoadType = querypb.LoadType_LoadCollection - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes)) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, c.nodes)) suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil) - balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) - balancer.targetMgr.UpdateCollectionCurrentTarget(1) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1) for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -427,7 +430,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -443,7 +446,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { // clear distribution for _, node := range c.nodes { - balancer.meta.ResourceManager.HandleNodeDown(node) + balancer.meta.ResourceManager.HandleNodeDown(ctx, node) balancer.nodeManager.Remove(node) balancer.dist.SegmentDistManager.Update(node) balancer.dist.ChannelDistManager.Update(node) @@ -453,6 +456,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { } func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -614,15 +618,15 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { collection.LoadPercentage = 100 collection.LoadType = querypb.LoadType_LoadCollection collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, c.segmentInCurrent, nil) - balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) - balancer.targetMgr.UpdateCollectionCurrentTarget(1) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1) suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, c.segmentInNext, nil) - balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -640,7 +644,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -652,6 +656,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { } func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -759,12 +764,12 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() { collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded collection.LoadType = querypb.LoadType_LoadCollection - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil) - balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) - balancer.targetMgr.UpdateCollectionCurrentTarget(1) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1) for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -784,8 +789,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() { suite.balancer.nodeManager.Add(nodeInfo) } // make node-3 outbound - balancer.meta.ResourceManager.HandleNodeUp(1) - balancer.meta.ResourceManager.HandleNodeUp(2) + balancer.meta.ResourceManager.HandleNodeUp(ctx, 1) + balancer.meta.ResourceManager.HandleNodeUp(ctx, 2) utils.RecoverAllCollection(balancer.meta) segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, 1) assertChannelAssignPlanElementMatch(&suite.Suite, c.expectChannelPlans, channelPlans) @@ -801,6 +806,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() { } func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -830,8 +836,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() { collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loading collection.LoadType = querypb.LoadType_LoadCollection - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes)) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, c.nodes)) for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -845,10 +851,11 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() { func (suite *RowCountBasedBalancerTestSuite) getCollectionBalancePlans(balancer *RowCountBasedBalancer, collectionID int64, ) ([]SegmentAssignPlan, []ChannelAssignPlan) { - replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID) + ctx := context.Background() + replicas := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID) segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) for _, replica := range replicas { - sPlans, cPlans := balancer.BalanceReplica(replica) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) segmentPlans = append(segmentPlans, sPlans...) channelPlans = append(channelPlans, cPlans...) } @@ -859,6 +866,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { suite.SetupSuite() defer suite.TearDownTest() balancer := suite.balancer + ctx := context.Background() distributions := map[int64][]*meta.Segment{ 1: { @@ -895,13 +903,14 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { NumOfGrowingRows: 50, } suite.balancer.dist.LeaderViewManager.Update(1, leaderView) - plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false) + plans := balancer.AssignSegment(ctx, 1, toAssign, lo.Keys(distributions), false) for _, p := range plans { suite.Equal(int64(2), p.To) } } func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -989,13 +998,13 @@ func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() { collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded collection.LoadType = querypb.LoadType_LoadCollection - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil) - balancer.targetMgr.UpdateCollectionNextTarget(int64(1)) - balancer.targetMgr.UpdateCollectionCurrentTarget(1) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, 1) for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -1013,7 +1022,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } Params.Save(Params.QueryCoordCfg.AutoBalanceChannel.Key, fmt.Sprint(c.enableBalanceChannel)) @@ -1039,6 +1048,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestDisableBalanceChannel() { } func (suite *RowCountBasedBalancerTestSuite) TestMultiReplicaBalance() { + ctx := context.Background() cases := []struct { name string collectionID int64 @@ -1115,13 +1125,13 @@ func (suite *RowCountBasedBalancerTestSuite) TestMultiReplicaBalance() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) for replicaID, nodes := range c.replicaWithNodes { - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, c.collectionID, nodes)) } - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.segmentDist { @@ -1142,7 +1152,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestMultiReplicaBalance() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i]) } } diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index ea1187813135d..af1092d016273 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -17,6 +17,7 @@ package balance import ( + "context" "fmt" "math" "sort" @@ -50,7 +51,7 @@ func NewScoreBasedBalancer(scheduler task.Scheduler, } // AssignSegment got a segment list, and try to assign each segment to node's with lowest score -func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { +func (b *ScoreBasedBalancer) AssignSegment(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { br := NewBalanceReport() return b.assignSegment(br, collectionID, segments, nodes, manualBalance) } @@ -263,7 +264,7 @@ func (b *ScoreBasedBalancer) calculateSegmentScore(s *meta.Segment) float64 { return float64(s.GetNumOfRows()) * (1 + params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()) } -func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { +func (b *ScoreBasedBalancer) BalanceReplica(ctx context.Context, replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { log := log.With( zap.Int64("collection", replica.GetCollectionID()), zap.Int64("replica id", replica.GetID()), @@ -308,32 +309,32 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans br.AddRecord(StrRecordf("executing stopping balance: %v", roNodes)) // handle stopped nodes here, have to assign segments on stopping nodes to nodes with the smallest score if b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genStoppingChannelPlan(replica, rwNodes, roNodes)...) + channelPlans = append(channelPlans, b.genStoppingChannelPlan(ctx, replica, rwNodes, roNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(replica, rwNodes, roNodes)...) + segmentPlans = append(segmentPlans, b.genStoppingSegmentPlan(ctx, replica, rwNodes, roNodes)...) } } else { if paramtable.Get().QueryCoordCfg.AutoBalanceChannel.GetAsBool() && b.permitBalanceChannel(replica.GetCollectionID()) { - channelPlans = append(channelPlans, b.genChannelPlan(br, replica, rwNodes)...) + channelPlans = append(channelPlans, b.genChannelPlan(ctx, br, replica, rwNodes)...) } if len(channelPlans) == 0 && b.permitBalanceSegment(replica.GetCollectionID()) { - segmentPlans = append(segmentPlans, b.genSegmentPlan(br, replica, rwNodes)...) + segmentPlans = append(segmentPlans, b.genSegmentPlan(ctx, br, replica, rwNodes)...) } } return segmentPlans, channelPlans } -func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { +func (b *ScoreBasedBalancer) genStoppingSegmentPlan(ctx context.Context, replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []SegmentAssignPlan { segmentPlans := make([]SegmentAssignPlan, 0) for _, nodeID := range offlineNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(nodeID)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) - plans := b.AssignSegment(replica.GetCollectionID(), segments, onlineNodes, false) + plans := b.AssignSegment(ctx, replica.GetCollectionID(), segments, onlineNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -343,7 +344,7 @@ func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlin return segmentPlans } -func (b *ScoreBasedBalancer) genSegmentPlan(br *balanceReport, replica *meta.Replica, onlineNodes []int64) []SegmentAssignPlan { +func (b *ScoreBasedBalancer) genSegmentPlan(ctx context.Context, br *balanceReport, replica *meta.Replica, onlineNodes []int64) []SegmentAssignPlan { segmentDist := make(map[int64][]*meta.Segment) nodeItemsMap := b.convertToNodeItems(br, replica.GetCollectionID(), onlineNodes) if len(nodeItemsMap) == 0 { @@ -359,7 +360,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(br *balanceReport, replica *meta.Rep for _, node := range onlineNodes { dist := b.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(node)) segments := lo.Filter(dist, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.CanSegmentBeMoved(segment.GetCollectionID(), segment.GetID()) + return b.targetMgr.CanSegmentBeMoved(ctx, segment.GetCollectionID(), segment.GetID()) }) segmentDist[node] = segments } diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index 52b7e2d7a7b89..225c7c0f0bc91 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -16,6 +16,7 @@ package balance import ( + "context" "testing" "github.com/samber/lo" @@ -85,6 +86,7 @@ func (suite *ScoreBasedBalancerTestSuite) TearDownTest() { } func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { + ctx := context.Background() cases := []struct { name string comment string @@ -240,7 +242,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { suite.balancer.nodeManager.Add(nodeInfo) } for i := range c.collectionIDs { - plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false) + plans := balancer.AssignSegment(ctx, c.collectionIDs[i], c.assignments[i], c.nodes, false) if c.unstableAssignment { suite.Len(plans, len(c.expectPlans[i])) } else { @@ -255,9 +257,10 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { suite.SetupSuite() defer suite.TearDownTest() balancer := suite.balancer + ctx := context.Background() paramtable.Get().Save(paramtable.Get().QueryCoordCfg.DelegatorMemoryOverloadFactor.Key, "0.3") - suite.balancer.meta.PutCollection(&meta.Collection{ + suite.balancer.meta.PutCollection(ctx, &meta.Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: 1, }, @@ -300,13 +303,14 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { CollectionID: 1, } suite.balancer.dist.LeaderViewManager.Update(1, leaderView) - plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false) + plans := balancer.AssignSegment(ctx, 1, toAssign, lo.Keys(distributions), false) for _, p := range plans { suite.Equal(int64(2), p.To) } } func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -377,11 +381,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -401,7 +405,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -414,6 +418,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { } func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -463,12 +468,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -494,7 +499,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -520,6 +525,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestDelegatorPreserveMemory() { } func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -572,11 +578,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -596,7 +602,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -618,6 +624,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceWithExecutingTask() { } func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { + ctx := context.Background() balanceCase := struct { name string nodes []int64 @@ -695,12 +702,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded collection.LoadType = querypb.LoadType_LoadCollection - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i], + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(balanceCase.replicaIDs[i], balanceCase.collectionIDs[i], append(balanceCase.nodes, balanceCase.notExistedNodes...))) - balancer.targetMgr.UpdateCollectionNextTarget(balanceCase.collectionIDs[i]) - balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, balanceCase.collectionIDs[i]) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, balanceCase.collectionIDs[i]) } // 2. set up target for distribution for multi collections @@ -717,7 +724,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { }) nodeInfo.SetState(balanceCase.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(balanceCase.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, balanceCase.nodes[i]) } // 4. first round balance @@ -735,6 +742,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { } func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -838,11 +846,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -862,11 +870,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } for i := range c.outBoundNodes { - suite.balancer.meta.ResourceManager.HandleNodeDown(c.outBoundNodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeDown(ctx, c.outBoundNodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -879,6 +887,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { } func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() { + ctx := context.Background() cases := []struct { name string collectionID int64 @@ -955,13 +964,13 @@ func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) for replicaID, nodes := range c.replicaWithNodes { - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, c.collectionID, nodes)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, c.collectionID, nodes)) } - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.segmentDist { @@ -982,7 +991,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() { nodeInfo.UpdateStats(session.WithChannelCnt(len(c.channelDist[nodes[i]]))) nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i]) } } @@ -1006,6 +1015,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestMultiReplicaBalance() { } func (suite *ScoreBasedBalancerTestSuite) TestQNMemoryCapacity() { + ctx := context.Background() cases := []struct { name string nodes []int64 @@ -1054,12 +1064,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestQNMemoryCapacity() { suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionID).Return([]int64{c.collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(c.collectionID, c.collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) - balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(c.collectionID, c.collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(c.replicaID, c.collectionID, c.nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, c.collectionID) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, c.collectionID) // 2. set up target for distribution for multi collections for node, s := range c.distributions { @@ -1081,7 +1091,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestQNMemoryCapacity() { nodeInfo.SetState(c.states[i]) nodeInfoMap[c.nodes[i]] = nodeInfo suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(c.nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, c.nodes[i]) } utils.RecoverAllCollection(balancer.meta) @@ -1113,10 +1123,11 @@ func TestScoreBasedBalancerSuite(t *testing.T) { func (suite *ScoreBasedBalancerTestSuite) getCollectionBalancePlans(balancer *ScoreBasedBalancer, collectionID int64, ) ([]SegmentAssignPlan, []ChannelAssignPlan) { - replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID) + ctx := context.Background() + replicas := balancer.meta.ReplicaManager.GetByCollection(ctx, collectionID) segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) for _, replica := range replicas { - sPlans, cPlans := balancer.BalanceReplica(replica) + sPlans, cPlans := balancer.BalanceReplica(ctx, replica) segmentPlans = append(segmentPlans, sPlans...) channelPlans = append(channelPlans, cPlans...) } @@ -1124,6 +1135,7 @@ func (suite *ScoreBasedBalancerTestSuite) getCollectionBalancePlans(balancer *Sc } func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() { + ctx := context.Background() nodes := []int64{1, 2, 3} collectionID := int64(1) replicaID := int64(1) @@ -1140,11 +1152,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() { suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{collectionID}, nil).Maybe() collection.LoadPercentage = 100 collection.Status = querypb.LoadStatus_Loaded - balancer.meta.CollectionManager.PutCollection(collection) - balancer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, collectionID)) - balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(replicaID, collectionID, nodes)) - balancer.targetMgr.UpdateCollectionNextTarget(collectionID) - balancer.targetMgr.UpdateCollectionCurrentTarget(collectionID) + balancer.meta.CollectionManager.PutCollection(ctx, collection) + balancer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, collectionID)) + balancer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(replicaID, collectionID, nodes)) + balancer.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + balancer.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) for i := range nodes { nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -1155,7 +1167,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceSegmentAndChannel() { }) nodeInfo.SetState(states[i]) suite.balancer.nodeManager.Add(nodeInfo) - suite.balancer.meta.ResourceManager.HandleNodeUp(nodes[i]) + suite.balancer.meta.ResourceManager.HandleNodeUp(ctx, nodes[i]) } utils.RecoverAllCollection(balancer.meta) diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index 7acdb898e681e..fb7fa0447f40c 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -73,19 +73,19 @@ func (b *BalanceChecker) Description() string { return "BalanceChecker checks the cluster distribution and generates balance tasks" } -func (b *BalanceChecker) readyToCheck(collectionID int64) bool { - metaExist := (b.meta.GetCollection(collectionID) != nil) - targetExist := b.targetMgr.IsNextTargetExist(collectionID) || b.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID) +func (b *BalanceChecker) readyToCheck(ctx context.Context, collectionID int64) bool { + metaExist := (b.meta.GetCollection(ctx, collectionID) != nil) + targetExist := b.targetMgr.IsNextTargetExist(ctx, collectionID) || b.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) return metaExist && targetExist } -func (b *BalanceChecker) replicasToBalance() []int64 { - ids := b.meta.GetAll() +func (b *BalanceChecker) replicasToBalance(ctx context.Context) []int64 { + ids := b.meta.GetAll(ctx) // all replicas belonging to loading collection will be skipped loadedCollections := lo.Filter(ids, func(cid int64, _ int) bool { - collection := b.meta.GetCollection(cid) + collection := b.meta.GetCollection(ctx, cid) return collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded }) sort.Slice(loadedCollections, func(i, j int) bool { @@ -97,10 +97,10 @@ func (b *BalanceChecker) replicasToBalance() []int64 { stoppingReplicas := make([]int64, 0) for _, cid := range loadedCollections { // if target and meta isn't ready, skip balance this collection - if !b.readyToCheck(cid) { + if !b.readyToCheck(ctx, cid) { continue } - replicas := b.meta.ReplicaManager.GetByCollection(cid) + replicas := b.meta.ReplicaManager.GetByCollection(ctx, cid) for _, replica := range replicas { if replica.RONodesCount() > 0 { stoppingReplicas = append(stoppingReplicas, replica.GetID()) @@ -130,7 +130,7 @@ func (b *BalanceChecker) replicasToBalance() []int64 { } hasUnbalancedCollection = true b.normalBalanceCollectionsCurrentRound.Insert(cid) - for _, replica := range b.meta.ReplicaManager.GetByCollection(cid) { + for _, replica := range b.meta.ReplicaManager.GetByCollection(ctx, cid) { normalReplicasToBalance = append(normalReplicasToBalance, replica.GetID()) } break @@ -144,14 +144,14 @@ func (b *BalanceChecker) replicasToBalance() []int64 { return normalReplicasToBalance } -func (b *BalanceChecker) balanceReplicas(replicaIDs []int64) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { +func (b *BalanceChecker) balanceReplicas(ctx context.Context, replicaIDs []int64) ([]balance.SegmentAssignPlan, []balance.ChannelAssignPlan) { segmentPlans, channelPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0) for _, rid := range replicaIDs { - replica := b.meta.ReplicaManager.Get(rid) + replica := b.meta.ReplicaManager.Get(ctx, rid) if replica == nil { continue } - sPlans, cPlans := b.getBalancerFunc().BalanceReplica(replica) + sPlans, cPlans := b.getBalancerFunc().BalanceReplica(ctx, replica) segmentPlans = append(segmentPlans, sPlans...) channelPlans = append(channelPlans, cPlans...) if len(segmentPlans) != 0 || len(channelPlans) != 0 { @@ -164,12 +164,12 @@ func (b *BalanceChecker) balanceReplicas(replicaIDs []int64) ([]balance.SegmentA func (b *BalanceChecker) Check(ctx context.Context) []task.Task { ret := make([]task.Task, 0) - replicasToBalance := b.replicasToBalance() - segmentPlans, channelPlans := b.balanceReplicas(replicasToBalance) + replicasToBalance := b.replicasToBalance(ctx) + segmentPlans, channelPlans := b.balanceReplicas(ctx, replicasToBalance) // iterate all collection to find a collection to balance for len(segmentPlans) == 0 && len(channelPlans) == 0 && b.normalBalanceCollectionsCurrentRound.Len() > 0 { - replicasToBalance := b.replicasToBalance() - segmentPlans, channelPlans = b.balanceReplicas(replicasToBalance) + replicasToBalance := b.replicasToBalance(ctx) + segmentPlans, channelPlans = b.balanceReplicas(ctx, replicasToBalance) } tasks := balance.CreateSegmentTasksFromPlans(ctx, b.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), segmentPlans) diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 744d9a2fc7dd7..2bc24627f8721 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -86,6 +86,7 @@ func (suite *BalanceCheckerTestSuite) TearDownTest() { } func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { + ctx := context.Background() // set up nodes info nodeID1, nodeID2 := 1, 2 suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -98,8 +99,8 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { Address: "localhost", Hostname: "localhost", })) - suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID2)) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) // set collections meta segments := []*datapb.SegmentInfo{ @@ -123,46 +124,47 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { collection1.Status = querypb.LoadStatus_Loaded replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(collection1, partition1) - suite.checker.meta.ReplicaManager.Put(replica1) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) cid2, replicaID2, partitionID2 := 2, 2, 2 collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) collection2.Status = querypb.LoadStatus_Loaded replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(collection2, partition2) - suite.checker.meta.ReplicaManager.Put(replica2) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) // test disable auto balance paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false") suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { return 0 }) - replicasToBalance := suite.checker.replicasToBalance() + replicasToBalance := suite.checker.replicasToBalance(ctx) suite.Empty(replicasToBalance) - segPlans, _ := suite.checker.balanceReplicas(replicasToBalance) + segPlans, _ := suite.checker.balanceReplicas(ctx, replicasToBalance) suite.Empty(segPlans) // test enable auto balance paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") idsToBalance := []int64{int64(replicaID1)} - replicasToBalance = suite.checker.replicasToBalance() + replicasToBalance = suite.checker.replicasToBalance(ctx) suite.ElementsMatch(idsToBalance, replicasToBalance) // next round idsToBalance = []int64{int64(replicaID2)} - replicasToBalance = suite.checker.replicasToBalance() + replicasToBalance = suite.checker.replicasToBalance(ctx) suite.ElementsMatch(idsToBalance, replicasToBalance) // final round - replicasToBalance = suite.checker.replicasToBalance() + replicasToBalance = suite.checker.replicasToBalance(ctx) suite.Empty(replicasToBalance) } func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { + ctx := context.Background() // set up nodes info nodeID1, nodeID2 := 1, 2 suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -175,8 +177,8 @@ func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { Address: "localhost", Hostname: "localhost", })) - suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID2)) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) segments := []*datapb.SegmentInfo{ { @@ -199,31 +201,32 @@ func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { collection1.Status = querypb.LoadStatus_Loaded replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(collection1, partition1) - suite.checker.meta.ReplicaManager.Put(replica1) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) cid2, replicaID2, partitionID2 := 2, 2, 2 collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) collection2.Status = querypb.LoadStatus_Loaded replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(collection2, partition2) - suite.checker.meta.ReplicaManager.Put(replica2) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) // test scheduler busy paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { return 1 }) - replicasToBalance := suite.checker.replicasToBalance() + replicasToBalance := suite.checker.replicasToBalance(ctx) suite.Len(replicasToBalance, 1) } func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { + ctx := context.Background() // set up nodes info, stopping node1 nodeID1, nodeID2 := 1, 2 suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -237,8 +240,8 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { Hostname: "localhost", })) suite.nodeMgr.Stopping(int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID1)) - suite.checker.meta.ResourceManager.HandleNodeUp(int64(nodeID2)) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID1)) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, int64(nodeID2)) segments := []*datapb.SegmentInfo{ { @@ -261,32 +264,32 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { collection1.Status = querypb.LoadStatus_Loaded replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{int64(nodeID1), int64(nodeID2)}) partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(collection1, partition1) - suite.checker.meta.ReplicaManager.Put(replica1) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) cid2, replicaID2, partitionID2 := 2, 2, 2 collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) collection2.Status = querypb.LoadStatus_Loaded replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{int64(nodeID1), int64(nodeID2)}) partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(collection2, partition2) - suite.checker.meta.ReplicaManager.Put(replica2) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid2)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid2)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid2)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid2)) mr1 := replica1.CopyForWrite() mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(mr1.IntoReplica()) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) mr2 := replica2.CopyForWrite() mr2.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(mr2.IntoReplica()) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) // test stopping balance idsToBalance := []int64{int64(replicaID1), int64(replicaID2)} - replicasToBalance := suite.checker.replicasToBalance() + replicasToBalance := suite.checker.replicasToBalance(ctx) suite.ElementsMatch(idsToBalance, replicasToBalance) // checker check @@ -298,12 +301,13 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { To: 2, } segPlans = append(segPlans, mockPlan) - suite.balancer.EXPECT().BalanceReplica(mock.Anything).Return(segPlans, chanPlans) + suite.balancer.EXPECT().BalanceReplica(mock.Anything, mock.Anything).Return(segPlans, chanPlans) tasks := suite.checker.Check(context.TODO()) suite.Len(tasks, 2) } func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { + ctx := context.Background() // set up nodes info, stopping node1 nodeID1, nodeID2 := int64(1), int64(2) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -317,8 +321,8 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { Hostname: "localhost", })) suite.nodeMgr.Stopping(nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(nodeID1) - suite.checker.meta.ResourceManager.HandleNodeUp(nodeID2) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID1) + suite.checker.meta.ResourceManager.HandleNodeUp(ctx, nodeID2) segments := []*datapb.SegmentInfo{ { @@ -341,30 +345,30 @@ func (suite *BalanceCheckerTestSuite) TestTargetNotReady() { collection1.Status = querypb.LoadStatus_Loaded replica1 := utils.CreateTestReplica(int64(replicaID1), int64(cid1), []int64{nodeID1, nodeID2}) partition1 := utils.CreateTestPartition(int64(cid1), int64(partitionID1)) - suite.checker.meta.CollectionManager.PutCollection(collection1, partition1) - suite.checker.meta.ReplicaManager.Put(replica1) - suite.targetMgr.UpdateCollectionNextTarget(int64(cid1)) - suite.targetMgr.UpdateCollectionCurrentTarget(int64(cid1)) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection1, partition1) + suite.checker.meta.ReplicaManager.Put(ctx, replica1) + suite.targetMgr.UpdateCollectionNextTarget(ctx, int64(cid1)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(cid1)) cid2, replicaID2, partitionID2 := 2, 2, 2 collection2 := utils.CreateTestCollection(int64(cid2), int32(replicaID2)) collection2.Status = querypb.LoadStatus_Loaded replica2 := utils.CreateTestReplica(int64(replicaID2), int64(cid2), []int64{nodeID1, nodeID2}) partition2 := utils.CreateTestPartition(int64(cid2), int64(partitionID2)) - suite.checker.meta.CollectionManager.PutCollection(collection2, partition2) - suite.checker.meta.ReplicaManager.Put(replica2) + suite.checker.meta.CollectionManager.PutCollection(ctx, collection2, partition2) + suite.checker.meta.ReplicaManager.Put(ctx, replica2) mr1 := replica1.CopyForWrite() mr1.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(mr1.IntoReplica()) + suite.checker.meta.ReplicaManager.Put(ctx, mr1.IntoReplica()) mr2 := replica2.CopyForWrite() mr2.AddRONode(1) - suite.checker.meta.ReplicaManager.Put(mr2.IntoReplica()) + suite.checker.meta.ReplicaManager.Put(ctx, mr2.IntoReplica()) // test stopping balance idsToBalance := []int64{int64(replicaID1)} - replicasToBalance := suite.checker.replicasToBalance() + replicasToBalance := suite.checker.replicasToBalance(ctx) suite.ElementsMatch(idsToBalance, replicasToBalance) } diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index 78860f56f3a5f..ce7e36c31e00a 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -70,9 +70,9 @@ func (c *ChannelChecker) Description() string { return "DmChannelChecker checks the lack of DmChannels, or some DmChannels are redundant" } -func (c *ChannelChecker) readyToCheck(collectionID int64) bool { - metaExist := (c.meta.GetCollection(collectionID) != nil) - targetExist := c.targetMgr.IsNextTargetExist(collectionID) || c.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID) +func (c *ChannelChecker) readyToCheck(ctx context.Context, collectionID int64) bool { + metaExist := (c.meta.GetCollection(ctx, collectionID) != nil) + targetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID) || c.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) return metaExist && targetExist } @@ -81,11 +81,11 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task { if !c.IsActive() { return nil } - collectionIDs := c.meta.CollectionManager.GetAll() + collectionIDs := c.meta.CollectionManager.GetAll(ctx) tasks := make([]task.Task, 0) for _, cid := range collectionIDs { - if c.readyToCheck(cid) { - replicas := c.meta.ReplicaManager.GetByCollection(cid) + if c.readyToCheck(ctx, cid) { + replicas := c.meta.ReplicaManager.GetByCollection(ctx, cid) for _, r := range replicas { tasks = append(tasks, c.checkReplica(ctx, r)...) } @@ -105,7 +105,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task { channelOnQN := c.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(nodeID)) collectionChannels := lo.GroupBy(channelOnQN, func(ch *meta.DmChannel) int64 { return ch.CollectionID }) for collectionID, channels := range collectionChannels { - replica := c.meta.ReplicaManager.GetByCollectionAndNode(collectionID, nodeID) + replica := c.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, nodeID) if replica == nil { reduceTasks := c.createChannelReduceTasks(ctx, channels, meta.NilReplica) task.SetReason("dirty channel exists", reduceTasks...) @@ -119,7 +119,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task { func (c *ChannelChecker) checkReplica(ctx context.Context, replica *meta.Replica) []task.Task { ret := make([]task.Task, 0) - lacks, redundancies := c.getDmChannelDiff(replica.GetCollectionID(), replica.GetID()) + lacks, redundancies := c.getDmChannelDiff(ctx, replica.GetCollectionID(), replica.GetID()) tasks := c.createChannelLoadTask(c.getTraceCtx(ctx, replica.GetCollectionID()), lacks, replica) task.SetReason("lacks of channel", tasks...) ret = append(ret, tasks...) @@ -139,10 +139,10 @@ func (c *ChannelChecker) checkReplica(ctx context.Context, replica *meta.Replica } // GetDmChannelDiff get channel diff between target and dist -func (c *ChannelChecker) getDmChannelDiff(collectionID int64, +func (c *ChannelChecker) getDmChannelDiff(ctx context.Context, collectionID int64, replicaID int64, ) (toLoad, toRelease []*meta.DmChannel) { - replica := c.meta.Get(replicaID) + replica := c.meta.Get(ctx, replicaID) if replica == nil { log.Info("replica does not exist, skip it") return @@ -154,8 +154,8 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64, distMap.Insert(ch.GetChannelName()) } - nextTargetMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.NextTarget) - currentTargetMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget) + nextTargetMap := c.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.NextTarget) + currentTargetMap := c.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget) // get channels which exists on dist, but not exist on current and next for _, ch := range dist { @@ -179,7 +179,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64, func (c *ChannelChecker) findRepeatedChannels(ctx context.Context, replicaID int64) []*meta.DmChannel { log := log.Ctx(ctx).WithRateGroup("ChannelChecker.findRepeatedChannels", 1, 60) - replica := c.meta.Get(replicaID) + replica := c.meta.Get(ctx, replicaID) ret := make([]*meta.DmChannel, 0) if replica == nil { @@ -232,7 +232,7 @@ func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []* if len(rwNodes) == 0 { rwNodes = replica.GetRWNodes() } - plan := c.getBalancerFunc().AssignChannel([]*meta.DmChannel{ch}, rwNodes, false) + plan := c.getBalancerFunc().AssignChannel(ctx, []*meta.DmChannel{ch}, rwNodes, false) plans = append(plans, plan...) } @@ -264,7 +264,7 @@ func (c *ChannelChecker) createChannelReduceTasks(ctx context.Context, channels } func (c *ChannelChecker) getTraceCtx(ctx context.Context, collectionID int64) context.Context { - coll := c.meta.GetCollection(collectionID) + coll := c.meta.GetCollection(ctx, collectionID) if coll == nil || coll.LoadSpan == nil { return ctx } diff --git a/internal/querycoordv2/checkers/channel_checker_test.go b/internal/querycoordv2/checkers/channel_checker_test.go index 3fe1d50e314dd..9ea63f1da1ea1 100644 --- a/internal/querycoordv2/checkers/channel_checker_test.go +++ b/internal/querycoordv2/checkers/channel_checker_test.go @@ -100,7 +100,7 @@ func (suite *ChannelCheckerTestSuite) setNodeAvailable(nodes ...int64) { func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance { balancer := balance.NewMockBalancer(suite.T()) - balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan { + balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(ctx context.Context, channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan { plans := make([]balance.ChannelAssignPlan, 0, len(channels)) for i, c := range channels { plan := balance.ChannelAssignPlan{ @@ -117,16 +117,17 @@ func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance { } func (suite *ChannelCheckerTestSuite) TestLoadChannel() { + ctx := context.Background() checker := suite.checker - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) channels := []*datapb.VchannelInfo{ { @@ -137,7 +138,7 @@ func (suite *ChannelCheckerTestSuite) TestLoadChannel() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, nil, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) tasks := checker.Check(context.TODO()) suite.Len(tasks, 1) @@ -151,10 +152,11 @@ func (suite *ChannelCheckerTestSuite) TestLoadChannel() { } func (suite *ChannelCheckerTestSuite) TestReduceChannel() { + ctx := context.Background() checker := suite.checker - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1})) channels := []*datapb.VchannelInfo{ { @@ -164,8 +166,8 @@ func (suite *ChannelCheckerTestSuite) TestReduceChannel() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, nil, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 1, "test-insert-channel1")) checker.dist.LeaderViewManager.Update(1, &meta.LeaderView{ID: 1, Channel: "test-insert-channel1"}) @@ -184,11 +186,12 @@ func (suite *ChannelCheckerTestSuite) TestReduceChannel() { } func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() { + ctx := context.Background() checker := suite.checker - err := checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) + err := checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) suite.NoError(err) - err = checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + err = checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) suite.NoError(err) segments := []*datapb.SegmentInfo{ @@ -206,7 +209,7 @@ func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 1, "test-insert-channel")) checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 2, "test-insert-channel")) @@ -228,11 +231,12 @@ func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() { } func (suite *ChannelCheckerTestSuite) TestReleaseDirtyChannels() { + ctx := context.Background() checker := suite.checker - err := checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) + err := checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) suite.NoError(err) - err = checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1})) + err = checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1})) suite.NoError(err) segments := []*datapb.SegmentInfo{ @@ -261,7 +265,7 @@ func (suite *ChannelCheckerTestSuite) TestReleaseDirtyChannels() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 2, "test-insert-channel")) checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 2, "test-insert-channel")) checker.dist.LeaderViewManager.Update(1, &meta.LeaderView{ID: 1, Channel: "test-insert-channel"}) diff --git a/internal/querycoordv2/checkers/controller_test.go b/internal/querycoordv2/checkers/controller_test.go index 95087bf256891..60200c847bb5c 100644 --- a/internal/querycoordv2/checkers/controller_test.go +++ b/internal/querycoordv2/checkers/controller_test.go @@ -17,6 +17,7 @@ package checkers import ( + "context" "testing" "time" @@ -85,10 +86,11 @@ func (suite *CheckerControllerSuite) SetupTest() { } func (suite *CheckerControllerSuite) TestBasic() { + ctx := context.Background() // set meta - suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - suite.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -99,8 +101,8 @@ func (suite *CheckerControllerSuite) TestBasic() { Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(1) - suite.meta.ResourceManager.HandleNodeUp(2) + suite.meta.ResourceManager.HandleNodeUp(ctx, 1) + suite.meta.ResourceManager.HandleNodeUp(ctx, 2) // set target channels := []*datapb.VchannelInfo{ @@ -119,7 +121,7 @@ func (suite *CheckerControllerSuite) TestBasic() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - suite.targetManager.UpdateCollectionNextTarget(int64(1)) + suite.targetManager.UpdateCollectionNextTarget(ctx, int64(1)) // set dist suite.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -134,11 +136,11 @@ func (suite *CheckerControllerSuite) TestBasic() { assignSegCounter := atomic.NewInt32(0) assingChanCounter := atomic.NewInt32(0) - suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan { + suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan { assignSegCounter.Inc() return nil }) - suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan { + suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan { assingChanCounter.Inc() return nil }) diff --git a/internal/querycoordv2/checkers/index_checker.go b/internal/querycoordv2/checkers/index_checker.go index 2afabc6bf3ee1..970ef6fb4b9a1 100644 --- a/internal/querycoordv2/checkers/index_checker.go +++ b/internal/querycoordv2/checkers/index_checker.go @@ -79,7 +79,7 @@ func (c *IndexChecker) Check(ctx context.Context) []task.Task { if !c.IsActive() { return nil } - collectionIDs := c.meta.CollectionManager.GetAll() + collectionIDs := c.meta.CollectionManager.GetAll(ctx) var tasks []task.Task for _, collectionID := range collectionIDs { @@ -89,12 +89,12 @@ func (c *IndexChecker) Check(ctx context.Context) []task.Task { continue } - collection := c.meta.CollectionManager.GetCollection(collectionID) + collection := c.meta.CollectionManager.GetCollection(ctx, collectionID) if collection == nil { log.Warn("collection released during check index", zap.Int64("collection", collectionID)) continue } - replicas := c.meta.ReplicaManager.GetByCollection(collectionID) + replicas := c.meta.ReplicaManager.GetByCollection(ctx, collectionID) for _, replica := range replicas { tasks = append(tasks, c.checkReplica(ctx, collection, replica, indexInfos)...) } @@ -121,7 +121,7 @@ func (c *IndexChecker) checkReplica(ctx context.Context, collection *meta.Collec } // skip update index for l0 segment - segmentInTarget := c.targetMgr.GetSealedSegment(collection.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst) + segmentInTarget := c.targetMgr.GetSealedSegment(ctx, collection.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst) if segmentInTarget == nil || segmentInTarget.GetLevel() == datapb.SegmentLevel_L0 { continue } diff --git a/internal/querycoordv2/checkers/index_checker_test.go b/internal/querycoordv2/checkers/index_checker_test.go index aaa8f341a8196..2a6a7b99f232d 100644 --- a/internal/querycoordv2/checkers/index_checker_test.go +++ b/internal/querycoordv2/checkers/index_checker_test.go @@ -78,7 +78,7 @@ func (suite *IndexCheckerSuite) SetupTest() { suite.targetMgr = meta.NewMockTargetManager(suite.T()) suite.checker = NewIndexChecker(suite.meta, distManager, suite.broker, suite.nodeMgr, suite.targetMgr) - suite.targetMgr.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(cid, sid int64, i3 int32) *datapb.SegmentInfo { + suite.targetMgr.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cid, sid int64, i3 int32) *datapb.SegmentInfo { return &datapb.SegmentInfo{ ID: sid, Level: datapb.SegmentLevel_L1, @@ -92,12 +92,13 @@ func (suite *IndexCheckerSuite) TearDownTest() { func (suite *IndexCheckerSuite) TestLoadIndex() { checker := suite.checker + ctx := context.Background() // meta coll := utils.CreateTestCollection(1, 1) coll.FieldIndexID = map[int64]int64{101: 1000} - checker.meta.CollectionManager.PutCollection(coll) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, coll) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -108,8 +109,8 @@ func (suite *IndexCheckerSuite) TestLoadIndex() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // dist checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel")) @@ -147,8 +148,8 @@ func (suite *IndexCheckerSuite) TestLoadIndex() { // test skip load index for read only node suite.nodeMgr.Stopping(1) suite.nodeMgr.Stopping(2) - suite.meta.ResourceManager.HandleNodeStopping(1) - suite.meta.ResourceManager.HandleNodeStopping(2) + suite.meta.ResourceManager.HandleNodeStopping(ctx, 1) + suite.meta.ResourceManager.HandleNodeStopping(ctx, 2) utils.RecoverAllCollection(suite.meta) tasks = checker.Check(context.Background()) suite.Require().Len(tasks, 0) @@ -156,12 +157,13 @@ func (suite *IndexCheckerSuite) TestLoadIndex() { func (suite *IndexCheckerSuite) TestIndexInfoNotMatch() { checker := suite.checker + ctx := context.Background() // meta coll := utils.CreateTestCollection(1, 1) coll.FieldIndexID = map[int64]int64{101: 1000} - checker.meta.CollectionManager.PutCollection(coll) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, coll) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -172,8 +174,8 @@ func (suite *IndexCheckerSuite) TestIndexInfoNotMatch() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // dist checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel")) @@ -216,12 +218,13 @@ func (suite *IndexCheckerSuite) TestIndexInfoNotMatch() { func (suite *IndexCheckerSuite) TestGetIndexInfoFailed() { checker := suite.checker + ctx := context.Background() // meta coll := utils.CreateTestCollection(1, 1) coll.FieldIndexID = map[int64]int64{101: 1000} - checker.meta.CollectionManager.PutCollection(coll) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, coll) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -232,8 +235,8 @@ func (suite *IndexCheckerSuite) TestGetIndexInfoFailed() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // dist checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel")) @@ -255,12 +258,13 @@ func (suite *IndexCheckerSuite) TestGetIndexInfoFailed() { func (suite *IndexCheckerSuite) TestCreateNewIndex() { checker := suite.checker + ctx := context.Background() // meta coll := utils.CreateTestCollection(1, 1) coll.FieldIndexID = map[int64]int64{101: 1000} - checker.meta.CollectionManager.PutCollection(coll) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(200, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, coll) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(200, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -271,8 +275,8 @@ func (suite *IndexCheckerSuite) TestCreateNewIndex() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // dist segment := utils.CreateTestSegment(1, 1, 2, 1, 1, "test-insert-channel") diff --git a/internal/querycoordv2/checkers/leader_checker.go b/internal/querycoordv2/checkers/leader_checker.go index 77fdda14dc859..f95c376c3cd68 100644 --- a/internal/querycoordv2/checkers/leader_checker.go +++ b/internal/querycoordv2/checkers/leader_checker.go @@ -65,9 +65,9 @@ func (c *LeaderChecker) Description() string { return "LeaderChecker checks the difference of leader view between dist, and try to correct it" } -func (c *LeaderChecker) readyToCheck(collectionID int64) bool { - metaExist := (c.meta.GetCollection(collectionID) != nil) - targetExist := c.target.IsNextTargetExist(collectionID) || c.target.IsCurrentTargetExist(collectionID, common.AllPartitionsID) +func (c *LeaderChecker) readyToCheck(ctx context.Context, collectionID int64) bool { + metaExist := (c.meta.GetCollection(ctx, collectionID) != nil) + targetExist := c.target.IsNextTargetExist(ctx, collectionID) || c.target.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) return metaExist && targetExist } @@ -77,20 +77,20 @@ func (c *LeaderChecker) Check(ctx context.Context) []task.Task { return nil } - collectionIDs := c.meta.CollectionManager.GetAll() + collectionIDs := c.meta.CollectionManager.GetAll(ctx) tasks := make([]task.Task, 0) for _, collectionID := range collectionIDs { - if !c.readyToCheck(collectionID) { + if !c.readyToCheck(ctx, collectionID) { continue } - collection := c.meta.CollectionManager.GetCollection(collectionID) + collection := c.meta.CollectionManager.GetCollection(ctx, collectionID) if collection == nil { log.Warn("collection released during check leader", zap.Int64("collection", collectionID)) continue } - replicas := c.meta.ReplicaManager.GetByCollection(collectionID) + replicas := c.meta.ReplicaManager.GetByCollection(ctx, collectionID) for _, replica := range replicas { for _, node := range replica.GetRWNodes() { leaderViews := c.dist.LeaderViewManager.GetByFilter(meta.WithCollectionID2LeaderView(replica.GetCollectionID()), meta.WithNodeID2LeaderView(node)) @@ -109,7 +109,7 @@ func (c *LeaderChecker) Check(ctx context.Context) []task.Task { func (c *LeaderChecker) findNeedSyncPartitionStats(ctx context.Context, replica *meta.Replica, leaderView *meta.LeaderView, nodeID int64) []task.Task { ret := make([]task.Task, 0) - curDmlChannel := c.target.GetDmChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) + curDmlChannel := c.target.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) if curDmlChannel == nil { return ret } @@ -163,7 +163,7 @@ func (c *LeaderChecker) findNeedLoadedSegments(ctx context.Context, replica *met latestNodeDist := utils.FindMaxVersionSegments(dist) for _, s := range latestNodeDist { - segment := c.target.GetSealedSegment(leaderView.CollectionID, s.GetID(), meta.CurrentTargetFirst) + segment := c.target.GetSealedSegment(ctx, leaderView.CollectionID, s.GetID(), meta.CurrentTargetFirst) existInTarget := segment != nil isL0Segment := existInTarget && segment.GetLevel() == datapb.SegmentLevel_L0 // shouldn't set l0 segment location to delegator. l0 segment should be reload in delegator @@ -213,7 +213,7 @@ func (c *LeaderChecker) findNeedRemovedSegments(ctx context.Context, replica *me for sid, s := range leaderView.Segments { _, ok := distMap[sid] - segment := c.target.GetSealedSegment(leaderView.CollectionID, sid, meta.CurrentTargetFirst) + segment := c.target.GetSealedSegment(ctx, leaderView.CollectionID, sid, meta.CurrentTargetFirst) existInTarget := segment != nil isL0Segment := existInTarget && segment.GetLevel() == datapb.SegmentLevel_L0 if ok || existInTarget || isL0Segment { diff --git a/internal/querycoordv2/checkers/leader_checker_test.go b/internal/querycoordv2/checkers/leader_checker_test.go index 0d2249b14ad15..2a4a238cefa2c 100644 --- a/internal/querycoordv2/checkers/leader_checker_test.go +++ b/internal/querycoordv2/checkers/leader_checker_test.go @@ -83,10 +83,11 @@ func (suite *LeaderCheckerTestSuite) TearDownTest() { } func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -119,13 +120,13 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() { })) // test leader view lack of segments - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) loadVersion := time.Now().UnixMilli() observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, loadVersion, "test-insert-channel")) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) tasks = suite.checker.Check(context.TODO()) @@ -140,7 +141,7 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() { // test segment's version in leader view doesn't match segment's version in dist observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel")) view = utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) view.Segments[1] = &querypb.SegmentDist{ NodeID: 0, Version: time.Now().UnixMilli() - 1, @@ -168,23 +169,24 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegments() { suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) // mock l0 segment exist on non delegator node, doesn't set to leader view observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, loadVersion, "test-insert-channel")) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view = utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) tasks = suite.checker.Check(context.TODO()) suite.Len(tasks, 0) } func (suite *LeaderCheckerTestSuite) TestActivation() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -211,12 +213,12 @@ func (suite *LeaderCheckerTestSuite) TestActivation() { Address: "localhost", Hostname: "localhost", })) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel")) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) suite.checker.Deactivate() @@ -234,11 +236,12 @@ func (suite *LeaderCheckerTestSuite) TestActivation() { } func (suite *LeaderCheckerTestSuite) TestStoppingNode() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) replica := utils.CreateTestReplica(1, 1, []int64{1, 2}) - observer.meta.ReplicaManager.Put(replica) + observer.meta.ReplicaManager.Put(ctx, replica) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -254,27 +257,28 @@ func (suite *LeaderCheckerTestSuite) TestStoppingNode() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel")) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) mutableReplica := replica.CopyForWrite() mutableReplica.AddRONode(2) - observer.meta.ReplicaManager.Put(mutableReplica.IntoReplica()) + observer.meta.ReplicaManager.Put(ctx, mutableReplica.IntoReplica()) tasks := suite.checker.Check(context.TODO()) suite.Len(tasks, 0) } func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -301,14 +305,14 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() { Address: "localhost", Hostname: "localhost", })) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) - observer.target.UpdateCollectionNextTarget(int64(1)) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 2, 1, "test-insert-channel"), utils.CreateTestSegment(1, 1, 2, 2, 1, "test-insert-channel")) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) tasks := suite.checker.Check(context.TODO()) @@ -322,11 +326,12 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncLoadedSegments() { } func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 2)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(2, 1, []int64{3, 4})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 2)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(2, 1, []int64{3, 4})) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -354,17 +359,17 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() { Hostname: "localhost", })) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, 0, "test-insert-channel")) observer.dist.SegmentDistManager.Update(4, utils.CreateTestSegment(1, 1, 1, 4, 0, "test-insert-channel")) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) observer.dist.ChannelDistManager.Update(4, utils.CreateTestChannel(1, 4, 2, "test-insert-channel")) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) view2 := utils.CreateTestLeaderView(4, 1, "test-insert-channel", map[int64]int64{1: 4}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(4, view2) tasks := suite.checker.Check(context.TODO()) @@ -379,10 +384,11 @@ func (suite *LeaderCheckerTestSuite) TestSyncLoadedSegmentsWithReplicas() { } func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) channels := []*datapb.VchannelInfo{ { @@ -393,12 +399,12 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, nil, nil) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 1}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) tasks := suite.checker.Check(context.TODO()) @@ -425,12 +431,12 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) view = utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 1}, map[int64]*meta.Segment{}) - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, view) tasks = suite.checker.Check(context.TODO()) @@ -438,10 +444,11 @@ func (suite *LeaderCheckerTestSuite) TestSyncRemovedSegments() { } func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() { + ctx := context.Background() observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { @@ -458,7 +465,7 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - observer.target.UpdateCollectionNextTarget(int64(1)) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) observer.dist.LeaderViewManager.Update(2, utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 2, 2: 2}, map[int64]*meta.Segment{})) @@ -475,12 +482,13 @@ func (suite *LeaderCheckerTestSuite) TestIgnoreSyncRemovedSegments() { } func (suite *LeaderCheckerTestSuite) TestUpdatePartitionStats() { + ctx := context.Background() testChannel := "test-insert-channel" leaderID := int64(2) observer := suite.checker - observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + observer.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + observer.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + observer.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { ID: 1, @@ -506,8 +514,8 @@ func (suite *LeaderCheckerTestSuite) TestUpdatePartitionStats() { suite.Len(tasks, 0) // try to update cur/next target - observer.target.UpdateCollectionNextTarget(int64(1)) - observer.target.UpdateCollectionCurrentTarget(1) + observer.target.UpdateCollectionNextTarget(ctx, int64(1)) + observer.target.UpdateCollectionCurrentTarget(ctx, 1) loadVersion := time.Now().UnixMilli() observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 2, 1, loadVersion, testChannel)) observer.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, testChannel)) @@ -516,7 +524,7 @@ func (suite *LeaderCheckerTestSuite) TestUpdatePartitionStats() { 1: 100, } // current partition stat version in leader view is version100 for partition1 - view.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) + view.TargetVersion = observer.target.GetCollectionTargetVersion(ctx, 1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(leaderID, view) tasks = suite.checker.Check(context.TODO()) diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 38c8aff2e22f3..80b8514f52a9d 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -75,9 +75,9 @@ func (c *SegmentChecker) Description() string { return "SegmentChecker checks the lack of segments, or some segments are redundant" } -func (c *SegmentChecker) readyToCheck(collectionID int64) bool { - metaExist := (c.meta.GetCollection(collectionID) != nil) - targetExist := c.targetMgr.IsNextTargetExist(collectionID) || c.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID) +func (c *SegmentChecker) readyToCheck(ctx context.Context, collectionID int64) bool { + metaExist := (c.meta.GetCollection(ctx, collectionID) != nil) + targetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID) || c.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) return metaExist && targetExist } @@ -86,11 +86,11 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task { if !c.IsActive() { return nil } - collectionIDs := c.meta.CollectionManager.GetAll() + collectionIDs := c.meta.CollectionManager.GetAll(ctx) results := make([]task.Task, 0) for _, cid := range collectionIDs { - if c.readyToCheck(cid) { - replicas := c.meta.ReplicaManager.GetByCollection(cid) + if c.readyToCheck(ctx, cid) { + replicas := c.meta.ReplicaManager.GetByCollection(ctx, cid) for _, r := range replicas { results = append(results, c.checkReplica(ctx, r)...) } @@ -111,7 +111,7 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task { segmentsOnQN := c.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(nodeID)) collectionSegments := lo.GroupBy(segmentsOnQN, func(segment *meta.Segment) int64 { return segment.GetCollectionID() }) for collectionID, segments := range collectionSegments { - replica := c.meta.ReplicaManager.GetByCollectionAndNode(collectionID, nodeID) + replica := c.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, nodeID) if replica == nil { reduceTasks := c.createSegmentReduceTasks(ctx, segments, meta.NilReplica, querypb.DataScope_Historical) task.SetReason("dirty segment exists", reduceTasks...) @@ -128,21 +128,21 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica ret := make([]task.Task, 0) // compare with targets to find the lack and redundancy of segments - lacks, redundancies := c.getSealedSegmentDiff(replica.GetCollectionID(), replica.GetID()) + lacks, redundancies := c.getSealedSegmentDiff(ctx, replica.GetCollectionID(), replica.GetID()) // loadCtx := trace.ContextWithSpan(context.Background(), c.meta.GetCollection(replica.CollectionID).LoadSpan) tasks := c.createSegmentLoadTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), lacks, replica) task.SetReason("lacks of segment", tasks...) task.SetPriority(task.TaskPriorityNormal, tasks...) ret = append(ret, tasks...) - redundancies = c.filterSegmentInUse(replica, redundancies) + redundancies = c.filterSegmentInUse(ctx, replica, redundancies) tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Historical) task.SetReason("segment not exists in target", tasks...) task.SetPriority(task.TaskPriorityNormal, tasks...) ret = append(ret, tasks...) // compare inner dists to find repeated loaded segments - redundancies = c.findRepeatedSealedSegments(replica.GetID()) + redundancies = c.findRepeatedSealedSegments(ctx, replica.GetID()) redundancies = c.filterExistedOnLeader(replica, redundancies) tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Historical) task.SetReason("redundancies of segment", tasks...) @@ -151,7 +151,7 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica ret = append(ret, tasks...) // compare with target to find the lack and redundancy of segments - _, redundancies = c.getGrowingSegmentDiff(replica.GetCollectionID(), replica.GetID()) + _, redundancies = c.getGrowingSegmentDiff(ctx, replica.GetCollectionID(), replica.GetID()) tasks = c.createSegmentReduceTasks(c.getTraceCtx(ctx, replica.GetCollectionID()), redundancies, replica, querypb.DataScope_Streaming) task.SetReason("streaming segment not exists in target", tasks...) task.SetPriority(task.TaskPriorityNormal, tasks...) @@ -161,10 +161,10 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica } // GetGrowingSegmentDiff get streaming segment diff between leader view and target -func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64, +func (c *SegmentChecker) getGrowingSegmentDiff(ctx context.Context, collectionID int64, replicaID int64, ) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { - replica := c.meta.Get(replicaID) + replica := c.meta.Get(ctx, replicaID) if replica == nil { log.Info("replica does not exist, skip it") return @@ -181,7 +181,7 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64, log.Info("leaderView is not ready, skip", zap.String("channelName", channelName), zap.Int64("node", node)) continue } - targetVersion := c.targetMgr.GetCollectionTargetVersion(collectionID, meta.CurrentTarget) + targetVersion := c.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.CurrentTarget) if view.TargetVersion != targetVersion { // before shard delegator update it's readable version, skip release segment log.RatedInfo(20, "before shard delegator update it's readable version, skip release segment", @@ -193,10 +193,10 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64, continue } - nextTargetExist := c.targetMgr.IsNextTargetExist(collectionID) - nextTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(collectionID, meta.NextTarget) - currentTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(collectionID, meta.CurrentTarget) - currentTargetChannelMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget) + nextTargetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID) + nextTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(ctx, collectionID, meta.NextTarget) + currentTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(ctx, collectionID, meta.CurrentTarget) + currentTargetChannelMap := c.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget) // get segment which exist on leader view, but not on current target and next target for _, segment := range view.GrowingSegments { @@ -227,10 +227,11 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64, // GetSealedSegmentDiff get historical segment diff between target and dist func (c *SegmentChecker) getSealedSegmentDiff( + ctx context.Context, collectionID int64, replicaID int64, ) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { - replica := c.meta.Get(replicaID) + replica := c.meta.Get(ctx, replicaID) if replica == nil { log.Info("replica does not exist, skip it") return @@ -278,9 +279,9 @@ func (c *SegmentChecker) getSealedSegmentDiff( return !existInDist } - nextTargetExist := c.targetMgr.IsNextTargetExist(collectionID) - nextTargetMap := c.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.NextTarget) - currentTargetMap := c.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.CurrentTarget) + nextTargetExist := c.targetMgr.IsNextTargetExist(ctx, collectionID) + nextTargetMap := c.targetMgr.GetSealedSegmentsByCollection(ctx, collectionID, meta.NextTarget) + currentTargetMap := c.targetMgr.GetSealedSegmentsByCollection(ctx, collectionID, meta.CurrentTarget) // Segment which exist on next target, but not on dist for _, segment := range nextTargetMap { @@ -325,9 +326,9 @@ func (c *SegmentChecker) getSealedSegmentDiff( return } -func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Segment { +func (c *SegmentChecker) findRepeatedSealedSegments(ctx context.Context, replicaID int64) []*meta.Segment { segments := make([]*meta.Segment, 0) - replica := c.meta.Get(replicaID) + replica := c.meta.Get(ctx, replicaID) if replica == nil { log.Info("replica does not exist, skip it") return segments @@ -336,7 +337,7 @@ func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Seg versions := make(map[int64]*meta.Segment) for _, s := range dist { // l0 segment should be release with channel together - segment := c.targetMgr.GetSealedSegment(s.GetCollectionID(), s.GetID(), meta.CurrentTargetFirst) + segment := c.targetMgr.GetSealedSegment(ctx, s.GetCollectionID(), s.GetID(), meta.CurrentTargetFirst) existInTarget := segment != nil isL0Segment := existInTarget && segment.GetLevel() == datapb.SegmentLevel_L0 if isL0Segment { @@ -378,7 +379,7 @@ func (c *SegmentChecker) filterExistedOnLeader(replica *meta.Replica, segments [ return filtered } -func (c *SegmentChecker) filterSegmentInUse(replica *meta.Replica, segments []*meta.Segment) []*meta.Segment { +func (c *SegmentChecker) filterSegmentInUse(ctx context.Context, replica *meta.Replica, segments []*meta.Segment) []*meta.Segment { filtered := make([]*meta.Segment, 0, len(segments)) for _, s := range segments { leaderID, ok := c.dist.ChannelDistManager.GetShardLeader(replica, s.GetInsertChannel()) @@ -387,8 +388,8 @@ func (c *SegmentChecker) filterSegmentInUse(replica *meta.Replica, segments []*m } view := c.dist.LeaderViewManager.GetLeaderShardView(leaderID, s.GetInsertChannel()) - currentTargetVersion := c.targetMgr.GetCollectionTargetVersion(s.CollectionID, meta.CurrentTarget) - partition := c.meta.CollectionManager.GetPartition(s.PartitionID) + currentTargetVersion := c.targetMgr.GetCollectionTargetVersion(ctx, s.CollectionID, meta.CurrentTarget) + partition := c.meta.CollectionManager.GetPartition(ctx, s.PartitionID) // if delegator has valid target version, and before it update to latest readable version, skip release it's sealed segment // Notice: if syncTargetVersion stuck, segment on delegator won't be released @@ -435,7 +436,7 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] SegmentInfo: s, } }) - shardPlans := c.getBalancerFunc().AssignSegment(replica.GetCollectionID(), segmentInfos, rwNodes, false) + shardPlans := c.getBalancerFunc().AssignSegment(ctx, replica.GetCollectionID(), segmentInfos, rwNodes, false) for i := range shardPlans { shardPlans[i].Replica = replica } @@ -474,7 +475,7 @@ func (c *SegmentChecker) createSegmentReduceTasks(ctx context.Context, segments } func (c *SegmentChecker) getTraceCtx(ctx context.Context, collectionID int64) context.Context { - coll := c.meta.GetCollection(collectionID) + coll := c.meta.GetCollection(ctx, collectionID) if coll == nil || coll.LoadSpan == nil { return ctx } diff --git a/internal/querycoordv2/checkers/segment_checker_test.go b/internal/querycoordv2/checkers/segment_checker_test.go index 228013c48dcfb..08f586648dfe4 100644 --- a/internal/querycoordv2/checkers/segment_checker_test.go +++ b/internal/querycoordv2/checkers/segment_checker_test.go @@ -88,7 +88,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() { func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance { balancer := balance.NewMockBalancer(suite.T()) - balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan { + balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(ctx context.Context, collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan { plans := make([]balance.SegmentAssignPlan, 0, len(segments)) for i, s := range segments { plan := balance.SegmentAssignPlan{ @@ -105,11 +105,12 @@ func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance { } func (suite *SegmentCheckerTestSuite) TestLoadSegments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -120,8 +121,8 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // set target segments := []*datapb.SegmentInfo{ @@ -141,7 +142,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) // set dist checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -170,11 +171,12 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() { } func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -187,8 +189,8 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() { Hostname: "localhost", Version: common.Version, })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // set target segments := []*datapb.SegmentInfo{ @@ -209,7 +211,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) // set dist checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -227,7 +229,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() { suite.EqualValues(2, action.Node()) suite.Equal(tasks[0].Priority(), task.TaskPriorityNormal) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) // test load l0 segments in current target tasks = checker.Check(context.TODO()) suite.Len(tasks, 1) @@ -241,7 +243,7 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() { suite.Equal(tasks[0].Priority(), task.TaskPriorityNormal) // seg l0 segment exist on a non delegator node - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) checker.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, 1, "test-insert-channel")) // test load l0 segments to delegator tasks = checker.Check(context.TODO()) @@ -257,11 +259,12 @@ func (suite *SegmentCheckerTestSuite) TestLoadL0Segments() { } func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -272,8 +275,8 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // set target segments := []*datapb.SegmentInfo{ @@ -294,8 +297,8 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) // set dist checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -315,9 +318,9 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() { suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, nil, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) tasks = checker.Check(context.TODO()) suite.Len(tasks, 1) @@ -332,11 +335,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseL0Segments() { } func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -347,8 +351,8 @@ func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() { Address: "localhost", Hostname: "localhost", })) - checker.meta.ResourceManager.HandleNodeUp(1) - checker.meta.ResourceManager.HandleNodeUp(2) + checker.meta.ResourceManager.HandleNodeUp(ctx, 1) + checker.meta.ResourceManager.HandleNodeUp(ctx, 2) // set target segments := []*datapb.SegmentInfo{ @@ -368,7 +372,7 @@ func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) // when channel not subscribed, segment_checker won't generate load segment task tasks := checker.Check(context.TODO()) @@ -376,11 +380,12 @@ func (suite *SegmentCheckerTestSuite) TestSkipLoadSegments() { } func (suite *SegmentCheckerTestSuite) TestReleaseSegments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) // set target channels := []*datapb.VchannelInfo{ @@ -391,7 +396,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseSegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, nil, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) // set dist checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -410,11 +415,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseSegments() { } func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) // set target segments := []*datapb.SegmentInfo{ @@ -432,7 +438,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) // set dist checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -458,11 +464,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() { } func (suite *SegmentCheckerTestSuite) TestReleaseDirtySegments() { + ctx := context.Background() checker := suite.checker // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1})) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -490,7 +497,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseDirtySegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) // set dist checker.dist.ChannelDistManager.Update(2, utils.CreateTestChannel(1, 2, 1, "test-insert-channel")) @@ -510,15 +517,16 @@ func (suite *SegmentCheckerTestSuite) TestReleaseDirtySegments() { } func (suite *SegmentCheckerTestSuite) TestSkipReleaseSealedSegments() { + ctx := context.Background() checker := suite.checker collectionID := int64(1) partitionID := int64(1) // set meta - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(collectionID, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(collectionID, partitionID)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, collectionID, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(collectionID, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(collectionID, partitionID)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, collectionID, []int64{1, 2})) // set target channels := []*datapb.VchannelInfo{ @@ -531,10 +539,10 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseSealedSegments() { segments := []*datapb.SegmentInfo{} suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(collectionID) - checker.targetMgr.UpdateCollectionCurrentTarget(collectionID) - checker.targetMgr.UpdateCollectionNextTarget(collectionID) - readableVersion := checker.targetMgr.GetCollectionTargetVersion(collectionID, meta.CurrentTarget) + checker.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) + checker.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) + readableVersion := checker.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.CurrentTarget) // test less target version exist on leader,meet segment doesn't exit in target, segment should be released nodeID := int64(2) @@ -579,12 +587,13 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseSealedSegments() { } func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() { + ctx := context.Background() checker := suite.checker // segment3 is compacted from segment2, and node2 has growing segments 2 and 3. checker should generate // 2 tasks to reduce segment 2 and 3. - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { @@ -602,9 +611,9 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) growingSegments := make(map[int64]*meta.Segment) growingSegments[2] = utils.CreateTestSegment(1, 1, 2, 2, 0, "test-insert-channel") @@ -618,7 +627,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() { dmChannel.UnflushedSegmentIds = []int64{2, 3} checker.dist.ChannelDistManager.Update(2, dmChannel) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 2}, growingSegments) - view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget) + view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget) checker.dist.LeaderViewManager.Update(2, view) checker.dist.SegmentDistManager.Update(2, utils.CreateTestSegment(1, 1, 3, 2, 2, "test-insert-channel")) @@ -647,11 +656,12 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() { } func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() { + ctx := context.Background() checker := suite.checker - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{ { @@ -670,9 +680,9 @@ func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) growingSegments := make(map[int64]*meta.Segment) // segment start pos after chekcpoint @@ -683,7 +693,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() { dmChannel.UnflushedSegmentIds = []int64{2, 3} checker.dist.ChannelDistManager.Update(2, dmChannel) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{3: 2}, growingSegments) - view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget) + view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget) checker.dist.LeaderViewManager.Update(2, view) checker.dist.SegmentDistManager.Update(2, utils.CreateTestSegment(1, 1, 3, 2, 2, "test-insert-channel")) @@ -703,10 +713,11 @@ func (suite *SegmentCheckerTestSuite) TestReleaseCompactedGrowingSegments() { } func (suite *SegmentCheckerTestSuite) TestSkipReleaseGrowingSegments() { + ctx := context.Background() checker := suite.checker - checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - checker.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) - checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) + checker.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + checker.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(1, 1)) + checker.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(1, 1, []int64{1, 2})) segments := []*datapb.SegmentInfo{} channels := []*datapb.VchannelInfo{ @@ -718,9 +729,9 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseGrowingSegments() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) - checker.targetMgr.UpdateCollectionCurrentTarget(int64(1)) - checker.targetMgr.UpdateCollectionNextTarget(int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionCurrentTarget(ctx, int64(1)) + checker.targetMgr.UpdateCollectionNextTarget(ctx, int64(1)) growingSegments := make(map[int64]*meta.Segment) growingSegments[2] = utils.CreateTestSegment(1, 1, 2, 2, 0, "test-insert-channel") @@ -730,13 +741,13 @@ func (suite *SegmentCheckerTestSuite) TestSkipReleaseGrowingSegments() { dmChannel.UnflushedSegmentIds = []int64{2, 3} checker.dist.ChannelDistManager.Update(2, dmChannel) view := utils.CreateTestLeaderView(2, 1, "test-insert-channel", map[int64]int64{}, growingSegments) - view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget) - 1 + view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget) - 1 checker.dist.LeaderViewManager.Update(2, view) tasks := checker.Check(context.TODO()) suite.Len(tasks, 0) - view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(int64(1), meta.CurrentTarget) + view.TargetVersion = checker.targetMgr.GetCollectionTargetVersion(ctx, int64(1), meta.CurrentTarget) checker.dist.LeaderViewManager.Update(2, view) tasks = checker.Check(context.TODO()) suite.Len(tasks, 1) diff --git a/internal/querycoordv2/dist/dist_controller.go b/internal/querycoordv2/dist/dist_controller.go index 5661eaae33413..5f46f04125028 100644 --- a/internal/querycoordv2/dist/dist_controller.go +++ b/internal/querycoordv2/dist/dist_controller.go @@ -79,7 +79,7 @@ func (dc *ControllerImpl) SyncAll(ctx context.Context) { if err != nil { log.Warn("SyncAll come across err when getting data distribution", zap.Error(err)) } else { - handler.handleDistResp(resp, true) + handler.handleDistResp(ctx, resp, true) } }(h) } diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index f1a5434f8fbfd..828cedc6e5ce3 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -103,11 +103,11 @@ func (dh *distHandler) pullDist(ctx context.Context, failures *int, dispatchTask log.RatedWarn(30.0, "failed to get data distribution", fields...) } else { *failures = 0 - dh.handleDistResp(resp, dispatchTask) + dh.handleDistResp(ctx, resp, dispatchTask) } } -func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse, dispatchTask bool) { +func (dh *distHandler) handleDistResp(ctx context.Context, resp *querypb.GetDataDistributionResponse, dispatchTask bool) { node := dh.nodeManager.Get(resp.GetNodeID()) if node == nil { return @@ -130,9 +130,9 @@ func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse, session.WithChannelCnt(len(resp.GetChannels())), session.WithMemCapacity(resp.GetMemCapacityInMB()), ) - dh.updateSegmentsDistribution(resp) - dh.updateChannelsDistribution(resp) - dh.updateLeaderView(resp) + dh.updateSegmentsDistribution(ctx, resp) + dh.updateChannelsDistribution(ctx, resp) + dh.updateLeaderView(ctx, resp) } if dispatchTask { @@ -140,10 +140,10 @@ func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse, } } -func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistributionResponse) { +func (dh *distHandler) updateSegmentsDistribution(ctx context.Context, resp *querypb.GetDataDistributionResponse) { updates := make([]*meta.Segment, 0, len(resp.GetSegments())) for _, s := range resp.GetSegments() { - segmentInfo := dh.target.GetSealedSegment(s.GetCollection(), s.GetID(), meta.CurrentTargetFirst) + segmentInfo := dh.target.GetSealedSegment(ctx, s.GetCollection(), s.GetID(), meta.CurrentTargetFirst) if segmentInfo == nil { segmentInfo = &datapb.SegmentInfo{ ID: s.GetID(), @@ -166,10 +166,10 @@ func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistribut dh.dist.SegmentDistManager.Update(resp.GetNodeID(), updates...) } -func (dh *distHandler) updateChannelsDistribution(resp *querypb.GetDataDistributionResponse) { +func (dh *distHandler) updateChannelsDistribution(ctx context.Context, resp *querypb.GetDataDistributionResponse) { updates := make([]*meta.DmChannel, 0, len(resp.GetChannels())) for _, ch := range resp.GetChannels() { - channelInfo := dh.target.GetDmChannel(ch.GetCollection(), ch.GetChannel(), meta.CurrentTarget) + channelInfo := dh.target.GetDmChannel(ctx, ch.GetCollection(), ch.GetChannel(), meta.CurrentTarget) var channel *meta.DmChannel if channelInfo == nil { channel = &meta.DmChannel{ @@ -193,7 +193,7 @@ func (dh *distHandler) updateChannelsDistribution(resp *querypb.GetDataDistribut dh.dist.ChannelDistManager.Update(resp.GetNodeID(), updates...) } -func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionResponse) { +func (dh *distHandler) updateLeaderView(ctx context.Context, resp *querypb.GetDataDistributionResponse) { updates := make([]*meta.LeaderView, 0, len(resp.GetLeaderViews())) channels := lo.SliceToMap(resp.GetChannels(), func(channel *querypb.ChannelVersionInfo) (string, *querypb.ChannelVersionInfo) { @@ -248,7 +248,7 @@ func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionRespons // if target version hasn't been synced, delegator will get empty readable segment list // so shard leader should be unserviceable until target version is synced - currentTargetVersion := dh.target.GetCollectionTargetVersion(lview.GetCollection(), meta.CurrentTarget) + currentTargetVersion := dh.target.GetCollectionTargetVersion(ctx, lview.GetCollection(), meta.CurrentTarget) if lview.TargetVersion <= 0 { err := merr.WrapErrServiceInternal(fmt.Sprintf("target version mismatch, collection: %d, channel: %s, current target version: %v, leader version: %v", lview.GetCollection(), lview.GetChannel(), currentTargetVersion, lview.TargetVersion)) diff --git a/internal/querycoordv2/dist/dist_handler_test.go b/internal/querycoordv2/dist/dist_handler_test.go index 17b1fe5ae7937..c66902bc43309 100644 --- a/internal/querycoordv2/dist/dist_handler_test.go +++ b/internal/querycoordv2/dist/dist_handler_test.go @@ -66,9 +66,9 @@ func (suite *DistHandlerSuite) SetupSuite() { suite.executedFlagChan = make(chan struct{}, 1) suite.scheduler.EXPECT().GetExecutedFlag(mock.Anything).Return(suite.executedFlagChan).Maybe() - suite.target.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() - suite.target.EXPECT().GetDmChannel(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() - suite.target.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe() + suite.target.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.target.EXPECT().GetDmChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.target.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe() } func (suite *DistHandlerSuite) TestBasic() { @@ -77,7 +77,7 @@ func (suite *DistHandlerSuite) TestBasic() { suite.dispatchMockCall = nil } - suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}) + suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}) suite.dispatchMockCall = suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe() suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, @@ -126,7 +126,7 @@ func (suite *DistHandlerSuite) TestGetDistributionFailed() { suite.dispatchMockCall.Unset() suite.dispatchMockCall = nil } - suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe() + suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe() suite.dispatchMockCall = suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe() suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, @@ -148,7 +148,7 @@ func (suite *DistHandlerSuite) TestForcePullDist() { suite.dispatchMockCall = nil } - suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe() + suite.target.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{}).Maybe() suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 7c69574980245..d770edc1e6fef 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -49,7 +49,7 @@ import ( // may come from different replica group. We only need these shards to form a replica that serves query // requests. func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool { - for _, replica := range s.meta.ReplicaManager.GetByCollection(collectionID) { + for _, replica := range s.meta.ReplicaManager.GetByCollection(s.ctx, collectionID) { isAvailable := true for _, node := range replica.GetRONodes() { if s.nodeMgr.Get(node) == nil { @@ -64,9 +64,9 @@ func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool { return false } -func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentInfo { +func (s *Server) getCollectionSegmentInfo(ctx context.Context, collection int64) []*querypb.SegmentInfo { segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(collection)) - currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget) + currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(ctx, collection, meta.CurrentTarget) infos := make(map[int64]*querypb.SegmentInfo) for _, segment := range segments { if _, existCurrentTarget := currentTargetSegmentsMap[segment.GetID()]; !existCurrentTarget { @@ -104,7 +104,7 @@ func (s *Server) balanceSegments(ctx context.Context, copyMode bool, ) error { log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("srcNode", srcNode)) - plans := s.getBalancerFunc().AssignSegment(collectionID, segments, dstNodes, true) + plans := s.getBalancerFunc().AssignSegment(ctx, collectionID, segments, dstNodes, true) for i := range plans { plans[i].From = srcNode plans[i].Replica = replica @@ -183,7 +183,7 @@ func (s *Server) balanceChannels(ctx context.Context, ) error { log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID)) - plans := s.getBalancerFunc().AssignChannel(channels, dstNodes, true) + plans := s.getBalancerFunc().AssignChannel(ctx, channels, dstNodes, true) for i := range plans { plans[i].From = srcNode plans[i].Replica = replica @@ -458,16 +458,16 @@ func (s *Server) tryGetNodesMetrics(ctx context.Context, req *milvuspb.GetMetric return ret } -func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) *milvuspb.ReplicaInfo { +func (s *Server) fillReplicaInfo(ctx context.Context, replica *meta.Replica, withShardNodes bool) *milvuspb.ReplicaInfo { info := &milvuspb.ReplicaInfo{ ReplicaID: replica.GetID(), CollectionID: replica.GetCollectionID(), NodeIds: replica.GetNodes(), ResourceGroupName: replica.GetResourceGroup(), - NumOutboundNode: s.meta.GetOutgoingNodeNumByReplica(replica), + NumOutboundNode: s.meta.GetOutgoingNodeNumByReplica(ctx, replica), } - channels := s.targetMgr.GetDmChannelsByCollection(replica.GetCollectionID(), meta.CurrentTarget) + channels := s.targetMgr.GetDmChannelsByCollection(ctx, replica.GetCollectionID(), meta.CurrentTarget) if len(channels) == 0 { log.Warn("failed to get channels, collection may be not loaded or in recovering", zap.Int64("collectionID", replica.GetCollectionID())) return info diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index f7c75da2fc33a..9457b1b303ab3 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -98,14 +98,14 @@ func (job *LoadCollectionJob) PreExecute() error { req.ResourceGroups = []string{meta.DefaultResourceGroupName} } - collection := job.meta.GetCollection(req.GetCollectionID()) + collection := job.meta.GetCollection(job.ctx, req.GetCollectionID()) if collection == nil { return nil } if collection.GetReplicaNumber() != req.GetReplicaNumber() { msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number", - job.meta.GetReplicaNumber(req.GetCollectionID()), + job.meta.GetReplicaNumber(job.ctx, req.GetCollectionID()), ) log.Warn(msg) return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded collection") @@ -125,7 +125,7 @@ func (job *LoadCollectionJob) PreExecute() error { ) return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection") } - collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect() + collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(job.ctx, collection.GetCollectionID()).Collect() left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups()) if len(left) > 0 || len(right) > 0 { msg := fmt.Sprintf("collection with different resource groups %v existed, release this collection first before changing its resource groups", @@ -149,7 +149,7 @@ func (job *LoadCollectionJob) Execute() error { log.Warn(msg, zap.Error(err)) return errors.Wrap(err, msg) } - loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()), + loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 { return partition.GetPartitionID() }) @@ -163,10 +163,10 @@ func (job *LoadCollectionJob) Execute() error { job.undo.LackPartitions = lackPartitionIDs log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs)) - colExisted := job.meta.CollectionManager.Exist(req.GetCollectionID()) + colExisted := job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) if !colExisted { // Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444 - err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID()) + err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID()) if err != nil { msg := "failed to clear stale replicas" log.Warn(msg, zap.Error(err)) @@ -175,7 +175,7 @@ func (job *LoadCollectionJob) Execute() error { } // 2. create replica if not exist - replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID()) + replicas := job.meta.ReplicaManager.GetByCollection(job.ctx, req.GetCollectionID()) if len(replicas) == 0 { collectionInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionID()) if err != nil { @@ -184,7 +184,7 @@ func (job *LoadCollectionJob) Execute() error { // API of LoadCollection is wired, we should use map[resourceGroupNames]replicaNumber as input, to keep consistency with `TransferReplica` API. // Then we can implement dynamic replica changed in different resource group independently. - _, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) + _, err = utils.SpawnReplicasWithRG(job.ctx, job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) if err != nil { msg := "failed to spawn replica for collection" log.Warn(msg, zap.Error(err)) @@ -227,7 +227,7 @@ func (job *LoadCollectionJob) Execute() error { LoadSpan: sp, } job.undo.IsNewCollection = true - err = job.meta.CollectionManager.PutCollection(collection, partitions...) + err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...) if err != nil { msg := "failed to store collection and partitions" log.Warn(msg, zap.Error(err)) @@ -312,7 +312,7 @@ func (job *LoadPartitionJob) PreExecute() error { req.ResourceGroups = []string{meta.DefaultResourceGroupName} } - collection := job.meta.GetCollection(req.GetCollectionID()) + collection := job.meta.GetCollection(job.ctx, req.GetCollectionID()) if collection == nil { return nil } @@ -337,7 +337,7 @@ func (job *LoadPartitionJob) PreExecute() error { ) return merr.WrapErrParameterInvalid(collection.GetLoadFields(), req.GetLoadFields(), "can't change the load field list for loaded collection") } - collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect() + collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(job.ctx, collection.GetCollectionID()).Collect() left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups()) if len(left) > 0 || len(right) > 0 { msg := fmt.Sprintf("collection with different resource groups %v existed, release this collection first before changing its resource groups", @@ -358,7 +358,7 @@ func (job *LoadPartitionJob) Execute() error { meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) // 1. Fetch target partitions - loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()), + loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 { return partition.GetPartitionID() }) @@ -373,9 +373,9 @@ func (job *LoadPartitionJob) Execute() error { log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs)) var err error - if !job.meta.CollectionManager.Exist(req.GetCollectionID()) { + if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) { // Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444 - err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID()) + err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID()) if err != nil { msg := "failed to clear stale replicas" log.Warn(msg, zap.Error(err)) @@ -384,13 +384,13 @@ func (job *LoadPartitionJob) Execute() error { } // 2. create replica if not exist - replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID()) + replicas := job.meta.ReplicaManager.GetByCollection(context.TODO(), req.GetCollectionID()) if len(replicas) == 0 { collectionInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionID()) if err != nil { return err } - _, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) + _, err = utils.SpawnReplicasWithRG(job.ctx, job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(), collectionInfo.GetVirtualChannelNames()) if err != nil { msg := "failed to spawn replica for collection" log.Warn(msg, zap.Error(err)) @@ -419,7 +419,7 @@ func (job *LoadPartitionJob) Execute() error { } }) ctx, sp := otel.Tracer(typeutil.QueryCoordRole).Start(job.ctx, "LoadPartition", trace.WithNewRoot()) - if !job.meta.CollectionManager.Exist(req.GetCollectionID()) { + if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) { job.undo.IsNewCollection = true collection := &meta.Collection{ @@ -434,14 +434,14 @@ func (job *LoadPartitionJob) Execute() error { CreatedAt: time.Now(), LoadSpan: sp, } - err = job.meta.CollectionManager.PutCollection(collection, partitions...) + err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...) if err != nil { msg := "failed to store collection and partitions" log.Warn(msg, zap.Error(err)) return errors.Wrap(err, msg) } } else { // collection exists, put partitions only - err = job.meta.CollectionManager.PutPartition(partitions...) + err = job.meta.CollectionManager.PutPartition(job.ctx, partitions...) if err != nil { msg := "failed to store partitions" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/job/job_release.go b/internal/querycoordv2/job/job_release.go index ca903159a5698..b5f7de892452a 100644 --- a/internal/querycoordv2/job/job_release.go +++ b/internal/querycoordv2/job/job_release.go @@ -77,25 +77,25 @@ func (job *ReleaseCollectionJob) Execute() error { req := job.req log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID())) - if !job.meta.CollectionManager.Exist(req.GetCollectionID()) { + if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) { log.Info("release collection end, the collection has not been loaded into QueryNode") return nil } - loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()) + loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()) toRelease := lo.Map(loadedPartitions, func(partition *meta.Partition, _ int) int64 { return partition.GetPartitionID() }) releasePartitions(job.ctx, job.meta, job.cluster, req.GetCollectionID(), toRelease...) - err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID()) + err := job.meta.CollectionManager.RemoveCollection(job.ctx, req.GetCollectionID()) if err != nil { msg := "failed to remove collection" log.Warn(msg, zap.Error(err)) return errors.Wrap(err, msg) } - err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID()) + err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID()) if err != nil { msg := "failed to remove replicas" log.Warn(msg, zap.Error(err)) @@ -166,12 +166,12 @@ func (job *ReleasePartitionJob) Execute() error { zap.Int64s("partitionIDs", req.GetPartitionIDs()), ) - if !job.meta.CollectionManager.Exist(req.GetCollectionID()) { + if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) { log.Info("release collection end, the collection has not been loaded into QueryNode") return nil } - loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()) + loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()) toRelease := lo.FilterMap(loadedPartitions, func(partition *meta.Partition, _ int) (int64, bool) { return partition.GetPartitionID(), lo.Contains(req.GetPartitionIDs(), partition.GetPartitionID()) }) @@ -185,13 +185,13 @@ func (job *ReleasePartitionJob) Execute() error { // If all partitions are released, clear all if len(toRelease) == len(loadedPartitions) { log.Info("release partitions covers all partitions, will remove the whole collection") - err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID()) + err := job.meta.CollectionManager.RemoveCollection(job.ctx, req.GetCollectionID()) if err != nil { msg := "failed to release partitions from store" log.Warn(msg, zap.Error(err)) return errors.Wrap(err, msg) } - err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID()) + err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID()) if err != nil { log.Warn("failed to remove replicas", zap.Error(err)) } @@ -207,7 +207,7 @@ func (job *ReleasePartitionJob) Execute() error { waitCollectionReleased(job.dist, job.checkerController, req.GetCollectionID()) } else { - err := job.meta.CollectionManager.RemovePartition(req.GetCollectionID(), toRelease...) + err := job.meta.CollectionManager.RemovePartition(job.ctx, req.GetCollectionID(), toRelease...) if err != nil { msg := "failed to release partitions from store" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/job/job_sync.go b/internal/querycoordv2/job/job_sync.go index 72a25b9a67e97..49cb805b7d449 100644 --- a/internal/querycoordv2/job/job_sync.go +++ b/internal/querycoordv2/job/job_sync.go @@ -65,13 +65,13 @@ func (job *SyncNewCreatedPartitionJob) Execute() error { ) // check if collection not load or loadType is loadPartition - collection := job.meta.GetCollection(job.req.GetCollectionID()) + collection := job.meta.GetCollection(job.ctx, job.req.GetCollectionID()) if collection == nil || collection.GetLoadType() == querypb.LoadType_LoadPartition { return nil } // check if partition already existed - if partition := job.meta.GetPartition(job.req.GetPartitionID()); partition != nil { + if partition := job.meta.GetPartition(job.ctx, job.req.GetPartitionID()); partition != nil { return nil } @@ -89,7 +89,7 @@ func (job *SyncNewCreatedPartitionJob) Execute() error { LoadPercentage: 100, CreatedAt: time.Now(), } - err = job.meta.CollectionManager.PutPartition(partition) + err = job.meta.CollectionManager.PutPartition(job.ctx, partition) if err != nil { msg := "failed to store partitions" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index 0fc786dbbc763..02615713a5b00 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -77,6 +77,8 @@ type JobSuite struct { // Test objects scheduler *Scheduler + + ctx context.Context } func (suite *JobSuite) SetupSuite() { @@ -160,6 +162,7 @@ func (suite *JobSuite) SetupTest() { config.EtcdTLSMinVersion.GetValue()) suite.Require().NoError(err) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.ctx = context.Background() suite.store = querycoord.NewCatalog(suite.kv) suite.dist = meta.NewDistributionManager() @@ -195,9 +198,9 @@ func (suite *JobSuite) SetupTest() { Hostname: "localhost", })) - suite.meta.HandleNodeUp(1000) - suite.meta.HandleNodeUp(2000) - suite.meta.HandleNodeUp(3000) + suite.meta.HandleNodeUp(suite.ctx, 1000) + suite.meta.HandleNodeUp(suite.ctx, 2000) + suite.meta.HandleNodeUp(suite.ctx, 3000) suite.checkerController = &checkers.CheckerController{} suite.collectionObserver = observers.NewCollectionObserver( @@ -253,8 +256,8 @@ func (suite *JobSuite) TestLoadCollection() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertCollectionLoaded(collection) } @@ -346,9 +349,9 @@ func (suite *JobSuite) TestLoadCollection() { }, } - suite.meta.ResourceManager.AddResourceGroup("rg1", cfg) - suite.meta.ResourceManager.AddResourceGroup("rg2", cfg) - suite.meta.ResourceManager.AddResourceGroup("rg3", cfg) + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", cfg) + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", cfg) + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", cfg) // Load with 3 replica on 1 rg req := &querypb.LoadCollectionRequest{ @@ -455,8 +458,8 @@ func (suite *JobSuite) TestLoadCollectionWithLoadFields() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertCollectionLoaded(collection) } }) @@ -580,8 +583,8 @@ func (suite *JobSuite) TestLoadPartition() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertCollectionLoaded(collection) } @@ -704,9 +707,9 @@ func (suite *JobSuite) TestLoadPartition() { NodeNum: 1, }, } - suite.meta.ResourceManager.AddResourceGroup("rg1", cfg) - suite.meta.ResourceManager.AddResourceGroup("rg2", cfg) - suite.meta.ResourceManager.AddResourceGroup("rg3", cfg) + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", cfg) + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", cfg) + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", cfg) // test load 3 replica in 1 rg, should pass rg check req := &querypb.LoadPartitionsRequest{ @@ -786,8 +789,8 @@ func (suite *JobSuite) TestLoadPartitionWithLoadFields() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertCollectionLoaded(collection) } }) @@ -941,7 +944,7 @@ func (suite *JobSuite) TestDynamicLoad() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p0, p1, p2) // loaded: p0, p1, p2 @@ -961,13 +964,13 @@ func (suite *JobSuite) TestDynamicLoad() { suite.scheduler.Add(job) err = job.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p0, p1) job = newLoadPartJob(p2) suite.scheduler.Add(job) err = job.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p2) // loaded: p0, p1 @@ -978,13 +981,13 @@ func (suite *JobSuite) TestDynamicLoad() { suite.scheduler.Add(job) err = job.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p0, p1) job = newLoadPartJob(p1, p2) suite.scheduler.Add(job) err = job.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p2) // loaded: p0, p1 @@ -995,13 +998,13 @@ func (suite *JobSuite) TestDynamicLoad() { suite.scheduler.Add(job) err = job.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p0, p1) colJob := newLoadColJob() suite.scheduler.Add(colJob) err = colJob.Wait() suite.NoError(err) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) suite.assertPartitionLoaded(collection, p2) } @@ -1166,8 +1169,8 @@ func (suite *JobSuite) TestReleasePartition() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.True(suite.meta.Exist(collection)) - partitions := suite.meta.GetPartitionsByCollection(collection) + suite.True(suite.meta.Exist(ctx, collection)) + partitions := suite.meta.GetPartitionsByCollection(ctx, collection) suite.Len(partitions, 1) suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID()) suite.assertPartitionReleased(collection, suite.partitions[collection][1:]...) @@ -1247,7 +1250,7 @@ func (suite *JobSuite) TestDynamicRelease() { err = job.Wait() suite.NoError(err) suite.assertPartitionReleased(col0, p0, p1, p2) - suite.False(suite.meta.Exist(col0)) + suite.False(suite.meta.Exist(ctx, col0)) // loaded: p0, p1, p2 // action: release col @@ -1275,14 +1278,15 @@ func (suite *JobSuite) TestDynamicRelease() { } func (suite *JobSuite) TestLoadCollectionStoreFailed() { + ctx := context.Background() // Store collection failed store := mocks.NewQueryCoordCatalog(suite.T()) suite.meta = meta.NewMeta(RandomIncrementIDAllocator(), store, suite.nodeMgr) store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil) - suite.meta.HandleNodeUp(1000) - suite.meta.HandleNodeUp(2000) - suite.meta.HandleNodeUp(3000) + suite.meta.HandleNodeUp(ctx, 1000) + suite.meta.HandleNodeUp(ctx, 2000) + suite.meta.HandleNodeUp(ctx, 3000) for _, collection := range suite.collections { if suite.loadTypes[collection] != querypb.LoadType_LoadCollection { @@ -1290,9 +1294,9 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() { } suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil) err := errors.New("failed to store collection") - store.EXPECT().SaveReplica(mock.Anything).Return(nil) - store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err) - store.EXPECT().ReleaseReplicas(collection).Return(nil) + store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err) + store.EXPECT().ReleaseReplicas(mock.Anything, collection).Return(nil) req := &querypb.LoadCollectionRequest{ CollectionID: collection, @@ -1316,14 +1320,15 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() { } func (suite *JobSuite) TestLoadPartitionStoreFailed() { + ctx := context.Background() // Store partition failed store := mocks.NewQueryCoordCatalog(suite.T()) suite.meta = meta.NewMeta(RandomIncrementIDAllocator(), store, suite.nodeMgr) - store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil) - suite.meta.HandleNodeUp(1000) - suite.meta.HandleNodeUp(2000) - suite.meta.HandleNodeUp(3000) + store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil) + suite.meta.HandleNodeUp(ctx, 1000) + suite.meta.HandleNodeUp(ctx, 2000) + suite.meta.HandleNodeUp(ctx, 3000) err := errors.New("failed to store collection") for _, collection := range suite.collections { @@ -1331,9 +1336,9 @@ func (suite *JobSuite) TestLoadPartitionStoreFailed() { continue } - store.EXPECT().SaveReplica(mock.Anything).Return(nil) - store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err) - store.EXPECT().ReleaseReplicas(collection).Return(nil) + store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err) + store.EXPECT().ReleaseReplicas(mock.Anything, collection).Return(nil) req := &querypb.LoadPartitionsRequest{ CollectionID: collection, @@ -1548,6 +1553,7 @@ func (suite *JobSuite) TestCallReleasePartitionFailed() { func (suite *JobSuite) TestSyncNewCreatedPartition() { newPartition := int64(999) + ctx := context.Background() // test sync new created partition suite.loadAll() @@ -1565,7 +1571,7 @@ func (suite *JobSuite) TestSyncNewCreatedPartition() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - partition := suite.meta.CollectionManager.GetPartition(newPartition) + partition := suite.meta.CollectionManager.GetPartition(ctx, newPartition) suite.NotNil(partition) suite.Equal(querypb.LoadStatus_Loaded, partition.GetStatus()) @@ -1624,11 +1630,11 @@ func (suite *JobSuite) loadAll() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) - suite.True(suite.meta.Exist(collection)) - suite.NotNil(suite.meta.GetCollection(collection)) - suite.NotNil(suite.meta.GetPartitionsByCollection(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection)) + suite.True(suite.meta.Exist(ctx, collection)) + suite.NotNil(suite.meta.GetCollection(ctx, collection)) + suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) } else { req := &querypb.LoadPartitionsRequest{ CollectionID: collection, @@ -1649,11 +1655,11 @@ func (suite *JobSuite) loadAll() { suite.scheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) - suite.True(suite.meta.Exist(collection)) - suite.NotNil(suite.meta.GetCollection(collection)) - suite.NotNil(suite.meta.GetPartitionsByCollection(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection)) + suite.True(suite.meta.Exist(ctx, collection)) + suite.NotNil(suite.meta.GetCollection(ctx, collection)) + suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) } } } @@ -1684,54 +1690,58 @@ func (suite *JobSuite) releaseAll() { } func (suite *JobSuite) assertCollectionLoaded(collection int64) { - suite.True(suite.meta.Exist(collection)) - suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection))) + ctx := context.Background() + suite.True(suite.meta.Exist(ctx, collection)) + suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection))) for _, channel := range suite.channels[collection] { - suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget)) } for _, segments := range suite.segments[collection] { for _, segment := range segments { - suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget)) } } } func (suite *JobSuite) assertPartitionLoaded(collection int64, partitionIDs ...int64) { - suite.True(suite.meta.Exist(collection)) - suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection))) + ctx := context.Background() + suite.True(suite.meta.Exist(ctx, collection)) + suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection))) for _, channel := range suite.channels[collection] { - suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget)) } for partitionID, segments := range suite.segments[collection] { if !lo.Contains(partitionIDs, partitionID) { continue } - suite.NotNil(suite.meta.GetPartition(partitionID)) + suite.NotNil(suite.meta.GetPartition(ctx, partitionID)) for _, segment := range segments { - suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget)) } } } func (suite *JobSuite) assertCollectionReleased(collection int64) { - suite.False(suite.meta.Exist(collection)) - suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(collection))) + ctx := context.Background() + suite.False(suite.meta.Exist(ctx, collection)) + suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection))) for _, channel := range suite.channels[collection] { - suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget)) } for _, partitions := range suite.segments[collection] { for _, segment := range partitions { - suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget)) } } } func (suite *JobSuite) assertPartitionReleased(collection int64, partitionIDs ...int64) { + ctx := context.Background() for _, partition := range partitionIDs { - suite.Nil(suite.meta.GetPartition(partition)) + suite.Nil(suite.meta.GetPartition(ctx, partition)) segments := suite.segments[collection][partition] for _, segment := range segments { - suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget)) } } } diff --git a/internal/querycoordv2/job/job_update.go b/internal/querycoordv2/job/job_update.go index cb60af36fa51e..7b5259c83d06c 100644 --- a/internal/querycoordv2/job/job_update.go +++ b/internal/querycoordv2/job/job_update.go @@ -62,7 +62,7 @@ func NewUpdateLoadConfigJob(ctx context.Context, } func (job *UpdateLoadConfigJob) Execute() error { - if !job.meta.CollectionManager.Exist(job.collectionID) { + if !job.meta.CollectionManager.Exist(job.ctx, job.collectionID) { msg := "modify replica for unloaded collection is not supported" err := merr.WrapErrCollectionNotLoaded(msg) log.Warn(msg, zap.Error(err)) @@ -83,7 +83,7 @@ func (job *UpdateLoadConfigJob) Execute() error { var err error // 2. reassign - toSpawn, toTransfer, toRelease, err := utils.ReassignReplicaToRG(job.meta, job.collectionID, job.newReplicaNumber, job.newResourceGroups) + toSpawn, toTransfer, toRelease, err := utils.ReassignReplicaToRG(job.ctx, job.meta, job.collectionID, job.newReplicaNumber, job.newResourceGroups) if err != nil { log.Warn("failed to reassign replica", zap.Error(err)) return err @@ -98,8 +98,8 @@ func (job *UpdateLoadConfigJob) Execute() error { zap.Any("toRelease", toRelease)) // 3. try to spawn new replica - channels := job.targetMgr.GetDmChannelsByCollection(job.collectionID, meta.CurrentTargetFirst) - newReplicas, spawnErr := job.meta.ReplicaManager.Spawn(job.collectionID, toSpawn, lo.Keys(channels)) + channels := job.targetMgr.GetDmChannelsByCollection(job.ctx, job.collectionID, meta.CurrentTargetFirst) + newReplicas, spawnErr := job.meta.ReplicaManager.Spawn(job.ctx, job.collectionID, toSpawn, lo.Keys(channels)) if spawnErr != nil { log.Warn("failed to spawn replica", zap.Error(spawnErr)) err := spawnErr @@ -109,7 +109,7 @@ func (job *UpdateLoadConfigJob) Execute() error { if err != nil { // roll back replica from meta replicaIDs := lo.Map(newReplicas, func(r *meta.Replica, _ int) int64 { return r.GetID() }) - err := job.meta.ReplicaManager.RemoveReplicas(job.collectionID, replicaIDs...) + err := job.meta.ReplicaManager.RemoveReplicas(job.ctx, job.collectionID, replicaIDs...) if err != nil { log.Warn("failed to remove replicas", zap.Int64s("replicaIDs", replicaIDs), zap.Error(err)) } @@ -125,7 +125,7 @@ func (job *UpdateLoadConfigJob) Execute() error { replicaOldRG[replica.GetID()] = replica.GetResourceGroup() } - if transferErr := job.meta.ReplicaManager.MoveReplica(rg, replicas); transferErr != nil { + if transferErr := job.meta.ReplicaManager.MoveReplica(job.ctx, rg, replicas); transferErr != nil { log.Warn("failed to transfer replica for collection", zap.Int64("collectionID", collectionID), zap.Error(transferErr)) err = transferErr return err @@ -138,7 +138,7 @@ func (job *UpdateLoadConfigJob) Execute() error { for _, replica := range replicas { oldRG := replicaOldRG[replica.GetID()] if replica.GetResourceGroup() != oldRG { - if err := job.meta.ReplicaManager.TransferReplica(replica.GetID(), replica.GetResourceGroup(), oldRG, 1); err != nil { + if err := job.meta.ReplicaManager.TransferReplica(job.ctx, replica.GetID(), replica.GetResourceGroup(), oldRG, 1); err != nil { log.Warn("failed to roll back replicas", zap.Int64("replica", replica.GetID()), zap.Error(err)) } } @@ -148,17 +148,17 @@ func (job *UpdateLoadConfigJob) Execute() error { }() // 5. remove replica from meta - err = job.meta.ReplicaManager.RemoveReplicas(job.collectionID, toRelease...) + err = job.meta.ReplicaManager.RemoveReplicas(job.ctx, job.collectionID, toRelease...) if err != nil { log.Warn("failed to remove replicas", zap.Int64s("replicaIDs", toRelease), zap.Error(err)) return err } // 6. recover node distribution among replicas - utils.RecoverReplicaOfCollection(job.meta, job.collectionID) + utils.RecoverReplicaOfCollection(job.ctx, job.meta, job.collectionID) // 7. update replica number in meta - err = job.meta.UpdateReplicaNumber(job.collectionID, job.newReplicaNumber) + err = job.meta.UpdateReplicaNumber(job.ctx, job.collectionID, job.newReplicaNumber) if err != nil { msg := "failed to update replica number" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/job/undo.go b/internal/querycoordv2/job/undo.go index e1314f0aec6e0..3fe97e98aff63 100644 --- a/internal/querycoordv2/job/undo.go +++ b/internal/querycoordv2/job/undo.go @@ -68,9 +68,9 @@ func (u *UndoList) RollBack() { var err error if u.IsNewCollection || u.IsReplicaCreated { - err = u.meta.CollectionManager.RemoveCollection(u.CollectionID) + err = u.meta.CollectionManager.RemoveCollection(u.ctx, u.CollectionID) } else { - err = u.meta.CollectionManager.RemovePartition(u.CollectionID, u.LackPartitions...) + err = u.meta.CollectionManager.RemovePartition(u.ctx, u.CollectionID, u.LackPartitions...) } if err != nil { log.Warn("failed to rollback collection from meta", zap.Error(err)) diff --git a/internal/querycoordv2/job/utils.go b/internal/querycoordv2/job/utils.go index 7f56794144480..a202d03662b78 100644 --- a/internal/querycoordv2/job/utils.go +++ b/internal/querycoordv2/job/utils.go @@ -90,7 +90,7 @@ func loadPartitions(ctx context.Context, return err } - replicas := meta.ReplicaManager.GetByCollection(collection) + replicas := meta.ReplicaManager.GetByCollection(ctx, collection) loadReq := &querypb.LoadPartitionsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadPartitions, @@ -124,7 +124,7 @@ func releasePartitions(ctx context.Context, partitions ...int64, ) { log := log.Ctx(ctx).With(zap.Int64("collection", collection), zap.Int64s("partitions", partitions)) - replicas := meta.ReplicaManager.GetByCollection(collection) + replicas := meta.ReplicaManager.GetByCollection(ctx, collection) releaseReq := &querypb.ReleasePartitionsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_ReleasePartitions, diff --git a/internal/querycoordv2/meta/collection_manager.go b/internal/querycoordv2/meta/collection_manager.go index 4a6f2ef7fd61c..124feb3e871da 100644 --- a/internal/querycoordv2/meta/collection_manager.go +++ b/internal/querycoordv2/meta/collection_manager.go @@ -122,17 +122,17 @@ func NewCollectionManager(catalog metastore.QueryCoordCatalog) *CollectionManage // Recover recovers collections from kv store, // panics if failed -func (m *CollectionManager) Recover(broker Broker) error { - collections, err := m.catalog.GetCollections() +func (m *CollectionManager) Recover(ctx context.Context, broker Broker) error { + collections, err := m.catalog.GetCollections(ctx) if err != nil { return err } - partitions, err := m.catalog.GetPartitions() + partitions, err := m.catalog.GetPartitions(ctx) if err != nil { return err } - ctx := log.WithTraceID(context.Background(), strconv.FormatInt(time.Now().UnixNano(), 10)) + ctx = log.WithTraceID(ctx, strconv.FormatInt(time.Now().UnixNano(), 10)) ctxLog := log.Ctx(ctx) ctxLog.Info("recover collections and partitions from kv store") @@ -141,13 +141,13 @@ func (m *CollectionManager) Recover(broker Broker) error { ctxLog.Info("skip recovery and release collection due to invalid replica number", zap.Int64("collectionID", collection.GetCollectionID()), zap.Int32("replicaNumber", collection.GetReplicaNumber())) - m.catalog.ReleaseCollection(collection.GetCollectionID()) + m.catalog.ReleaseCollection(ctx, collection.GetCollectionID()) continue } if collection.GetStatus() != querypb.LoadStatus_Loaded { if collection.RecoverTimes >= paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32() { - m.catalog.ReleaseCollection(collection.CollectionID) + m.catalog.ReleaseCollection(ctx, collection.CollectionID) ctxLog.Info("recover loading collection times reach limit, release collection", zap.Int64("collectionID", collection.CollectionID), zap.Int32("recoverTimes", collection.RecoverTimes)) @@ -155,11 +155,11 @@ func (m *CollectionManager) Recover(broker Broker) error { } // update recoverTimes meta in etcd collection.RecoverTimes += 1 - m.putCollection(true, &Collection{CollectionLoadInfo: collection}) + m.putCollection(ctx, true, &Collection{CollectionLoadInfo: collection}) continue } - err := m.upgradeLoadFields(collection, broker) + err := m.upgradeLoadFields(ctx, collection, broker) if err != nil { if errors.Is(err, merr.ErrCollectionNotFound) { log.Warn("collection not found, skip upgrade logic and wait for release") @@ -170,7 +170,7 @@ func (m *CollectionManager) Recover(broker Broker) error { } // update collection's CreateAt and UpdateAt to now after qc restart - m.putCollection(false, &Collection{ + m.putCollection(ctx, false, &Collection{ CollectionLoadInfo: collection, CreatedAt: time.Now(), }) @@ -181,7 +181,7 @@ func (m *CollectionManager) Recover(broker Broker) error { // Partitions not loaded done should be deprecated if partition.GetStatus() != querypb.LoadStatus_Loaded { if partition.RecoverTimes >= paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32() { - m.catalog.ReleaseCollection(collection) + m.catalog.ReleaseCollection(ctx, collection) ctxLog.Info("recover loading partition times reach limit, release collection", zap.Int64("collectionID", collection), zap.Int32("recoverTimes", partition.RecoverTimes)) @@ -189,7 +189,7 @@ func (m *CollectionManager) Recover(broker Broker) error { } partition.RecoverTimes += 1 - m.putPartition([]*Partition{ + m.putPartition(ctx, []*Partition{ { PartitionLoadInfo: partition, CreatedAt: time.Now(), @@ -198,7 +198,7 @@ func (m *CollectionManager) Recover(broker Broker) error { continue } - m.putPartition([]*Partition{ + m.putPartition(ctx, []*Partition{ { PartitionLoadInfo: partition, CreatedAt: time.Now(), @@ -207,7 +207,7 @@ func (m *CollectionManager) Recover(broker Broker) error { } } - err = m.upgradeRecover(broker) + err = m.upgradeRecover(ctx, broker) if err != nil { log.Warn("upgrade recover failed", zap.Error(err)) return err @@ -215,7 +215,7 @@ func (m *CollectionManager) Recover(broker Broker) error { return nil } -func (m *CollectionManager) upgradeLoadFields(collection *querypb.CollectionLoadInfo, broker Broker) error { +func (m *CollectionManager) upgradeLoadFields(ctx context.Context, collection *querypb.CollectionLoadInfo, broker Broker) error { // only fill load fields when value is nil if collection.LoadFields != nil { return nil @@ -234,7 +234,7 @@ func (m *CollectionManager) upgradeLoadFields(collection *querypb.CollectionLoad }) // put updated meta back to store - err = m.putCollection(true, &Collection{ + err = m.putCollection(ctx, true, &Collection{ CollectionLoadInfo: collection, LoadPercentage: 100, }) @@ -246,10 +246,10 @@ func (m *CollectionManager) upgradeLoadFields(collection *querypb.CollectionLoad } // upgradeRecover recovers from old version <= 2.2.x for compatibility. -func (m *CollectionManager) upgradeRecover(broker Broker) error { +func (m *CollectionManager) upgradeRecover(ctx context.Context, broker Broker) error { // for loaded collection from 2.2, it only save a old version CollectionLoadInfo without LoadType. // we should update the CollectionLoadInfo and save all PartitionLoadInfo to meta store - for _, collection := range m.GetAllCollections() { + for _, collection := range m.GetAllCollections(ctx) { if collection.GetLoadType() == querypb.LoadType_UnKnownType { partitionIDs, err := broker.GetPartitions(context.Background(), collection.GetCollectionID()) if err != nil { @@ -267,14 +267,14 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error { LoadPercentage: 100, } }) - err = m.putPartition(partitions, true) + err = m.putPartition(ctx, partitions, true) if err != nil { return err } newInfo := collection.Clone() newInfo.LoadType = querypb.LoadType_LoadCollection - err = m.putCollection(true, newInfo) + err = m.putCollection(ctx, true, newInfo) if err != nil { return err } @@ -283,7 +283,7 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error { // for loaded partition from 2.2, it only save load PartitionLoadInfo. // we should save it's CollectionLoadInfo to meta store - for _, partition := range m.GetAllPartitions() { + for _, partition := range m.GetAllPartitions(ctx) { // In old version, collection would NOT be stored if the partition existed. if _, ok := m.collections[partition.GetCollectionID()]; !ok { col := &Collection{ @@ -296,7 +296,7 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error { }, LoadPercentage: 100, } - err := m.PutCollection(col) + err := m.PutCollection(ctx, col) if err != nil { return err } @@ -305,21 +305,21 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error { return nil } -func (m *CollectionManager) GetCollection(collectionID typeutil.UniqueID) *Collection { +func (m *CollectionManager) GetCollection(ctx context.Context, collectionID typeutil.UniqueID) *Collection { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.collections[collectionID] } -func (m *CollectionManager) GetPartition(partitionID typeutil.UniqueID) *Partition { +func (m *CollectionManager) GetPartition(ctx context.Context, partitionID typeutil.UniqueID) *Partition { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.partitions[partitionID] } -func (m *CollectionManager) GetLoadType(collectionID typeutil.UniqueID) querypb.LoadType { +func (m *CollectionManager) GetLoadType(ctx context.Context, collectionID typeutil.UniqueID) querypb.LoadType { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -330,7 +330,7 @@ func (m *CollectionManager) GetLoadType(collectionID typeutil.UniqueID) querypb. return querypb.LoadType_UnKnownType } -func (m *CollectionManager) GetReplicaNumber(collectionID typeutil.UniqueID) int32 { +func (m *CollectionManager) GetReplicaNumber(ctx context.Context, collectionID typeutil.UniqueID) int32 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -342,7 +342,7 @@ func (m *CollectionManager) GetReplicaNumber(collectionID typeutil.UniqueID) int } // CalculateLoadPercentage checks if collection is currently fully loaded. -func (m *CollectionManager) CalculateLoadPercentage(collectionID typeutil.UniqueID) int32 { +func (m *CollectionManager) CalculateLoadPercentage(ctx context.Context, collectionID typeutil.UniqueID) int32 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -362,7 +362,7 @@ func (m *CollectionManager) calculateLoadPercentage(collectionID typeutil.Unique return -1 } -func (m *CollectionManager) GetPartitionLoadPercentage(partitionID typeutil.UniqueID) int32 { +func (m *CollectionManager) GetPartitionLoadPercentage(ctx context.Context, partitionID typeutil.UniqueID) int32 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -373,7 +373,7 @@ func (m *CollectionManager) GetPartitionLoadPercentage(partitionID typeutil.Uniq return -1 } -func (m *CollectionManager) CalculateLoadStatus(collectionID typeutil.UniqueID) querypb.LoadStatus { +func (m *CollectionManager) CalculateLoadStatus(ctx context.Context, collectionID typeutil.UniqueID) querypb.LoadStatus { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -396,7 +396,7 @@ func (m *CollectionManager) CalculateLoadStatus(collectionID typeutil.UniqueID) return querypb.LoadStatus_Invalid } -func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[int64]int64 { +func (m *CollectionManager) GetFieldIndex(ctx context.Context, collectionID typeutil.UniqueID) map[int64]int64 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -407,7 +407,7 @@ func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[in return nil } -func (m *CollectionManager) GetLoadFields(collectionID typeutil.UniqueID) []int64 { +func (m *CollectionManager) GetLoadFields(ctx context.Context, collectionID typeutil.UniqueID) []int64 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -418,7 +418,7 @@ func (m *CollectionManager) GetLoadFields(collectionID typeutil.UniqueID) []int6 return nil } -func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool { +func (m *CollectionManager) Exist(ctx context.Context, collectionID typeutil.UniqueID) bool { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -427,7 +427,7 @@ func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool { } // GetAll returns the collection ID of all loaded collections -func (m *CollectionManager) GetAll() []int64 { +func (m *CollectionManager) GetAll(ctx context.Context) []int64 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -438,21 +438,21 @@ func (m *CollectionManager) GetAll() []int64 { return ids.Collect() } -func (m *CollectionManager) GetAllCollections() []*Collection { +func (m *CollectionManager) GetAllCollections(ctx context.Context) []*Collection { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return lo.Values(m.collections) } -func (m *CollectionManager) GetAllPartitions() []*Partition { +func (m *CollectionManager) GetAllPartitions(ctx context.Context) []*Partition { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return lo.Values(m.partitions) } -func (m *CollectionManager) GetPartitionsByCollection(collectionID typeutil.UniqueID) []*Partition { +func (m *CollectionManager) GetPartitionsByCollection(ctx context.Context, collectionID typeutil.UniqueID) []*Partition { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -463,26 +463,26 @@ func (m *CollectionManager) getPartitionsByCollection(collectionID typeutil.Uniq return lo.Map(m.collectionPartitions[collectionID].Collect(), func(partitionID int64, _ int) *Partition { return m.partitions[partitionID] }) } -func (m *CollectionManager) PutCollection(collection *Collection, partitions ...*Partition) error { +func (m *CollectionManager) PutCollection(ctx context.Context, collection *Collection, partitions ...*Partition) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() - return m.putCollection(true, collection, partitions...) + return m.putCollection(ctx, true, collection, partitions...) } -func (m *CollectionManager) PutCollectionWithoutSave(collection *Collection) error { +func (m *CollectionManager) PutCollectionWithoutSave(ctx context.Context, collection *Collection) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() - return m.putCollection(false, collection) + return m.putCollection(ctx, false, collection) } -func (m *CollectionManager) putCollection(withSave bool, collection *Collection, partitions ...*Partition) error { +func (m *CollectionManager) putCollection(ctx context.Context, withSave bool, collection *Collection, partitions ...*Partition) error { if withSave { partitionInfos := lo.Map(partitions, func(partition *Partition, _ int) *querypb.PartitionLoadInfo { return partition.PartitionLoadInfo }) - err := m.catalog.SaveCollection(collection.CollectionLoadInfo, partitionInfos...) + err := m.catalog.SaveCollection(ctx, collection.CollectionLoadInfo, partitionInfos...) if err != nil { return err } @@ -504,26 +504,26 @@ func (m *CollectionManager) putCollection(withSave bool, collection *Collection, return nil } -func (m *CollectionManager) PutPartition(partitions ...*Partition) error { +func (m *CollectionManager) PutPartition(ctx context.Context, partitions ...*Partition) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() - return m.putPartition(partitions, true) + return m.putPartition(ctx, partitions, true) } -func (m *CollectionManager) PutPartitionWithoutSave(partitions ...*Partition) error { +func (m *CollectionManager) PutPartitionWithoutSave(ctx context.Context, partitions ...*Partition) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() - return m.putPartition(partitions, false) + return m.putPartition(ctx, partitions, false) } -func (m *CollectionManager) putPartition(partitions []*Partition, withSave bool) error { +func (m *CollectionManager) putPartition(ctx context.Context, partitions []*Partition, withSave bool) error { if withSave { loadInfos := lo.Map(partitions, func(partition *Partition, _ int) *querypb.PartitionLoadInfo { return partition.PartitionLoadInfo }) - err := m.catalog.SavePartition(loadInfos...) + err := m.catalog.SavePartition(ctx, loadInfos...) if err != nil { return err } @@ -543,7 +543,7 @@ func (m *CollectionManager) putPartition(partitions []*Partition, withSave bool) return nil } -func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int32) (int32, error) { +func (m *CollectionManager) UpdateLoadPercent(ctx context.Context, partitionID int64, loadPercent int32) (int32, error) { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -565,7 +565,7 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds())) eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("Partition %d loaded", partitionID))) } - err := m.putPartition([]*Partition{newPartition}, savePartition) + err := m.putPartition(ctx, []*Partition{newPartition}, savePartition) if err != nil { return 0, err } @@ -595,17 +595,17 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds())) eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("Collection %d loaded", newCollection.CollectionID))) } - return collectionPercent, m.putCollection(saveCollection, newCollection) + return collectionPercent, m.putCollection(ctx, saveCollection, newCollection) } // RemoveCollection removes collection and its partitions. -func (m *CollectionManager) RemoveCollection(collectionID typeutil.UniqueID) error { +func (m *CollectionManager) RemoveCollection(ctx context.Context, collectionID typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() _, ok := m.collections[collectionID] if ok { - err := m.catalog.ReleaseCollection(collectionID) + err := m.catalog.ReleaseCollection(ctx, collectionID) if err != nil { return err } @@ -619,7 +619,7 @@ func (m *CollectionManager) RemoveCollection(collectionID typeutil.UniqueID) err return nil } -func (m *CollectionManager) RemovePartition(collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error { +func (m *CollectionManager) RemovePartition(ctx context.Context, collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error { if len(partitionIDs) == 0 { return nil } @@ -627,11 +627,11 @@ func (m *CollectionManager) RemovePartition(collectionID typeutil.UniqueID, part m.rwmutex.Lock() defer m.rwmutex.Unlock() - return m.removePartition(collectionID, partitionIDs...) + return m.removePartition(ctx, collectionID, partitionIDs...) } -func (m *CollectionManager) removePartition(collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error { - err := m.catalog.ReleasePartition(collectionID, partitionIDs...) +func (m *CollectionManager) removePartition(ctx context.Context, collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error { + err := m.catalog.ReleasePartition(ctx, collectionID, partitionIDs...) if err != nil { return err } @@ -644,7 +644,7 @@ func (m *CollectionManager) removePartition(collectionID typeutil.UniqueID, part return nil } -func (m *CollectionManager) UpdateReplicaNumber(collectionID typeutil.UniqueID, replicaNumber int32) error { +func (m *CollectionManager) UpdateReplicaNumber(ctx context.Context, collectionID typeutil.UniqueID, replicaNumber int32) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -663,5 +663,5 @@ func (m *CollectionManager) UpdateReplicaNumber(collectionID typeutil.UniqueID, newPartitions = append(newPartitions, newPartition) } - return m.putCollection(true, newCollection, newPartitions...) + return m.putCollection(ctx, true, newCollection, newPartitions...) } diff --git a/internal/querycoordv2/meta/collection_manager_test.go b/internal/querycoordv2/meta/collection_manager_test.go index 30f0f7b958e6b..2ee5e99dcc318 100644 --- a/internal/querycoordv2/meta/collection_manager_test.go +++ b/internal/querycoordv2/meta/collection_manager_test.go @@ -17,6 +17,7 @@ package meta import ( + "context" "sort" "testing" "time" @@ -59,6 +60,8 @@ type CollectionManagerSuite struct { // Test object mgr *CollectionManager + + ctx context.Context } func (suite *CollectionManagerSuite) SetupSuite() { @@ -85,6 +88,7 @@ func (suite *CollectionManagerSuite) SetupSuite() { 102: {100, 100, 100}, 103: {}, } + suite.ctx = context.Background() } func (suite *CollectionManagerSuite) SetupTest() { @@ -113,12 +117,13 @@ func (suite *CollectionManagerSuite) TearDownTest() { func (suite *CollectionManagerSuite) TestGetProperty() { mgr := suite.mgr + ctx := suite.ctx for i, collection := range suite.collections { - loadType := mgr.GetLoadType(collection) - replicaNumber := mgr.GetReplicaNumber(collection) - percentage := mgr.CalculateLoadPercentage(collection) - exist := mgr.Exist(collection) + loadType := mgr.GetLoadType(ctx, collection) + replicaNumber := mgr.GetReplicaNumber(ctx, collection) + percentage := mgr.CalculateLoadPercentage(ctx, collection) + exist := mgr.Exist(ctx, collection) suite.Equal(suite.loadTypes[i], loadType) suite.Equal(suite.replicaNumber[i], replicaNumber) suite.Equal(suite.colLoadPercent[i], percentage) @@ -126,10 +131,10 @@ func (suite *CollectionManagerSuite) TestGetProperty() { } invalidCollection := -1 - loadType := mgr.GetLoadType(int64(invalidCollection)) - replicaNumber := mgr.GetReplicaNumber(int64(invalidCollection)) - percentage := mgr.CalculateLoadPercentage(int64(invalidCollection)) - exist := mgr.Exist(int64(invalidCollection)) + loadType := mgr.GetLoadType(ctx, int64(invalidCollection)) + replicaNumber := mgr.GetReplicaNumber(ctx, int64(invalidCollection)) + percentage := mgr.CalculateLoadPercentage(ctx, int64(invalidCollection)) + exist := mgr.Exist(ctx, int64(invalidCollection)) suite.Equal(querypb.LoadType_UnKnownType, loadType) suite.EqualValues(-1, replicaNumber) suite.EqualValues(-1, percentage) @@ -138,6 +143,7 @@ func (suite *CollectionManagerSuite) TestGetProperty() { func (suite *CollectionManagerSuite) TestPut() { suite.releaseAll() + ctx := suite.ctx // test put collection with partitions for i, collection := range suite.collections { status := querypb.LoadStatus_Loaded @@ -167,7 +173,7 @@ func (suite *CollectionManagerSuite) TestPut() { CreatedAt: time.Now(), } }) - err := suite.mgr.PutCollection(col, partitions...) + err := suite.mgr.PutCollection(ctx, col, partitions...) suite.NoError(err) } suite.checkLoadResult() @@ -179,43 +185,44 @@ func (suite *CollectionManagerSuite) TestGet() { func (suite *CollectionManagerSuite) TestUpdate() { mgr := suite.mgr + ctx := suite.ctx - collections := mgr.GetAllCollections() - partitions := mgr.GetAllPartitions() + collections := mgr.GetAllCollections(ctx) + partitions := mgr.GetAllPartitions(ctx) for _, collection := range collections { collection := collection.Clone() collection.LoadPercentage = 100 - err := mgr.PutCollectionWithoutSave(collection) + err := mgr.PutCollectionWithoutSave(ctx, collection) suite.NoError(err) - modified := mgr.GetCollection(collection.GetCollectionID()) + modified := mgr.GetCollection(ctx, collection.GetCollectionID()) suite.Equal(collection, modified) suite.EqualValues(100, modified.LoadPercentage) collection.Status = querypb.LoadStatus_Loaded - err = mgr.PutCollection(collection) + err = mgr.PutCollection(ctx, collection) suite.NoError(err) } for _, partition := range partitions { partition := partition.Clone() partition.LoadPercentage = 100 - err := mgr.PutPartitionWithoutSave(partition) + err := mgr.PutPartitionWithoutSave(ctx, partition) suite.NoError(err) - modified := mgr.GetPartition(partition.GetPartitionID()) + modified := mgr.GetPartition(ctx, partition.GetPartitionID()) suite.Equal(partition, modified) suite.EqualValues(100, modified.LoadPercentage) partition.Status = querypb.LoadStatus_Loaded - err = mgr.PutPartition(partition) + err = mgr.PutPartition(ctx, partition) suite.NoError(err) } suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) - collections = mgr.GetAllCollections() - partitions = mgr.GetAllPartitions() + collections = mgr.GetAllCollections(ctx) + partitions = mgr.GetAllPartitions(ctx) for _, collection := range collections { suite.Equal(querypb.LoadStatus_Loaded, collection.GetStatus()) } @@ -226,7 +233,8 @@ func (suite *CollectionManagerSuite) TestUpdate() { func (suite *CollectionManagerSuite) TestGetFieldIndex() { mgr := suite.mgr - mgr.PutCollection(&Collection{ + ctx := suite.ctx + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: 1, ReplicaNumber: 1, @@ -237,7 +245,7 @@ func (suite *CollectionManagerSuite) TestGetFieldIndex() { LoadPercentage: 0, CreatedAt: time.Now(), }) - indexID := mgr.GetFieldIndex(1) + indexID := mgr.GetFieldIndex(ctx, 1) suite.Len(indexID, 2) suite.Contains(indexID, int64(1)) suite.Contains(indexID, int64(2)) @@ -245,14 +253,15 @@ func (suite *CollectionManagerSuite) TestGetFieldIndex() { func (suite *CollectionManagerSuite) TestRemove() { mgr := suite.mgr + ctx := suite.ctx // Remove collections/partitions for i, collectionID := range suite.collections { if suite.loadTypes[i] == querypb.LoadType_LoadCollection { - err := mgr.RemoveCollection(collectionID) + err := mgr.RemoveCollection(ctx, collectionID) suite.NoError(err) } else { - err := mgr.RemovePartition(collectionID, suite.partitions[collectionID]...) + err := mgr.RemovePartition(ctx, collectionID, suite.partitions[collectionID]...) suite.NoError(err) } } @@ -260,23 +269,23 @@ func (suite *CollectionManagerSuite) TestRemove() { // Try to get the removed items for i, collectionID := range suite.collections { if suite.loadTypes[i] == querypb.LoadType_LoadCollection { - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.Nil(collection) } else { - partitions := mgr.GetPartitionsByCollection(collectionID) + partitions := mgr.GetPartitionsByCollection(ctx, collectionID) suite.Empty(partitions) } } // Make sure the removes applied to meta store - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) for i, collectionID := range suite.collections { if suite.loadTypes[i] == querypb.LoadType_LoadCollection { - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.Nil(collection) } else { - partitions := mgr.GetPartitionsByCollection(collectionID) + partitions := mgr.GetPartitionsByCollection(ctx, collectionID) suite.Empty(partitions) } } @@ -285,9 +294,9 @@ func (suite *CollectionManagerSuite) TestRemove() { suite.loadAll() for i, collectionID := range suite.collections { if suite.loadTypes[i] == querypb.LoadType_LoadPartition { - err := mgr.RemoveCollection(collectionID) + err := mgr.RemoveCollection(ctx, collectionID) suite.NoError(err) - partitions := mgr.GetPartitionsByCollection(collectionID) + partitions := mgr.GetPartitionsByCollection(ctx, collectionID) suite.Empty(partitions) } } @@ -296,27 +305,28 @@ func (suite *CollectionManagerSuite) TestRemove() { suite.releaseAll() suite.loadAll() for _, collectionID := range suite.collections { - err := mgr.RemoveCollection(collectionID) + err := mgr.RemoveCollection(ctx, collectionID) suite.NoError(err) - err = mgr.Recover(suite.broker) + err = mgr.Recover(ctx, suite.broker) suite.NoError(err) - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.Nil(collection) - partitions := mgr.GetPartitionsByCollection(collectionID) + partitions := mgr.GetPartitionsByCollection(ctx, collectionID) suite.Empty(partitions) } } func (suite *CollectionManagerSuite) TestRecover_normal() { mgr := suite.mgr + ctx := suite.ctx suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) for _, collection := range suite.collections { - suite.True(mgr.Exist(collection)) + suite.True(mgr.Exist(ctx, collection)) for _, partitionID := range suite.partitions[collection] { - partition := mgr.GetPartition(partitionID) + partition := mgr.GetPartition(ctx, partitionID) suite.NotNil(partition) } } @@ -325,6 +335,7 @@ func (suite *CollectionManagerSuite) TestRecover_normal() { func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { mgr := suite.mgr suite.releaseAll() + ctx := suite.ctx // test put collection with partitions for i, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() @@ -350,20 +361,20 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { CreatedAt: time.Now(), } }) - err := suite.mgr.PutCollection(col, partitions...) + err := suite.mgr.PutCollection(ctx, col, partitions...) suite.NoError(err) } // recover for first time, expected recover success suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) for _, collectionID := range suite.collections { - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.NotNil(collection) suite.Equal(int32(1), collection.GetRecoverTimes()) for _, partitionID := range suite.partitions[collectionID] { - partition := mgr.GetPartition(partitionID) + partition := mgr.GetPartition(ctx, partitionID) suite.NotNil(partition) suite.Equal(int32(1), partition.GetRecoverTimes()) } @@ -372,18 +383,18 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { // update load percent, then recover for second time for _, collectionID := range suite.collections { for _, partitionID := range suite.partitions[collectionID] { - mgr.UpdateLoadPercent(partitionID, 10) + mgr.UpdateLoadPercent(ctx, partitionID, 10) } } suite.clearMemory() - err = mgr.Recover(suite.broker) + err = mgr.Recover(ctx, suite.broker) suite.NoError(err) for _, collectionID := range suite.collections { - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.NotNil(collection) suite.Equal(int32(2), collection.GetRecoverTimes()) for _, partitionID := range suite.partitions[collectionID] { - partition := mgr.GetPartition(partitionID) + partition := mgr.GetPartition(ctx, partitionID) suite.NotNil(partition) suite.Equal(int32(2), partition.GetRecoverTimes()) } @@ -393,14 +404,14 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { for i := 0; i < int(paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32()); i++ { log.Info("stupid", zap.Int("count", i)) suite.clearMemory() - err = mgr.Recover(suite.broker) + err = mgr.Recover(ctx, suite.broker) suite.NoError(err) } for _, collectionID := range suite.collections { - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.Nil(collection) for _, partitionID := range suite.partitions[collectionID] { - partition := mgr.GetPartition(partitionID) + partition := mgr.GetPartition(ctx, partitionID) suite.Nil(partition) } } @@ -408,7 +419,8 @@ func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() { mgr := suite.mgr - mgr.PutCollection(&Collection{ + ctx := suite.ctx + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: 1, ReplicaNumber: 1, @@ -421,7 +433,7 @@ func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() { partitions := []int64{1, 2} for _, partition := range partitions { - mgr.PutPartition(&Partition{ + mgr.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: 1, PartitionID: partition, @@ -432,42 +444,43 @@ func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() { }) } // test update partition load percentage - mgr.UpdateLoadPercent(1, 30) - partition := mgr.GetPartition(1) + mgr.UpdateLoadPercent(ctx, 1, 30) + partition := mgr.GetPartition(ctx, 1) suite.Equal(int32(30), partition.LoadPercentage) - suite.Equal(int32(30), mgr.GetPartitionLoadPercentage(partition.PartitionID)) + suite.Equal(int32(30), mgr.GetPartitionLoadPercentage(ctx, partition.PartitionID)) suite.Equal(querypb.LoadStatus_Loading, partition.Status) - collection := mgr.GetCollection(1) + collection := mgr.GetCollection(ctx, 1) suite.Equal(int32(15), collection.LoadPercentage) suite.Equal(querypb.LoadStatus_Loading, collection.Status) // test update partition load percentage to 100 - mgr.UpdateLoadPercent(1, 100) - partition = mgr.GetPartition(1) + mgr.UpdateLoadPercent(ctx, 1, 100) + partition = mgr.GetPartition(ctx, 1) suite.Equal(int32(100), partition.LoadPercentage) suite.Equal(querypb.LoadStatus_Loaded, partition.Status) - collection = mgr.GetCollection(1) + collection = mgr.GetCollection(ctx, 1) suite.Equal(int32(50), collection.LoadPercentage) suite.Equal(querypb.LoadStatus_Loading, collection.Status) // test update collection load percentage - mgr.UpdateLoadPercent(2, 100) - partition = mgr.GetPartition(1) + mgr.UpdateLoadPercent(ctx, 2, 100) + partition = mgr.GetPartition(ctx, 1) suite.Equal(int32(100), partition.LoadPercentage) suite.Equal(querypb.LoadStatus_Loaded, partition.Status) - collection = mgr.GetCollection(1) + collection = mgr.GetCollection(ctx, 1) suite.Equal(int32(100), collection.LoadPercentage) suite.Equal(querypb.LoadStatus_Loaded, collection.Status) - suite.Equal(querypb.LoadStatus_Loaded, mgr.CalculateLoadStatus(collection.CollectionID)) + suite.Equal(querypb.LoadStatus_Loaded, mgr.CalculateLoadStatus(ctx, collection.CollectionID)) } func (suite *CollectionManagerSuite) TestUpgradeRecover() { suite.releaseAll() mgr := suite.mgr + ctx := suite.ctx // put old version of collections and partitions for i, collection := range suite.collections { status := querypb.LoadStatus_Loaded if suite.loadTypes[i] == querypb.LoadType_LoadCollection { - mgr.PutCollection(&Collection{ + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection, ReplicaNumber: suite.replicaNumber[i], @@ -479,7 +492,7 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() { }) } else { for _, partition := range suite.partitions[collection] { - mgr.PutPartition(&Partition{ + mgr.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collection, PartitionID: partition, @@ -513,12 +526,12 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() { // do recovery suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) suite.checkLoadResult() for i, collection := range suite.collections { - newColl := mgr.GetCollection(collection) + newColl := mgr.GetCollection(ctx, collection) suite.Equal(suite.loadTypes[i], newColl.GetLoadType()) } } @@ -526,10 +539,11 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() { func (suite *CollectionManagerSuite) TestUpgradeLoadFields() { suite.releaseAll() mgr := suite.mgr + ctx := suite.ctx // put old version of collections and partitions for i, collection := range suite.collections { - mgr.PutCollection(&Collection{ + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection, ReplicaNumber: suite.replicaNumber[i], @@ -541,7 +555,7 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFields() { CreatedAt: time.Now(), }) for j, partition := range suite.partitions[collection] { - mgr.PutPartition(&Partition{ + mgr.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collection, PartitionID: partition, @@ -570,12 +584,12 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFields() { // do recovery suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) suite.checkLoadResult() for _, collection := range suite.collections { - newColl := mgr.GetCollection(collection) + newColl := mgr.GetCollection(ctx, collection) suite.ElementsMatch([]int64{100, 101}, newColl.GetLoadFields()) } } @@ -584,8 +598,9 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() { suite.Run("normal_error", func() { suite.releaseAll() mgr := suite.mgr + ctx := suite.ctx - mgr.PutCollection(&Collection{ + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: 100, ReplicaNumber: 1, @@ -596,7 +611,7 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() { LoadPercentage: 100, CreatedAt: time.Now(), }) - mgr.PutPartition(&Partition{ + mgr.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: 100, PartitionID: 1000, @@ -609,15 +624,16 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() { suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(100)).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() // do recovery suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.Error(err) }) suite.Run("normal_error", func() { suite.releaseAll() mgr := suite.mgr + ctx := suite.ctx - mgr.PutCollection(&Collection{ + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: 100, ReplicaNumber: 1, @@ -628,7 +644,7 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() { LoadPercentage: 100, CreatedAt: time.Now(), }) - mgr.PutPartition(&Partition{ + mgr.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: 100, PartitionID: 1000, @@ -643,13 +659,14 @@ func (suite *CollectionManagerSuite) TestUpgradeLoadFieldsFail() { }, nil).Once() // do recovery suite.clearMemory() - err := mgr.Recover(suite.broker) + err := mgr.Recover(ctx, suite.broker) suite.NoError(err) }) } func (suite *CollectionManagerSuite) loadAll() { mgr := suite.mgr + ctx := suite.ctx for i, collection := range suite.collections { status := querypb.LoadStatus_Loaded @@ -657,7 +674,7 @@ func (suite *CollectionManagerSuite) loadAll() { status = querypb.LoadStatus_Loading } - mgr.PutCollection(&Collection{ + mgr.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection, ReplicaNumber: suite.replicaNumber[i], @@ -670,7 +687,7 @@ func (suite *CollectionManagerSuite) loadAll() { }) for j, partition := range suite.partitions[collection] { - mgr.PutPartition(&Partition{ + mgr.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collection, PartitionID: partition, @@ -685,18 +702,19 @@ func (suite *CollectionManagerSuite) loadAll() { func (suite *CollectionManagerSuite) checkLoadResult() { mgr := suite.mgr + ctx := suite.ctx - allCollections := mgr.GetAllCollections() - allPartitions := mgr.GetAllPartitions() + allCollections := mgr.GetAllCollections(ctx) + allPartitions := mgr.GetAllPartitions(ctx) for _, collectionID := range suite.collections { - collection := mgr.GetCollection(collectionID) + collection := mgr.GetCollection(ctx, collectionID) suite.Equal(collectionID, collection.GetCollectionID()) suite.Contains(allCollections, collection) - partitions := mgr.GetPartitionsByCollection(collectionID) + partitions := mgr.GetPartitionsByCollection(ctx, collectionID) suite.Len(partitions, len(suite.partitions[collectionID])) for _, partitionID := range suite.partitions[collectionID] { - partition := mgr.GetPartition(partitionID) + partition := mgr.GetPartition(ctx, partitionID) suite.Equal(collectionID, partition.GetCollectionID()) suite.Equal(partitionID, partition.GetPartitionID()) suite.Contains(partitions, partition) @@ -704,14 +722,14 @@ func (suite *CollectionManagerSuite) checkLoadResult() { } } - all := mgr.GetAll() + all := mgr.GetAll(ctx) sort.Slice(all, func(i, j int) bool { return all[i] < all[j] }) suite.Equal(suite.collections, all) } func (suite *CollectionManagerSuite) releaseAll() { for _, collection := range suite.collections { - err := suite.mgr.RemoveCollection(collection) + err := suite.mgr.RemoveCollection(context.TODO(), collection) suite.NoError(err) } } diff --git a/internal/querycoordv2/meta/mock_target_manager.go b/internal/querycoordv2/meta/mock_target_manager.go index 9968d495fe3ab..d31dece57cbd0 100644 --- a/internal/querycoordv2/meta/mock_target_manager.go +++ b/internal/querycoordv2/meta/mock_target_manager.go @@ -3,6 +3,8 @@ package meta import ( + context "context" + metastore "github.com/milvus-io/milvus/internal/metastore" datapb "github.com/milvus-io/milvus/internal/proto/datapb" @@ -24,17 +26,17 @@ func (_m *MockTargetManager) EXPECT() *MockTargetManager_Expecter { return &MockTargetManager_Expecter{mock: &_m.Mock} } -// CanSegmentBeMoved provides a mock function with given fields: collectionID, segmentID -func (_m *MockTargetManager) CanSegmentBeMoved(collectionID int64, segmentID int64) bool { - ret := _m.Called(collectionID, segmentID) +// CanSegmentBeMoved provides a mock function with given fields: ctx, collectionID, segmentID +func (_m *MockTargetManager) CanSegmentBeMoved(ctx context.Context, collectionID int64, segmentID int64) bool { + ret := _m.Called(ctx, collectionID, segmentID) if len(ret) == 0 { panic("no return value specified for CanSegmentBeMoved") } var r0 bool - if rf, ok := ret.Get(0).(func(int64, int64) bool); ok { - r0 = rf(collectionID, segmentID) + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) bool); ok { + r0 = rf(ctx, collectionID, segmentID) } else { r0 = ret.Get(0).(bool) } @@ -48,15 +50,16 @@ type MockTargetManager_CanSegmentBeMoved_Call struct { } // CanSegmentBeMoved is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - segmentID int64 -func (_e *MockTargetManager_Expecter) CanSegmentBeMoved(collectionID interface{}, segmentID interface{}) *MockTargetManager_CanSegmentBeMoved_Call { - return &MockTargetManager_CanSegmentBeMoved_Call{Call: _e.mock.On("CanSegmentBeMoved", collectionID, segmentID)} +func (_e *MockTargetManager_Expecter) CanSegmentBeMoved(ctx interface{}, collectionID interface{}, segmentID interface{}) *MockTargetManager_CanSegmentBeMoved_Call { + return &MockTargetManager_CanSegmentBeMoved_Call{Call: _e.mock.On("CanSegmentBeMoved", ctx, collectionID, segmentID)} } -func (_c *MockTargetManager_CanSegmentBeMoved_Call) Run(run func(collectionID int64, segmentID int64)) *MockTargetManager_CanSegmentBeMoved_Call { +func (_c *MockTargetManager_CanSegmentBeMoved_Call) Run(run func(ctx context.Context, collectionID int64, segmentID int64)) *MockTargetManager_CanSegmentBeMoved_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int64)) + run(args[0].(context.Context), args[1].(int64), args[2].(int64)) }) return _c } @@ -66,22 +69,22 @@ func (_c *MockTargetManager_CanSegmentBeMoved_Call) Return(_a0 bool) *MockTarget return _c } -func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(int64, int64) bool) *MockTargetManager_CanSegmentBeMoved_Call { +func (_c *MockTargetManager_CanSegmentBeMoved_Call) RunAndReturn(run func(context.Context, int64, int64) bool) *MockTargetManager_CanSegmentBeMoved_Call { _c.Call.Return(run) return _c } -// GetCollectionTargetVersion provides a mock function with given fields: collectionID, scope -func (_m *MockTargetManager) GetCollectionTargetVersion(collectionID int64, scope int32) int64 { - ret := _m.Called(collectionID, scope) +// GetCollectionTargetVersion provides a mock function with given fields: ctx, collectionID, scope +func (_m *MockTargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope int32) int64 { + ret := _m.Called(ctx, collectionID, scope) if len(ret) == 0 { panic("no return value specified for GetCollectionTargetVersion") } var r0 int64 - if rf, ok := ret.Get(0).(func(int64, int32) int64); ok { - r0 = rf(collectionID, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) int64); ok { + r0 = rf(ctx, collectionID, scope) } else { r0 = ret.Get(0).(int64) } @@ -95,15 +98,16 @@ type MockTargetManager_GetCollectionTargetVersion_Call struct { } // GetCollectionTargetVersion is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - scope int32 -func (_e *MockTargetManager_Expecter) GetCollectionTargetVersion(collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionTargetVersion_Call { - return &MockTargetManager_GetCollectionTargetVersion_Call{Call: _e.mock.On("GetCollectionTargetVersion", collectionID, scope)} +func (_e *MockTargetManager_Expecter) GetCollectionTargetVersion(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionTargetVersion_Call { + return &MockTargetManager_GetCollectionTargetVersion_Call{Call: _e.mock.On("GetCollectionTargetVersion", ctx, collectionID, scope)} } -func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetCollectionTargetVersion_Call { +func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetCollectionTargetVersion_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) }) return _c } @@ -113,22 +117,22 @@ func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Return(_a0 int64) * return _c } -func (_c *MockTargetManager_GetCollectionTargetVersion_Call) RunAndReturn(run func(int64, int32) int64) *MockTargetManager_GetCollectionTargetVersion_Call { +func (_c *MockTargetManager_GetCollectionTargetVersion_Call) RunAndReturn(run func(context.Context, int64, int32) int64) *MockTargetManager_GetCollectionTargetVersion_Call { _c.Call.Return(run) return _c } -// GetDmChannel provides a mock function with given fields: collectionID, channel, scope -func (_m *MockTargetManager) GetDmChannel(collectionID int64, channel string, scope int32) *DmChannel { - ret := _m.Called(collectionID, channel, scope) +// GetDmChannel provides a mock function with given fields: ctx, collectionID, channel, scope +func (_m *MockTargetManager) GetDmChannel(ctx context.Context, collectionID int64, channel string, scope int32) *DmChannel { + ret := _m.Called(ctx, collectionID, channel, scope) if len(ret) == 0 { panic("no return value specified for GetDmChannel") } var r0 *DmChannel - if rf, ok := ret.Get(0).(func(int64, string, int32) *DmChannel); ok { - r0 = rf(collectionID, channel, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) *DmChannel); ok { + r0 = rf(ctx, collectionID, channel, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*DmChannel) @@ -144,16 +148,17 @@ type MockTargetManager_GetDmChannel_Call struct { } // GetDmChannel is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - channel string // - scope int32 -func (_e *MockTargetManager_Expecter) GetDmChannel(collectionID interface{}, channel interface{}, scope interface{}) *MockTargetManager_GetDmChannel_Call { - return &MockTargetManager_GetDmChannel_Call{Call: _e.mock.On("GetDmChannel", collectionID, channel, scope)} +func (_e *MockTargetManager_Expecter) GetDmChannel(ctx interface{}, collectionID interface{}, channel interface{}, scope interface{}) *MockTargetManager_GetDmChannel_Call { + return &MockTargetManager_GetDmChannel_Call{Call: _e.mock.On("GetDmChannel", ctx, collectionID, channel, scope)} } -func (_c *MockTargetManager_GetDmChannel_Call) Run(run func(collectionID int64, channel string, scope int32)) *MockTargetManager_GetDmChannel_Call { +func (_c *MockTargetManager_GetDmChannel_Call) Run(run func(ctx context.Context, collectionID int64, channel string, scope int32)) *MockTargetManager_GetDmChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(string), args[2].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32)) }) return _c } @@ -163,22 +168,22 @@ func (_c *MockTargetManager_GetDmChannel_Call) Return(_a0 *DmChannel) *MockTarge return _c } -func (_c *MockTargetManager_GetDmChannel_Call) RunAndReturn(run func(int64, string, int32) *DmChannel) *MockTargetManager_GetDmChannel_Call { +func (_c *MockTargetManager_GetDmChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) *DmChannel) *MockTargetManager_GetDmChannel_Call { _c.Call.Return(run) return _c } -// GetDmChannelsByCollection provides a mock function with given fields: collectionID, scope -func (_m *MockTargetManager) GetDmChannelsByCollection(collectionID int64, scope int32) map[string]*DmChannel { - ret := _m.Called(collectionID, scope) +// GetDmChannelsByCollection provides a mock function with given fields: ctx, collectionID, scope +func (_m *MockTargetManager) GetDmChannelsByCollection(ctx context.Context, collectionID int64, scope int32) map[string]*DmChannel { + ret := _m.Called(ctx, collectionID, scope) if len(ret) == 0 { panic("no return value specified for GetDmChannelsByCollection") } var r0 map[string]*DmChannel - if rf, ok := ret.Get(0).(func(int64, int32) map[string]*DmChannel); ok { - r0 = rf(collectionID, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) map[string]*DmChannel); ok { + r0 = rf(ctx, collectionID, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[string]*DmChannel) @@ -194,15 +199,16 @@ type MockTargetManager_GetDmChannelsByCollection_Call struct { } // GetDmChannelsByCollection is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - scope int32 -func (_e *MockTargetManager_Expecter) GetDmChannelsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetDmChannelsByCollection_Call { - return &MockTargetManager_GetDmChannelsByCollection_Call{Call: _e.mock.On("GetDmChannelsByCollection", collectionID, scope)} +func (_e *MockTargetManager_Expecter) GetDmChannelsByCollection(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetDmChannelsByCollection_Call { + return &MockTargetManager_GetDmChannelsByCollection_Call{Call: _e.mock.On("GetDmChannelsByCollection", ctx, collectionID, scope)} } -func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetDmChannelsByCollection_Call { +func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetDmChannelsByCollection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) }) return _c } @@ -212,22 +218,22 @@ func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Return(_a0 map[strin return _c } -func (_c *MockTargetManager_GetDmChannelsByCollection_Call) RunAndReturn(run func(int64, int32) map[string]*DmChannel) *MockTargetManager_GetDmChannelsByCollection_Call { +func (_c *MockTargetManager_GetDmChannelsByCollection_Call) RunAndReturn(run func(context.Context, int64, int32) map[string]*DmChannel) *MockTargetManager_GetDmChannelsByCollection_Call { _c.Call.Return(run) return _c } -// GetDroppedSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope -func (_m *MockTargetManager) GetDroppedSegmentsByChannel(collectionID int64, channelName string, scope int32) []int64 { - ret := _m.Called(collectionID, channelName, scope) +// GetDroppedSegmentsByChannel provides a mock function with given fields: ctx, collectionID, channelName, scope +func (_m *MockTargetManager) GetDroppedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope int32) []int64 { + ret := _m.Called(ctx, collectionID, channelName, scope) if len(ret) == 0 { panic("no return value specified for GetDroppedSegmentsByChannel") } var r0 []int64 - if rf, ok := ret.Get(0).(func(int64, string, int32) []int64); ok { - r0 = rf(collectionID, channelName, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) []int64); ok { + r0 = rf(ctx, collectionID, channelName, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int64) @@ -243,16 +249,17 @@ type MockTargetManager_GetDroppedSegmentsByChannel_Call struct { } // GetDroppedSegmentsByChannel is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - channelName string // - scope int32 -func (_e *MockTargetManager_Expecter) GetDroppedSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetDroppedSegmentsByChannel_Call { - return &MockTargetManager_GetDroppedSegmentsByChannel_Call{Call: _e.mock.On("GetDroppedSegmentsByChannel", collectionID, channelName, scope)} +func (_e *MockTargetManager_Expecter) GetDroppedSegmentsByChannel(ctx interface{}, collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetDroppedSegmentsByChannel_Call { + return &MockTargetManager_GetDroppedSegmentsByChannel_Call{Call: _e.mock.On("GetDroppedSegmentsByChannel", ctx, collectionID, channelName, scope)} } -func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetDroppedSegmentsByChannel_Call { +func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Run(run func(ctx context.Context, collectionID int64, channelName string, scope int32)) *MockTargetManager_GetDroppedSegmentsByChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(string), args[2].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32)) }) return _c } @@ -262,22 +269,22 @@ func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Return(_a0 []int64 return _c } -func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) []int64) *MockTargetManager_GetDroppedSegmentsByChannel_Call { +func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) []int64) *MockTargetManager_GetDroppedSegmentsByChannel_Call { _c.Call.Return(run) return _c } -// GetGrowingSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope -func (_m *MockTargetManager) GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope int32) typeutil.Set[int64] { - ret := _m.Called(collectionID, channelName, scope) +// GetGrowingSegmentsByChannel provides a mock function with given fields: ctx, collectionID, channelName, scope +func (_m *MockTargetManager) GetGrowingSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope int32) typeutil.Set[int64] { + ret := _m.Called(ctx, collectionID, channelName, scope) if len(ret) == 0 { panic("no return value specified for GetGrowingSegmentsByChannel") } var r0 typeutil.Set[int64] - if rf, ok := ret.Get(0).(func(int64, string, int32) typeutil.Set[int64]); ok { - r0 = rf(collectionID, channelName, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) typeutil.Set[int64]); ok { + r0 = rf(ctx, collectionID, channelName, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(typeutil.Set[int64]) @@ -293,16 +300,17 @@ type MockTargetManager_GetGrowingSegmentsByChannel_Call struct { } // GetGrowingSegmentsByChannel is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - channelName string // - scope int32 -func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByChannel_Call { - return &MockTargetManager_GetGrowingSegmentsByChannel_Call{Call: _e.mock.On("GetGrowingSegmentsByChannel", collectionID, channelName, scope)} +func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByChannel(ctx interface{}, collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByChannel_Call { + return &MockTargetManager_GetGrowingSegmentsByChannel_Call{Call: _e.mock.On("GetGrowingSegmentsByChannel", ctx, collectionID, channelName, scope)} } -func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetGrowingSegmentsByChannel_Call { +func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Run(run func(ctx context.Context, collectionID int64, channelName string, scope int32)) *MockTargetManager_GetGrowingSegmentsByChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(string), args[2].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32)) }) return _c } @@ -312,22 +320,22 @@ func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Return(_a0 typeuti return _c } -func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByChannel_Call { +func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByChannel_Call { _c.Call.Return(run) return _c } -// GetGrowingSegmentsByCollection provides a mock function with given fields: collectionID, scope -func (_m *MockTargetManager) GetGrowingSegmentsByCollection(collectionID int64, scope int32) typeutil.Set[int64] { - ret := _m.Called(collectionID, scope) +// GetGrowingSegmentsByCollection provides a mock function with given fields: ctx, collectionID, scope +func (_m *MockTargetManager) GetGrowingSegmentsByCollection(ctx context.Context, collectionID int64, scope int32) typeutil.Set[int64] { + ret := _m.Called(ctx, collectionID, scope) if len(ret) == 0 { panic("no return value specified for GetGrowingSegmentsByCollection") } var r0 typeutil.Set[int64] - if rf, ok := ret.Get(0).(func(int64, int32) typeutil.Set[int64]); ok { - r0 = rf(collectionID, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) typeutil.Set[int64]); ok { + r0 = rf(ctx, collectionID, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(typeutil.Set[int64]) @@ -343,15 +351,16 @@ type MockTargetManager_GetGrowingSegmentsByCollection_Call struct { } // GetGrowingSegmentsByCollection is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - scope int32 -func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByCollection_Call { - return &MockTargetManager_GetGrowingSegmentsByCollection_Call{Call: _e.mock.On("GetGrowingSegmentsByCollection", collectionID, scope)} +func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByCollection(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByCollection_Call { + return &MockTargetManager_GetGrowingSegmentsByCollection_Call{Call: _e.mock.On("GetGrowingSegmentsByCollection", ctx, collectionID, scope)} } -func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetGrowingSegmentsByCollection_Call { +func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetGrowingSegmentsByCollection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) }) return _c } @@ -361,22 +370,22 @@ func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Return(_a0 type return _c } -func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) RunAndReturn(run func(int64, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByCollection_Call { +func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) RunAndReturn(run func(context.Context, int64, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByCollection_Call { _c.Call.Return(run) return _c } -// GetSealedSegment provides a mock function with given fields: collectionID, id, scope -func (_m *MockTargetManager) GetSealedSegment(collectionID int64, id int64, scope int32) *datapb.SegmentInfo { - ret := _m.Called(collectionID, id, scope) +// GetSealedSegment provides a mock function with given fields: ctx, collectionID, id, scope +func (_m *MockTargetManager) GetSealedSegment(ctx context.Context, collectionID int64, id int64, scope int32) *datapb.SegmentInfo { + ret := _m.Called(ctx, collectionID, id, scope) if len(ret) == 0 { panic("no return value specified for GetSealedSegment") } var r0 *datapb.SegmentInfo - if rf, ok := ret.Get(0).(func(int64, int64, int32) *datapb.SegmentInfo); ok { - r0 = rf(collectionID, id, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, int64, int32) *datapb.SegmentInfo); ok { + r0 = rf(ctx, collectionID, id, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.SegmentInfo) @@ -392,16 +401,17 @@ type MockTargetManager_GetSealedSegment_Call struct { } // GetSealedSegment is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - id int64 // - scope int32 -func (_e *MockTargetManager_Expecter) GetSealedSegment(collectionID interface{}, id interface{}, scope interface{}) *MockTargetManager_GetSealedSegment_Call { - return &MockTargetManager_GetSealedSegment_Call{Call: _e.mock.On("GetSealedSegment", collectionID, id, scope)} +func (_e *MockTargetManager_Expecter) GetSealedSegment(ctx interface{}, collectionID interface{}, id interface{}, scope interface{}) *MockTargetManager_GetSealedSegment_Call { + return &MockTargetManager_GetSealedSegment_Call{Call: _e.mock.On("GetSealedSegment", ctx, collectionID, id, scope)} } -func (_c *MockTargetManager_GetSealedSegment_Call) Run(run func(collectionID int64, id int64, scope int32)) *MockTargetManager_GetSealedSegment_Call { +func (_c *MockTargetManager_GetSealedSegment_Call) Run(run func(ctx context.Context, collectionID int64, id int64, scope int32)) *MockTargetManager_GetSealedSegment_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int64), args[2].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(int64), args[3].(int32)) }) return _c } @@ -411,22 +421,22 @@ func (_c *MockTargetManager_GetSealedSegment_Call) Return(_a0 *datapb.SegmentInf return _c } -func (_c *MockTargetManager_GetSealedSegment_Call) RunAndReturn(run func(int64, int64, int32) *datapb.SegmentInfo) *MockTargetManager_GetSealedSegment_Call { +func (_c *MockTargetManager_GetSealedSegment_Call) RunAndReturn(run func(context.Context, int64, int64, int32) *datapb.SegmentInfo) *MockTargetManager_GetSealedSegment_Call { _c.Call.Return(run) return _c } -// GetSealedSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope -func (_m *MockTargetManager) GetSealedSegmentsByChannel(collectionID int64, channelName string, scope int32) map[int64]*datapb.SegmentInfo { - ret := _m.Called(collectionID, channelName, scope) +// GetSealedSegmentsByChannel provides a mock function with given fields: ctx, collectionID, channelName, scope +func (_m *MockTargetManager) GetSealedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope int32) map[int64]*datapb.SegmentInfo { + ret := _m.Called(ctx, collectionID, channelName, scope) if len(ret) == 0 { panic("no return value specified for GetSealedSegmentsByChannel") } var r0 map[int64]*datapb.SegmentInfo - if rf, ok := ret.Get(0).(func(int64, string, int32) map[int64]*datapb.SegmentInfo); ok { - r0 = rf(collectionID, channelName, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, string, int32) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(ctx, collectionID, channelName, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) @@ -442,16 +452,17 @@ type MockTargetManager_GetSealedSegmentsByChannel_Call struct { } // GetSealedSegmentsByChannel is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - channelName string // - scope int32 -func (_e *MockTargetManager_Expecter) GetSealedSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByChannel_Call { - return &MockTargetManager_GetSealedSegmentsByChannel_Call{Call: _e.mock.On("GetSealedSegmentsByChannel", collectionID, channelName, scope)} +func (_e *MockTargetManager_Expecter) GetSealedSegmentsByChannel(ctx interface{}, collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByChannel_Call { + return &MockTargetManager_GetSealedSegmentsByChannel_Call{Call: _e.mock.On("GetSealedSegmentsByChannel", ctx, collectionID, channelName, scope)} } -func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetSealedSegmentsByChannel_Call { +func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Run(run func(ctx context.Context, collectionID int64, channelName string, scope int32)) *MockTargetManager_GetSealedSegmentsByChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(string), args[2].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(string), args[3].(int32)) }) return _c } @@ -461,22 +472,22 @@ func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Return(_a0 map[int6 return _c } -func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByChannel_Call { +func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) RunAndReturn(run func(context.Context, int64, string, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByChannel_Call { _c.Call.Return(run) return _c } -// GetSealedSegmentsByCollection provides a mock function with given fields: collectionID, scope -func (_m *MockTargetManager) GetSealedSegmentsByCollection(collectionID int64, scope int32) map[int64]*datapb.SegmentInfo { - ret := _m.Called(collectionID, scope) +// GetSealedSegmentsByCollection provides a mock function with given fields: ctx, collectionID, scope +func (_m *MockTargetManager) GetSealedSegmentsByCollection(ctx context.Context, collectionID int64, scope int32) map[int64]*datapb.SegmentInfo { + ret := _m.Called(ctx, collectionID, scope) if len(ret) == 0 { panic("no return value specified for GetSealedSegmentsByCollection") } var r0 map[int64]*datapb.SegmentInfo - if rf, ok := ret.Get(0).(func(int64, int32) map[int64]*datapb.SegmentInfo); ok { - r0 = rf(collectionID, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, int32) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(ctx, collectionID, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) @@ -492,15 +503,16 @@ type MockTargetManager_GetSealedSegmentsByCollection_Call struct { } // GetSealedSegmentsByCollection is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - scope int32 -func (_e *MockTargetManager_Expecter) GetSealedSegmentsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByCollection_Call { - return &MockTargetManager_GetSealedSegmentsByCollection_Call{Call: _e.mock.On("GetSealedSegmentsByCollection", collectionID, scope)} +func (_e *MockTargetManager_Expecter) GetSealedSegmentsByCollection(ctx interface{}, collectionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByCollection_Call { + return &MockTargetManager_GetSealedSegmentsByCollection_Call{Call: _e.mock.On("GetSealedSegmentsByCollection", ctx, collectionID, scope)} } -func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByCollection_Call { +func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Run(run func(ctx context.Context, collectionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByCollection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(int32)) }) return _c } @@ -510,22 +522,22 @@ func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Return(_a0 map[i return _c } -func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) RunAndReturn(run func(int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByCollection_Call { +func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) RunAndReturn(run func(context.Context, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByCollection_Call { _c.Call.Return(run) return _c } -// GetSealedSegmentsByPartition provides a mock function with given fields: collectionID, partitionID, scope -func (_m *MockTargetManager) GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope int32) map[int64]*datapb.SegmentInfo { - ret := _m.Called(collectionID, partitionID, scope) +// GetSealedSegmentsByPartition provides a mock function with given fields: ctx, collectionID, partitionID, scope +func (_m *MockTargetManager) GetSealedSegmentsByPartition(ctx context.Context, collectionID int64, partitionID int64, scope int32) map[int64]*datapb.SegmentInfo { + ret := _m.Called(ctx, collectionID, partitionID, scope) if len(ret) == 0 { panic("no return value specified for GetSealedSegmentsByPartition") } var r0 map[int64]*datapb.SegmentInfo - if rf, ok := ret.Get(0).(func(int64, int64, int32) map[int64]*datapb.SegmentInfo); ok { - r0 = rf(collectionID, partitionID, scope) + if rf, ok := ret.Get(0).(func(context.Context, int64, int64, int32) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(ctx, collectionID, partitionID, scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) @@ -541,16 +553,17 @@ type MockTargetManager_GetSealedSegmentsByPartition_Call struct { } // GetSealedSegmentsByPartition is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - partitionID int64 // - scope int32 -func (_e *MockTargetManager_Expecter) GetSealedSegmentsByPartition(collectionID interface{}, partitionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByPartition_Call { - return &MockTargetManager_GetSealedSegmentsByPartition_Call{Call: _e.mock.On("GetSealedSegmentsByPartition", collectionID, partitionID, scope)} +func (_e *MockTargetManager_Expecter) GetSealedSegmentsByPartition(ctx interface{}, collectionID interface{}, partitionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByPartition_Call { + return &MockTargetManager_GetSealedSegmentsByPartition_Call{Call: _e.mock.On("GetSealedSegmentsByPartition", ctx, collectionID, partitionID, scope)} } -func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Run(run func(collectionID int64, partitionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByPartition_Call { +func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Run(run func(ctx context.Context, collectionID int64, partitionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByPartition_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int64), args[2].(int32)) + run(args[0].(context.Context), args[1].(int64), args[2].(int64), args[3].(int32)) }) return _c } @@ -560,22 +573,22 @@ func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Return(_a0 map[in return _c } -func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) RunAndReturn(run func(int64, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByPartition_Call { +func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) RunAndReturn(run func(context.Context, int64, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByPartition_Call { _c.Call.Return(run) return _c } -// GetTargetJSON provides a mock function with given fields: scope -func (_m *MockTargetManager) GetTargetJSON(scope int32) string { - ret := _m.Called(scope) +// GetTargetJSON provides a mock function with given fields: ctx, scope +func (_m *MockTargetManager) GetTargetJSON(ctx context.Context, scope int32) string { + ret := _m.Called(ctx, scope) if len(ret) == 0 { panic("no return value specified for GetTargetJSON") } var r0 string - if rf, ok := ret.Get(0).(func(int32) string); ok { - r0 = rf(scope) + if rf, ok := ret.Get(0).(func(context.Context, int32) string); ok { + r0 = rf(ctx, scope) } else { r0 = ret.Get(0).(string) } @@ -589,14 +602,15 @@ type MockTargetManager_GetTargetJSON_Call struct { } // GetTargetJSON is a helper method to define mock.On call +// - ctx context.Context // - scope int32 -func (_e *MockTargetManager_Expecter) GetTargetJSON(scope interface{}) *MockTargetManager_GetTargetJSON_Call { - return &MockTargetManager_GetTargetJSON_Call{Call: _e.mock.On("GetTargetJSON", scope)} +func (_e *MockTargetManager_Expecter) GetTargetJSON(ctx interface{}, scope interface{}) *MockTargetManager_GetTargetJSON_Call { + return &MockTargetManager_GetTargetJSON_Call{Call: _e.mock.On("GetTargetJSON", ctx, scope)} } -func (_c *MockTargetManager_GetTargetJSON_Call) Run(run func(scope int32)) *MockTargetManager_GetTargetJSON_Call { +func (_c *MockTargetManager_GetTargetJSON_Call) Run(run func(ctx context.Context, scope int32)) *MockTargetManager_GetTargetJSON_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int32)) + run(args[0].(context.Context), args[1].(int32)) }) return _c } @@ -606,22 +620,22 @@ func (_c *MockTargetManager_GetTargetJSON_Call) Return(_a0 string) *MockTargetMa return _c } -func (_c *MockTargetManager_GetTargetJSON_Call) RunAndReturn(run func(int32) string) *MockTargetManager_GetTargetJSON_Call { +func (_c *MockTargetManager_GetTargetJSON_Call) RunAndReturn(run func(context.Context, int32) string) *MockTargetManager_GetTargetJSON_Call { _c.Call.Return(run) return _c } -// IsCurrentTargetExist provides a mock function with given fields: collectionID, partitionID -func (_m *MockTargetManager) IsCurrentTargetExist(collectionID int64, partitionID int64) bool { - ret := _m.Called(collectionID, partitionID) +// IsCurrentTargetExist provides a mock function with given fields: ctx, collectionID, partitionID +func (_m *MockTargetManager) IsCurrentTargetExist(ctx context.Context, collectionID int64, partitionID int64) bool { + ret := _m.Called(ctx, collectionID, partitionID) if len(ret) == 0 { panic("no return value specified for IsCurrentTargetExist") } var r0 bool - if rf, ok := ret.Get(0).(func(int64, int64) bool); ok { - r0 = rf(collectionID, partitionID) + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) bool); ok { + r0 = rf(ctx, collectionID, partitionID) } else { r0 = ret.Get(0).(bool) } @@ -635,15 +649,16 @@ type MockTargetManager_IsCurrentTargetExist_Call struct { } // IsCurrentTargetExist is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - partitionID int64 -func (_e *MockTargetManager_Expecter) IsCurrentTargetExist(collectionID interface{}, partitionID interface{}) *MockTargetManager_IsCurrentTargetExist_Call { - return &MockTargetManager_IsCurrentTargetExist_Call{Call: _e.mock.On("IsCurrentTargetExist", collectionID, partitionID)} +func (_e *MockTargetManager_Expecter) IsCurrentTargetExist(ctx interface{}, collectionID interface{}, partitionID interface{}) *MockTargetManager_IsCurrentTargetExist_Call { + return &MockTargetManager_IsCurrentTargetExist_Call{Call: _e.mock.On("IsCurrentTargetExist", ctx, collectionID, partitionID)} } -func (_c *MockTargetManager_IsCurrentTargetExist_Call) Run(run func(collectionID int64, partitionID int64)) *MockTargetManager_IsCurrentTargetExist_Call { +func (_c *MockTargetManager_IsCurrentTargetExist_Call) Run(run func(ctx context.Context, collectionID int64, partitionID int64)) *MockTargetManager_IsCurrentTargetExist_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].(int64)) + run(args[0].(context.Context), args[1].(int64), args[2].(int64)) }) return _c } @@ -653,22 +668,22 @@ func (_c *MockTargetManager_IsCurrentTargetExist_Call) Return(_a0 bool) *MockTar return _c } -func (_c *MockTargetManager_IsCurrentTargetExist_Call) RunAndReturn(run func(int64, int64) bool) *MockTargetManager_IsCurrentTargetExist_Call { +func (_c *MockTargetManager_IsCurrentTargetExist_Call) RunAndReturn(run func(context.Context, int64, int64) bool) *MockTargetManager_IsCurrentTargetExist_Call { _c.Call.Return(run) return _c } -// IsNextTargetExist provides a mock function with given fields: collectionID -func (_m *MockTargetManager) IsNextTargetExist(collectionID int64) bool { - ret := _m.Called(collectionID) +// IsNextTargetExist provides a mock function with given fields: ctx, collectionID +func (_m *MockTargetManager) IsNextTargetExist(ctx context.Context, collectionID int64) bool { + ret := _m.Called(ctx, collectionID) if len(ret) == 0 { panic("no return value specified for IsNextTargetExist") } var r0 bool - if rf, ok := ret.Get(0).(func(int64) bool); ok { - r0 = rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) bool); ok { + r0 = rf(ctx, collectionID) } else { r0 = ret.Get(0).(bool) } @@ -682,14 +697,15 @@ type MockTargetManager_IsNextTargetExist_Call struct { } // IsNextTargetExist is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *MockTargetManager_Expecter) IsNextTargetExist(collectionID interface{}) *MockTargetManager_IsNextTargetExist_Call { - return &MockTargetManager_IsNextTargetExist_Call{Call: _e.mock.On("IsNextTargetExist", collectionID)} +func (_e *MockTargetManager_Expecter) IsNextTargetExist(ctx interface{}, collectionID interface{}) *MockTargetManager_IsNextTargetExist_Call { + return &MockTargetManager_IsNextTargetExist_Call{Call: _e.mock.On("IsNextTargetExist", ctx, collectionID)} } -func (_c *MockTargetManager_IsNextTargetExist_Call) Run(run func(collectionID int64)) *MockTargetManager_IsNextTargetExist_Call { +func (_c *MockTargetManager_IsNextTargetExist_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_IsNextTargetExist_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -699,22 +715,22 @@ func (_c *MockTargetManager_IsNextTargetExist_Call) Return(_a0 bool) *MockTarget return _c } -func (_c *MockTargetManager_IsNextTargetExist_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_IsNextTargetExist_Call { +func (_c *MockTargetManager_IsNextTargetExist_Call) RunAndReturn(run func(context.Context, int64) bool) *MockTargetManager_IsNextTargetExist_Call { _c.Call.Return(run) return _c } -// Recover provides a mock function with given fields: catalog -func (_m *MockTargetManager) Recover(catalog metastore.QueryCoordCatalog) error { - ret := _m.Called(catalog) +// Recover provides a mock function with given fields: ctx, catalog +func (_m *MockTargetManager) Recover(ctx context.Context, catalog metastore.QueryCoordCatalog) error { + ret := _m.Called(ctx, catalog) if len(ret) == 0 { panic("no return value specified for Recover") } var r0 error - if rf, ok := ret.Get(0).(func(metastore.QueryCoordCatalog) error); ok { - r0 = rf(catalog) + if rf, ok := ret.Get(0).(func(context.Context, metastore.QueryCoordCatalog) error); ok { + r0 = rf(ctx, catalog) } else { r0 = ret.Error(0) } @@ -728,14 +744,15 @@ type MockTargetManager_Recover_Call struct { } // Recover is a helper method to define mock.On call +// - ctx context.Context // - catalog metastore.QueryCoordCatalog -func (_e *MockTargetManager_Expecter) Recover(catalog interface{}) *MockTargetManager_Recover_Call { - return &MockTargetManager_Recover_Call{Call: _e.mock.On("Recover", catalog)} +func (_e *MockTargetManager_Expecter) Recover(ctx interface{}, catalog interface{}) *MockTargetManager_Recover_Call { + return &MockTargetManager_Recover_Call{Call: _e.mock.On("Recover", ctx, catalog)} } -func (_c *MockTargetManager_Recover_Call) Run(run func(catalog metastore.QueryCoordCatalog)) *MockTargetManager_Recover_Call { +func (_c *MockTargetManager_Recover_Call) Run(run func(ctx context.Context, catalog metastore.QueryCoordCatalog)) *MockTargetManager_Recover_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(metastore.QueryCoordCatalog)) + run(args[0].(context.Context), args[1].(metastore.QueryCoordCatalog)) }) return _c } @@ -745,14 +762,14 @@ func (_c *MockTargetManager_Recover_Call) Return(_a0 error) *MockTargetManager_R return _c } -func (_c *MockTargetManager_Recover_Call) RunAndReturn(run func(metastore.QueryCoordCatalog) error) *MockTargetManager_Recover_Call { +func (_c *MockTargetManager_Recover_Call) RunAndReturn(run func(context.Context, metastore.QueryCoordCatalog) error) *MockTargetManager_Recover_Call { _c.Call.Return(run) return _c } -// RemoveCollection provides a mock function with given fields: collectionID -func (_m *MockTargetManager) RemoveCollection(collectionID int64) { - _m.Called(collectionID) +// RemoveCollection provides a mock function with given fields: ctx, collectionID +func (_m *MockTargetManager) RemoveCollection(ctx context.Context, collectionID int64) { + _m.Called(ctx, collectionID) } // MockTargetManager_RemoveCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCollection' @@ -761,14 +778,15 @@ type MockTargetManager_RemoveCollection_Call struct { } // RemoveCollection is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *MockTargetManager_Expecter) RemoveCollection(collectionID interface{}) *MockTargetManager_RemoveCollection_Call { - return &MockTargetManager_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", collectionID)} +func (_e *MockTargetManager_Expecter) RemoveCollection(ctx interface{}, collectionID interface{}) *MockTargetManager_RemoveCollection_Call { + return &MockTargetManager_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", ctx, collectionID)} } -func (_c *MockTargetManager_RemoveCollection_Call) Run(run func(collectionID int64)) *MockTargetManager_RemoveCollection_Call { +func (_c *MockTargetManager_RemoveCollection_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_RemoveCollection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -778,19 +796,19 @@ func (_c *MockTargetManager_RemoveCollection_Call) Return() *MockTargetManager_R return _c } -func (_c *MockTargetManager_RemoveCollection_Call) RunAndReturn(run func(int64)) *MockTargetManager_RemoveCollection_Call { +func (_c *MockTargetManager_RemoveCollection_Call) RunAndReturn(run func(context.Context, int64)) *MockTargetManager_RemoveCollection_Call { _c.Call.Return(run) return _c } -// RemovePartition provides a mock function with given fields: collectionID, partitionIDs -func (_m *MockTargetManager) RemovePartition(collectionID int64, partitionIDs ...int64) { +// RemovePartition provides a mock function with given fields: ctx, collectionID, partitionIDs +func (_m *MockTargetManager) RemovePartition(ctx context.Context, collectionID int64, partitionIDs ...int64) { _va := make([]interface{}, len(partitionIDs)) for _i := range partitionIDs { _va[_i] = partitionIDs[_i] } var _ca []interface{} - _ca = append(_ca, collectionID) + _ca = append(_ca, ctx, collectionID) _ca = append(_ca, _va...) _m.Called(_ca...) } @@ -801,22 +819,23 @@ type MockTargetManager_RemovePartition_Call struct { } // RemovePartition is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 // - partitionIDs ...int64 -func (_e *MockTargetManager_Expecter) RemovePartition(collectionID interface{}, partitionIDs ...interface{}) *MockTargetManager_RemovePartition_Call { +func (_e *MockTargetManager_Expecter) RemovePartition(ctx interface{}, collectionID interface{}, partitionIDs ...interface{}) *MockTargetManager_RemovePartition_Call { return &MockTargetManager_RemovePartition_Call{Call: _e.mock.On("RemovePartition", - append([]interface{}{collectionID}, partitionIDs...)...)} + append([]interface{}{ctx, collectionID}, partitionIDs...)...)} } -func (_c *MockTargetManager_RemovePartition_Call) Run(run func(collectionID int64, partitionIDs ...int64)) *MockTargetManager_RemovePartition_Call { +func (_c *MockTargetManager_RemovePartition_Call) Run(run func(ctx context.Context, collectionID int64, partitionIDs ...int64)) *MockTargetManager_RemovePartition_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]int64, len(args)-1) - for i, a := range args[1:] { + variadicArgs := make([]int64, len(args)-2) + for i, a := range args[2:] { if a != nil { variadicArgs[i] = a.(int64) } } - run(args[0].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), variadicArgs...) }) return _c } @@ -826,14 +845,14 @@ func (_c *MockTargetManager_RemovePartition_Call) Return() *MockTargetManager_Re return _c } -func (_c *MockTargetManager_RemovePartition_Call) RunAndReturn(run func(int64, ...int64)) *MockTargetManager_RemovePartition_Call { +func (_c *MockTargetManager_RemovePartition_Call) RunAndReturn(run func(context.Context, int64, ...int64)) *MockTargetManager_RemovePartition_Call { _c.Call.Return(run) return _c } -// SaveCurrentTarget provides a mock function with given fields: catalog -func (_m *MockTargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) { - _m.Called(catalog) +// SaveCurrentTarget provides a mock function with given fields: ctx, catalog +func (_m *MockTargetManager) SaveCurrentTarget(ctx context.Context, catalog metastore.QueryCoordCatalog) { + _m.Called(ctx, catalog) } // MockTargetManager_SaveCurrentTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCurrentTarget' @@ -842,14 +861,15 @@ type MockTargetManager_SaveCurrentTarget_Call struct { } // SaveCurrentTarget is a helper method to define mock.On call +// - ctx context.Context // - catalog metastore.QueryCoordCatalog -func (_e *MockTargetManager_Expecter) SaveCurrentTarget(catalog interface{}) *MockTargetManager_SaveCurrentTarget_Call { - return &MockTargetManager_SaveCurrentTarget_Call{Call: _e.mock.On("SaveCurrentTarget", catalog)} +func (_e *MockTargetManager_Expecter) SaveCurrentTarget(ctx interface{}, catalog interface{}) *MockTargetManager_SaveCurrentTarget_Call { + return &MockTargetManager_SaveCurrentTarget_Call{Call: _e.mock.On("SaveCurrentTarget", ctx, catalog)} } -func (_c *MockTargetManager_SaveCurrentTarget_Call) Run(run func(catalog metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call { +func (_c *MockTargetManager_SaveCurrentTarget_Call) Run(run func(ctx context.Context, catalog metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(metastore.QueryCoordCatalog)) + run(args[0].(context.Context), args[1].(metastore.QueryCoordCatalog)) }) return _c } @@ -859,22 +879,22 @@ func (_c *MockTargetManager_SaveCurrentTarget_Call) Return() *MockTargetManager_ return _c } -func (_c *MockTargetManager_SaveCurrentTarget_Call) RunAndReturn(run func(metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call { +func (_c *MockTargetManager_SaveCurrentTarget_Call) RunAndReturn(run func(context.Context, metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call { _c.Call.Return(run) return _c } -// UpdateCollectionCurrentTarget provides a mock function with given fields: collectionID -func (_m *MockTargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool { - ret := _m.Called(collectionID) +// UpdateCollectionCurrentTarget provides a mock function with given fields: ctx, collectionID +func (_m *MockTargetManager) UpdateCollectionCurrentTarget(ctx context.Context, collectionID int64) bool { + ret := _m.Called(ctx, collectionID) if len(ret) == 0 { panic("no return value specified for UpdateCollectionCurrentTarget") } var r0 bool - if rf, ok := ret.Get(0).(func(int64) bool); ok { - r0 = rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) bool); ok { + r0 = rf(ctx, collectionID) } else { r0 = ret.Get(0).(bool) } @@ -888,14 +908,15 @@ type MockTargetManager_UpdateCollectionCurrentTarget_Call struct { } // UpdateCollectionCurrentTarget is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *MockTargetManager_Expecter) UpdateCollectionCurrentTarget(collectionID interface{}) *MockTargetManager_UpdateCollectionCurrentTarget_Call { - return &MockTargetManager_UpdateCollectionCurrentTarget_Call{Call: _e.mock.On("UpdateCollectionCurrentTarget", collectionID)} +func (_e *MockTargetManager_Expecter) UpdateCollectionCurrentTarget(ctx interface{}, collectionID interface{}) *MockTargetManager_UpdateCollectionCurrentTarget_Call { + return &MockTargetManager_UpdateCollectionCurrentTarget_Call{Call: _e.mock.On("UpdateCollectionCurrentTarget", ctx, collectionID)} } -func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Run(run func(collectionID int64)) *MockTargetManager_UpdateCollectionCurrentTarget_Call { +func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_UpdateCollectionCurrentTarget_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -905,22 +926,22 @@ func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Return(_a0 bool) return _c } -func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_UpdateCollectionCurrentTarget_Call { +func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) RunAndReturn(run func(context.Context, int64) bool) *MockTargetManager_UpdateCollectionCurrentTarget_Call { _c.Call.Return(run) return _c } -// UpdateCollectionNextTarget provides a mock function with given fields: collectionID -func (_m *MockTargetManager) UpdateCollectionNextTarget(collectionID int64) error { - ret := _m.Called(collectionID) +// UpdateCollectionNextTarget provides a mock function with given fields: ctx, collectionID +func (_m *MockTargetManager) UpdateCollectionNextTarget(ctx context.Context, collectionID int64) error { + ret := _m.Called(ctx, collectionID) if len(ret) == 0 { panic("no return value specified for UpdateCollectionNextTarget") } var r0 error - if rf, ok := ret.Get(0).(func(int64) error); ok { - r0 = rf(collectionID) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, collectionID) } else { r0 = ret.Error(0) } @@ -934,14 +955,15 @@ type MockTargetManager_UpdateCollectionNextTarget_Call struct { } // UpdateCollectionNextTarget is a helper method to define mock.On call +// - ctx context.Context // - collectionID int64 -func (_e *MockTargetManager_Expecter) UpdateCollectionNextTarget(collectionID interface{}) *MockTargetManager_UpdateCollectionNextTarget_Call { - return &MockTargetManager_UpdateCollectionNextTarget_Call{Call: _e.mock.On("UpdateCollectionNextTarget", collectionID)} +func (_e *MockTargetManager_Expecter) UpdateCollectionNextTarget(ctx interface{}, collectionID interface{}) *MockTargetManager_UpdateCollectionNextTarget_Call { + return &MockTargetManager_UpdateCollectionNextTarget_Call{Call: _e.mock.On("UpdateCollectionNextTarget", ctx, collectionID)} } -func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Run(run func(collectionID int64)) *MockTargetManager_UpdateCollectionNextTarget_Call { +func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Run(run func(ctx context.Context, collectionID int64)) *MockTargetManager_UpdateCollectionNextTarget_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) + run(args[0].(context.Context), args[1].(int64)) }) return _c } @@ -951,7 +973,7 @@ func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Return(_a0 error) * return _c } -func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) RunAndReturn(run func(int64) error) *MockTargetManager_UpdateCollectionNextTarget_Call { +func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) RunAndReturn(run func(context.Context, int64) error) *MockTargetManager_UpdateCollectionNextTarget_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/meta/replica_manager.go b/internal/querycoordv2/meta/replica_manager.go index a4263e1a55435..7fcc7ea5df306 100644 --- a/internal/querycoordv2/meta/replica_manager.go +++ b/internal/querycoordv2/meta/replica_manager.go @@ -17,6 +17,7 @@ package meta import ( + "context" "fmt" "sync" @@ -78,8 +79,8 @@ func NewReplicaManager(idAllocator func() (int64, error), catalog metastore.Quer } // Recover recovers the replicas for given collections from meta store -func (m *ReplicaManager) Recover(collections []int64) error { - replicas, err := m.catalog.GetReplicas() +func (m *ReplicaManager) Recover(ctx context.Context, collections []int64) error { + replicas, err := m.catalog.GetReplicas(ctx) if err != nil { return fmt.Errorf("failed to recover replicas, err=%w", err) } @@ -98,7 +99,7 @@ func (m *ReplicaManager) Recover(collections []int64) error { zap.Int64s("nodes", replica.GetNodes()), ) } else { - err := m.catalog.ReleaseReplica(replica.GetCollectionID(), replica.GetID()) + err := m.catalog.ReleaseReplica(ctx, replica.GetCollectionID(), replica.GetID()) if err != nil { return err } @@ -114,7 +115,7 @@ func (m *ReplicaManager) Recover(collections []int64) error { // Get returns the replica by id. // Replica should be read-only, do not modify it. -func (m *ReplicaManager) Get(id typeutil.UniqueID) *Replica { +func (m *ReplicaManager) Get(ctx context.Context, id typeutil.UniqueID) *Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -122,7 +123,7 @@ func (m *ReplicaManager) Get(id typeutil.UniqueID) *Replica { } // Spawn spawns N replicas at resource group for given collection in ReplicaManager. -func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int, channels []string) ([]*Replica, error) { +func (m *ReplicaManager) Spawn(ctx context.Context, collection int64, replicaNumInRG map[string]int, channels []string) ([]*Replica, error) { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -151,7 +152,7 @@ func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int, })) } } - if err := m.put(replicas...); err != nil { + if err := m.put(ctx, replicas...); err != nil { return nil, err } return replicas, nil @@ -159,14 +160,14 @@ func (m *ReplicaManager) Spawn(collection int64, replicaNumInRG map[string]int, // Deprecated: Warning, break the consistency of ReplicaManager, // never use it in non-test code, use Spawn instead. -func (m *ReplicaManager) Put(replicas ...*Replica) error { +func (m *ReplicaManager) Put(ctx context.Context, replicas ...*Replica) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() - return m.put(replicas...) + return m.put(ctx, replicas...) } -func (m *ReplicaManager) put(replicas ...*Replica) error { +func (m *ReplicaManager) put(ctx context.Context, replicas ...*Replica) error { if len(replicas) == 0 { return nil } @@ -175,7 +176,7 @@ func (m *ReplicaManager) put(replicas ...*Replica) error { for _, replica := range replicas { replicaPBs = append(replicaPBs, replica.replicaPB) } - if err := m.catalog.SaveReplica(replicaPBs...); err != nil { + if err := m.catalog.SaveReplica(ctx, replicaPBs...); err != nil { return err } @@ -198,7 +199,7 @@ func (m *ReplicaManager) putReplicaInMemory(replicas ...*Replica) { } // TransferReplica transfers N replicas from srcRGName to dstRGName. -func (m *ReplicaManager) TransferReplica(collectionID typeutil.UniqueID, srcRGName string, dstRGName string, replicaNum int) error { +func (m *ReplicaManager) TransferReplica(ctx context.Context, collectionID typeutil.UniqueID, srcRGName string, dstRGName string, replicaNum int) error { if srcRGName == dstRGName { return merr.WrapErrParameterInvalidMsg("source resource group and target resource group should not be the same, resource group: %s", srcRGName) } @@ -223,10 +224,10 @@ func (m *ReplicaManager) TransferReplica(collectionID typeutil.UniqueID, srcRGNa mutableReplica.SetResourceGroup(dstRGName) replicas = append(replicas, mutableReplica.IntoReplica()) } - return m.put(replicas...) + return m.put(ctx, replicas...) } -func (m *ReplicaManager) MoveReplica(dstRGName string, toMove []*Replica) error { +func (m *ReplicaManager) MoveReplica(ctx context.Context, dstRGName string, toMove []*Replica) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() replicas := make([]*Replica, 0, len(toMove)) @@ -238,7 +239,7 @@ func (m *ReplicaManager) MoveReplica(dstRGName string, toMove []*Replica) error replicaIDs = append(replicaIDs, replica.GetID()) } log.Info("move replicas to resource group", zap.String("dstRGName", dstRGName), zap.Int64s("replicas", replicaIDs)) - return m.put(replicas...) + return m.put(ctx, replicas...) } // getSrcReplicasAndCheckIfTransferable checks if the collection can be transfer from srcRGName to dstRGName. @@ -267,11 +268,11 @@ func (m *ReplicaManager) getSrcReplicasAndCheckIfTransferable(collectionID typeu // RemoveCollection removes replicas of given collection, // returns error if failed to remove replica from KV -func (m *ReplicaManager) RemoveCollection(collectionID typeutil.UniqueID) error { +func (m *ReplicaManager) RemoveCollection(ctx context.Context, collectionID typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() - err := m.catalog.ReleaseReplicas(collectionID) + err := m.catalog.ReleaseReplicas(ctx, collectionID) if err != nil { return err } @@ -286,17 +287,17 @@ func (m *ReplicaManager) RemoveCollection(collectionID typeutil.UniqueID) error return nil } -func (m *ReplicaManager) RemoveReplicas(collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error { +func (m *ReplicaManager) RemoveReplicas(ctx context.Context, collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() log.Info("release replicas", zap.Int64("collectionID", collectionID), zap.Int64s("replicas", replicas)) - return m.removeReplicas(collectionID, replicas...) + return m.removeReplicas(ctx, collectionID, replicas...) } -func (m *ReplicaManager) removeReplicas(collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error { - err := m.catalog.ReleaseReplica(collectionID, replicas...) +func (m *ReplicaManager) removeReplicas(ctx context.Context, collectionID typeutil.UniqueID, replicas ...typeutil.UniqueID) error { + err := m.catalog.ReleaseReplica(ctx, collectionID, replicas...) if err != nil { return err } @@ -312,7 +313,7 @@ func (m *ReplicaManager) removeReplicas(collectionID typeutil.UniqueID, replicas return nil } -func (m *ReplicaManager) GetByCollection(collectionID typeutil.UniqueID) []*Replica { +func (m *ReplicaManager) GetByCollection(ctx context.Context, collectionID typeutil.UniqueID) []*Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.getByCollection(collectionID) @@ -327,7 +328,7 @@ func (m *ReplicaManager) getByCollection(collectionID typeutil.UniqueID) []*Repl return collReplicas.replicas } -func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.UniqueID) *Replica { +func (m *ReplicaManager) GetByCollectionAndNode(ctx context.Context, collectionID, nodeID typeutil.UniqueID) *Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -342,7 +343,7 @@ func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.Un return nil } -func (m *ReplicaManager) GetByNode(nodeID typeutil.UniqueID) []*Replica { +func (m *ReplicaManager) GetByNode(ctx context.Context, nodeID typeutil.UniqueID) []*Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -367,7 +368,7 @@ func (m *ReplicaManager) getByCollectionAndRG(collectionID int64, rgName string) }) } -func (m *ReplicaManager) GetByResourceGroup(rgName string) []*Replica { +func (m *ReplicaManager) GetByResourceGroup(ctx context.Context, rgName string) []*Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -386,7 +387,7 @@ func (m *ReplicaManager) GetByResourceGroup(rgName string) []*Replica { // 1. Move the rw nodes to ro nodes if they are not in related resource group. // 2. Add new incoming nodes into the replica if they are not in-used by other replicas of same collection. // 3. replicas in same resource group will shared the nodes in resource group fairly. -func (m *ReplicaManager) RecoverNodesInCollection(collectionID typeutil.UniqueID, rgs map[string]typeutil.UniqueSet) error { +func (m *ReplicaManager) RecoverNodesInCollection(ctx context.Context, collectionID typeutil.UniqueID, rgs map[string]typeutil.UniqueSet) error { if err := m.validateResourceGroups(rgs); err != nil { return err } @@ -427,7 +428,7 @@ func (m *ReplicaManager) RecoverNodesInCollection(collectionID typeutil.UniqueID modifiedReplicas = append(modifiedReplicas, mutableReplica.IntoReplica()) }) }) - return m.put(modifiedReplicas...) + return m.put(ctx, modifiedReplicas...) } // validateResourceGroups checks if the resource groups are valid. @@ -468,7 +469,7 @@ func (m *ReplicaManager) getCollectionAssignmentHelper(collectionID typeutil.Uni } // RemoveNode removes the node from all replicas of given collection. -func (m *ReplicaManager) RemoveNode(replicaID typeutil.UniqueID, nodes ...typeutil.UniqueID) error { +func (m *ReplicaManager) RemoveNode(ctx context.Context, replicaID typeutil.UniqueID, nodes ...typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -479,11 +480,11 @@ func (m *ReplicaManager) RemoveNode(replicaID typeutil.UniqueID, nodes ...typeut mutableReplica := replica.CopyForWrite() mutableReplica.RemoveNode(nodes...) // ro -> unused - return m.put(mutableReplica.IntoReplica()) + return m.put(ctx, mutableReplica.IntoReplica()) } -func (m *ReplicaManager) GetResourceGroupByCollection(collection typeutil.UniqueID) typeutil.Set[string] { - replicas := m.GetByCollection(collection) +func (m *ReplicaManager) GetResourceGroupByCollection(ctx context.Context, collection typeutil.UniqueID) typeutil.Set[string] { + replicas := m.GetByCollection(ctx, collection) ret := typeutil.NewSet(lo.Map(replicas, func(r *Replica, _ int) string { return r.GetResourceGroup() })...) return ret } @@ -492,7 +493,7 @@ func (m *ReplicaManager) GetResourceGroupByCollection(collection typeutil.Unique // It locks the ReplicaManager for reading, converts the replicas to their protobuf representation, // marshals them into a JSON string, and returns the result. // If an error occurs during marshaling, it logs a warning and returns an empty string. -func (m *ReplicaManager) GetReplicasJSON() string { +func (m *ReplicaManager) GetReplicasJSON(ctx context.Context) string { m.rwmutex.RLock() defer m.rwmutex.RUnlock() diff --git a/internal/querycoordv2/meta/replica_manager_test.go b/internal/querycoordv2/meta/replica_manager_test.go index f81cc5dfe0b77..cdea379f9f233 100644 --- a/internal/querycoordv2/meta/replica_manager_test.go +++ b/internal/querycoordv2/meta/replica_manager_test.go @@ -17,6 +17,7 @@ package meta import ( + "context" "testing" "github.com/samber/lo" @@ -62,6 +63,7 @@ type ReplicaManagerSuite struct { kv kv.MetaKv catalog metastore.QueryCoordCatalog mgr *ReplicaManager + ctx context.Context } func (suite *ReplicaManagerSuite) SetupSuite() { @@ -86,6 +88,7 @@ func (suite *ReplicaManagerSuite) SetupSuite() { spawnConfig: map[string]int{"RG1": 1, "RG2": 1, "RG3": 1}, }, } + suite.ctx = context.Background() } func (suite *ReplicaManagerSuite) SetupTest() { @@ -114,16 +117,17 @@ func (suite *ReplicaManagerSuite) TearDownTest() { func (suite *ReplicaManagerSuite) TestSpawn() { mgr := suite.mgr + ctx := suite.ctx mgr.idAllocator = ErrorIDAllocator() - _, err := mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}, nil) + _, err := mgr.Spawn(ctx, 1, map[string]int{DefaultResourceGroupName: 1}, nil) suite.Error(err) - replicas := mgr.GetByCollection(1) + replicas := mgr.GetByCollection(ctx, 1) suite.Len(replicas, 0) mgr.idAllocator = suite.idAllocator - replicas, err = mgr.Spawn(1, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + replicas, err = mgr.Spawn(ctx, 1, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) suite.NoError(err) for _, replica := range replicas { suite.Len(replica.replicaPB.GetChannelNodeInfos(), 0) @@ -131,7 +135,7 @@ func (suite *ReplicaManagerSuite) TestSpawn() { paramtable.Get().Save(paramtable.Get().QueryCoordCfg.Balancer.Key, ChannelLevelScoreBalancerName) defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.Balancer.Key) - replicas, err = mgr.Spawn(2, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) + replicas, err = mgr.Spawn(ctx, 2, map[string]int{DefaultResourceGroupName: 1}, []string{"channel1", "channel2"}) suite.NoError(err) for _, replica := range replicas { suite.Len(replica.replicaPB.GetChannelNodeInfos(), 2) @@ -140,14 +144,15 @@ func (suite *ReplicaManagerSuite) TestSpawn() { func (suite *ReplicaManagerSuite) TestGet() { mgr := suite.mgr + ctx := suite.ctx for collectionID, collectionCfg := range suite.collections { - replicas := mgr.GetByCollection(collectionID) + replicas := mgr.GetByCollection(ctx, collectionID) replicaNodes := make(map[int64][]int64) nodes := make([]int64, 0) for _, replica := range replicas { suite.Equal(collectionID, replica.GetCollectionID()) - suite.Equal(replica, mgr.Get(replica.GetID())) + suite.Equal(replica, mgr.Get(ctx, replica.GetID())) suite.Equal(len(replica.replicaPB.GetNodes()), replica.RWNodesCount()) suite.Equal(replica.replicaPB.GetNodes(), replica.GetNodes()) replicaNodes[replica.GetID()] = replica.GetNodes() @@ -162,7 +167,7 @@ func (suite *ReplicaManagerSuite) TestGet() { for replicaID, nodes := range replicaNodes { for _, node := range nodes { - replica := mgr.GetByCollectionAndNode(collectionID, node) + replica := mgr.GetByCollectionAndNode(ctx, collectionID, node) suite.Equal(replicaID, replica.GetID()) } } @@ -171,6 +176,7 @@ func (suite *ReplicaManagerSuite) TestGet() { func (suite *ReplicaManagerSuite) TestGetByNode() { mgr := suite.mgr + ctx := suite.ctx randomNodeID := int64(11111) testReplica1 := newReplica(&querypb.Replica{ @@ -185,18 +191,19 @@ func (suite *ReplicaManagerSuite) TestGetByNode() { Nodes: []int64{randomNodeID}, ResourceGroup: DefaultResourceGroupName, }) - mgr.Put(testReplica1, testReplica2) + mgr.Put(ctx, testReplica1, testReplica2) - replicas := mgr.GetByNode(randomNodeID) + replicas := mgr.GetByNode(ctx, randomNodeID) suite.Len(replicas, 2) } func (suite *ReplicaManagerSuite) TestRecover() { mgr := suite.mgr + ctx := suite.ctx // Clear data in memory, and then recover from meta store suite.clearMemory() - mgr.Recover(lo.Keys(suite.collections)) + mgr.Recover(ctx, lo.Keys(suite.collections)) suite.TestGet() // Test recover from 2.1 meta store @@ -210,8 +217,8 @@ func (suite *ReplicaManagerSuite) TestRecover() { suite.kv.Save(querycoord.ReplicaMetaPrefixV1+"/2100", string(value)) suite.clearMemory() - mgr.Recover(append(lo.Keys(suite.collections), 1000)) - replica := mgr.Get(2100) + mgr.Recover(ctx, append(lo.Keys(suite.collections), 1000)) + replica := mgr.Get(ctx, 2100) suite.NotNil(replica) suite.EqualValues(1000, replica.GetCollectionID()) suite.EqualValues([]int64{1, 2, 3}, replica.GetNodes()) @@ -223,25 +230,27 @@ func (suite *ReplicaManagerSuite) TestRecover() { func (suite *ReplicaManagerSuite) TestRemove() { mgr := suite.mgr + ctx := suite.ctx for collection := range suite.collections { - err := mgr.RemoveCollection(collection) + err := mgr.RemoveCollection(ctx, collection) suite.NoError(err) - replicas := mgr.GetByCollection(collection) + replicas := mgr.GetByCollection(ctx, collection) suite.Empty(replicas) } // Check whether the replicas are also removed from meta store - mgr.Recover(lo.Keys(suite.collections)) + mgr.Recover(ctx, lo.Keys(suite.collections)) for collection := range suite.collections { - replicas := mgr.GetByCollection(collection) + replicas := mgr.GetByCollection(ctx, collection) suite.Empty(replicas) } } func (suite *ReplicaManagerSuite) TestNodeManipulate() { mgr := suite.mgr + ctx := suite.ctx // add node into rg. rgs := map[string]typeutil.UniqueSet{ @@ -256,10 +265,10 @@ func (suite *ReplicaManagerSuite) TestNodeManipulate() { for rg := range cfg.spawnConfig { rgsOfCollection[rg] = rgs[rg] } - mgr.RecoverNodesInCollection(collectionID, rgsOfCollection) + mgr.RecoverNodesInCollection(ctx, collectionID, rgsOfCollection) for rg := range cfg.spawnConfig { for _, node := range rgs[rg].Collect() { - replica := mgr.GetByCollectionAndNode(collectionID, node) + replica := mgr.GetByCollectionAndNode(ctx, collectionID, node) suite.Contains(replica.GetNodes(), node) } } @@ -267,11 +276,11 @@ func (suite *ReplicaManagerSuite) TestNodeManipulate() { // Check these modifications are applied to meta store suite.clearMemory() - mgr.Recover(lo.Keys(suite.collections)) + mgr.Recover(ctx, lo.Keys(suite.collections)) for collectionID, cfg := range suite.collections { for rg := range cfg.spawnConfig { for _, node := range rgs[rg].Collect() { - replica := mgr.GetByCollectionAndNode(collectionID, node) + replica := mgr.GetByCollectionAndNode(ctx, collectionID, node) suite.Contains(replica.GetNodes(), node) } } @@ -280,9 +289,10 @@ func (suite *ReplicaManagerSuite) TestNodeManipulate() { func (suite *ReplicaManagerSuite) spawnAll() { mgr := suite.mgr + ctx := suite.ctx for id, cfg := range suite.collections { - replicas, err := mgr.Spawn(id, cfg.spawnConfig, nil) + replicas, err := mgr.Spawn(ctx, id, cfg.spawnConfig, nil) suite.NoError(err) totalSpawn := 0 rgsOfCollection := make(map[string]typeutil.UniqueSet) @@ -290,26 +300,27 @@ func (suite *ReplicaManagerSuite) spawnAll() { totalSpawn += spawnNum rgsOfCollection[rg] = suite.rgs[rg] } - mgr.RecoverNodesInCollection(id, rgsOfCollection) + mgr.RecoverNodesInCollection(ctx, id, rgsOfCollection) suite.Len(replicas, totalSpawn) } } func (suite *ReplicaManagerSuite) TestResourceGroup() { mgr := NewReplicaManager(suite.idAllocator, suite.catalog) - replicas1, err := mgr.Spawn(int64(1000), map[string]int{DefaultResourceGroupName: 1}, nil) + ctx := suite.ctx + replicas1, err := mgr.Spawn(ctx, int64(1000), map[string]int{DefaultResourceGroupName: 1}, nil) suite.NoError(err) suite.NotNil(replicas1) suite.Len(replicas1, 1) - replica2, err := mgr.Spawn(int64(2000), map[string]int{DefaultResourceGroupName: 1}, nil) + replica2, err := mgr.Spawn(ctx, int64(2000), map[string]int{DefaultResourceGroupName: 1}, nil) suite.NoError(err) suite.NotNil(replica2) suite.Len(replica2, 1) - replicas := mgr.GetByResourceGroup(DefaultResourceGroupName) + replicas := mgr.GetByResourceGroup(ctx, DefaultResourceGroupName) suite.Len(replicas, 2) - rgNames := mgr.GetResourceGroupByCollection(int64(1000)) + rgNames := mgr.GetResourceGroupByCollection(ctx, int64(1000)) suite.Len(rgNames, 1) suite.True(rgNames.Contain(DefaultResourceGroupName)) } @@ -326,6 +337,7 @@ type ReplicaManagerV2Suite struct { kv kv.MetaKv catalog metastore.QueryCoordCatalog mgr *ReplicaManager + ctx context.Context } func (suite *ReplicaManagerV2Suite) SetupSuite() { @@ -375,6 +387,7 @@ func (suite *ReplicaManagerV2Suite) SetupSuite() { idAllocator := RandomIncrementIDAllocator() suite.mgr = NewReplicaManager(idAllocator, suite.catalog) + suite.ctx = context.Background() } func (suite *ReplicaManagerV2Suite) TearDownSuite() { @@ -383,32 +396,34 @@ func (suite *ReplicaManagerV2Suite) TearDownSuite() { func (suite *ReplicaManagerV2Suite) TestSpawn() { mgr := suite.mgr + ctx := suite.ctx for id, cfg := range suite.collections { - replicas, err := mgr.Spawn(id, cfg.spawnConfig, nil) + replicas, err := mgr.Spawn(ctx, id, cfg.spawnConfig, nil) suite.NoError(err) rgsOfCollection := make(map[string]typeutil.UniqueSet) for rg := range cfg.spawnConfig { rgsOfCollection[rg] = suite.rgs[rg] } - mgr.RecoverNodesInCollection(id, rgsOfCollection) + mgr.RecoverNodesInCollection(ctx, id, rgsOfCollection) for rg := range cfg.spawnConfig { for _, node := range suite.rgs[rg].Collect() { - replica := mgr.GetByCollectionAndNode(id, node) + replica := mgr.GetByCollectionAndNode(ctx, id, node) suite.Contains(replica.GetNodes(), node) } } suite.Len(replicas, cfg.getTotalSpawn()) - replicas = mgr.GetByCollection(id) + replicas = mgr.GetByCollection(ctx, id) suite.Len(replicas, cfg.getTotalSpawn()) } suite.testIfBalanced() } func (suite *ReplicaManagerV2Suite) testIfBalanced() { + ctx := suite.ctx // If balanced for id := range suite.collections { - replicas := suite.mgr.GetByCollection(id) + replicas := suite.mgr.GetByCollection(ctx, id) rgToReplica := make(map[string][]*Replica, 0) for _, r := range replicas { rgToReplica[r.GetResourceGroup()] = append(rgToReplica[r.GetResourceGroup()], r) @@ -440,22 +455,24 @@ func (suite *ReplicaManagerV2Suite) testIfBalanced() { } func (suite *ReplicaManagerV2Suite) TestTransferReplica() { + ctx := suite.ctx // param error - err := suite.mgr.TransferReplica(10086, "RG4", "RG5", 1) + err := suite.mgr.TransferReplica(ctx, 10086, "RG4", "RG5", 1) suite.Error(err) - err = suite.mgr.TransferReplica(1005, "RG4", "RG5", 0) + err = suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG5", 0) suite.Error(err) - err = suite.mgr.TransferReplica(1005, "RG4", "RG4", 1) + err = suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG4", 1) suite.Error(err) - err = suite.mgr.TransferReplica(1005, "RG4", "RG5", 1) + err = suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG5", 1) suite.NoError(err) suite.recoverReplica(2, true) suite.testIfBalanced() } func (suite *ReplicaManagerV2Suite) TestTransferReplicaAndAddNode() { - suite.mgr.TransferReplica(1005, "RG4", "RG5", 1) + ctx := suite.ctx + suite.mgr.TransferReplica(ctx, 1005, "RG4", "RG5", 1) suite.recoverReplica(1, false) suite.rgs["RG5"].Insert(16, 17, 18) suite.recoverReplica(2, true) @@ -470,6 +487,7 @@ func (suite *ReplicaManagerV2Suite) TestTransferNode() { } func (suite *ReplicaManagerV2Suite) recoverReplica(k int, clearOutbound bool) { + ctx := suite.ctx // need at least two times to recover the replicas. // transfer node between replicas need set to outbound and then set to incoming. for i := 0; i < k; i++ { @@ -479,16 +497,16 @@ func (suite *ReplicaManagerV2Suite) recoverReplica(k int, clearOutbound bool) { for rg := range cfg.spawnConfig { rgsOfCollection[rg] = suite.rgs[rg] } - suite.mgr.RecoverNodesInCollection(id, rgsOfCollection) + suite.mgr.RecoverNodesInCollection(ctx, id, rgsOfCollection) } // clear all outbound nodes if clearOutbound { for id := range suite.collections { - replicas := suite.mgr.GetByCollection(id) + replicas := suite.mgr.GetByCollection(ctx, id) for _, r := range replicas { outboundNodes := r.GetRONodes() - suite.mgr.RemoveNode(r.GetID(), outboundNodes...) + suite.mgr.RemoveNode(ctx, r.GetID(), outboundNodes...) } } } @@ -502,9 +520,10 @@ func TestReplicaManager(t *testing.T) { func TestGetReplicasJSON(t *testing.T) { catalog := mocks.NewQueryCoordCatalog(t) - catalog.EXPECT().SaveReplica(mock.Anything).Return(nil) + catalog.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil) idAllocator := RandomIncrementIDAllocator() replicaManager := NewReplicaManager(idAllocator, catalog) + ctx := context.Background() // Add some replicas to the ReplicaManager replica1 := newReplica(&querypb.Replica{ @@ -520,13 +539,13 @@ func TestGetReplicasJSON(t *testing.T) { Nodes: []int64{4, 5, 6}, }) - err := replicaManager.put(replica1) + err := replicaManager.put(ctx, replica1) assert.NoError(t, err) - err = replicaManager.put(replica2) + err = replicaManager.put(ctx, replica2) assert.NoError(t, err) - jsonOutput := replicaManager.GetReplicasJSON() + jsonOutput := replicaManager.GetReplicasJSON(ctx) var replicas []*metricsinfo.Replica err = json.Unmarshal([]byte(jsonOutput), &replicas) assert.NoError(t, err) diff --git a/internal/querycoordv2/meta/resource_manager.go b/internal/querycoordv2/meta/resource_manager.go index 408e2da246801..e2c9a9fc44f22 100644 --- a/internal/querycoordv2/meta/resource_manager.go +++ b/internal/querycoordv2/meta/resource_manager.go @@ -17,6 +17,7 @@ package meta import ( + "context" "fmt" "sync" @@ -77,11 +78,11 @@ func NewResourceManager(catalog metastore.QueryCoordCatalog, nodeMgr *session.No } // Recover recover resource group from meta, other interface of ResourceManager can be only called after recover is done. -func (rm *ResourceManager) Recover() error { +func (rm *ResourceManager) Recover(ctx context.Context) error { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() - rgs, err := rm.catalog.GetResourceGroups() + rgs, err := rm.catalog.GetResourceGroups(ctx) if err != nil { return errors.Wrap(err, "failed to recover resource group from store") } @@ -111,14 +112,14 @@ func (rm *ResourceManager) Recover() error { } if len(upgrades) > 0 { log.Info("upgrade resource group meta into latest", zap.Int("num", len(upgrades))) - return rm.catalog.SaveResourceGroup(upgrades...) + return rm.catalog.SaveResourceGroup(ctx, upgrades...) } return nil } // AddResourceGroup create a new ResourceGroup. // Do no changed with node, all node will be reassign to new resource group by auto recover. -func (rm *ResourceManager) AddResourceGroup(rgName string, cfg *rgpb.ResourceGroupConfig) error { +func (rm *ResourceManager) AddResourceGroup(ctx context.Context, rgName string, cfg *rgpb.ResourceGroupConfig) error { if len(rgName) == 0 { return merr.WrapErrParameterMissing("resource group name couldn't be empty") } @@ -148,7 +149,7 @@ func (rm *ResourceManager) AddResourceGroup(rgName string, cfg *rgpb.ResourceGro } rg := NewResourceGroup(rgName, cfg, rm.nodeMgr) - if err := rm.catalog.SaveResourceGroup(rg.GetMeta()); err != nil { + if err := rm.catalog.SaveResourceGroup(ctx, rg.GetMeta()); err != nil { log.Warn("failed to add resource group", zap.String("rgName", rgName), zap.Any("config", cfg), @@ -170,18 +171,18 @@ func (rm *ResourceManager) AddResourceGroup(rgName string, cfg *rgpb.ResourceGro // UpdateResourceGroups update resource group configuration. // Only change the configuration, no change with node. all node will be reassign by auto recover. -func (rm *ResourceManager) UpdateResourceGroups(rgs map[string]*rgpb.ResourceGroupConfig) error { +func (rm *ResourceManager) UpdateResourceGroups(ctx context.Context, rgs map[string]*rgpb.ResourceGroupConfig) error { if len(rgs) == 0 { return nil } rm.rwmutex.Lock() defer rm.rwmutex.Unlock() - return rm.updateResourceGroups(rgs) + return rm.updateResourceGroups(ctx, rgs) } // updateResourceGroups update resource group configuration. -func (rm *ResourceManager) updateResourceGroups(rgs map[string]*rgpb.ResourceGroupConfig) error { +func (rm *ResourceManager) updateResourceGroups(ctx context.Context, rgs map[string]*rgpb.ResourceGroupConfig) error { modifiedRG := make([]*ResourceGroup, 0, len(rgs)) updates := make([]*querypb.ResourceGroup, 0, len(rgs)) for rgName, cfg := range rgs { @@ -200,7 +201,7 @@ func (rm *ResourceManager) updateResourceGroups(rgs map[string]*rgpb.ResourceGro modifiedRG = append(modifiedRG, rg) } - if err := rm.catalog.SaveResourceGroup(updates...); err != nil { + if err := rm.catalog.SaveResourceGroup(ctx, updates...); err != nil { for rgName, cfg := range rgs { log.Warn("failed to update resource group", zap.String("rgName", rgName), @@ -227,7 +228,7 @@ func (rm *ResourceManager) updateResourceGroups(rgs map[string]*rgpb.ResourceGro // go:deprecated TransferNode transfer node from source resource group to target resource group. // Deprecated, use Declarative API `UpdateResourceGroups` instead. -func (rm *ResourceManager) TransferNode(sourceRGName string, targetRGName string, nodeNum int) error { +func (rm *ResourceManager) TransferNode(ctx context.Context, sourceRGName string, targetRGName string, nodeNum int) error { if sourceRGName == targetRGName { return merr.WrapErrParameterInvalidMsg("source resource group and target resource group should not be the same, resource group: %s", sourceRGName) } @@ -272,14 +273,14 @@ func (rm *ResourceManager) TransferNode(sourceRGName string, targetRGName string if targetCfg.Requests.NodeNum > targetCfg.Limits.NodeNum { targetCfg.Limits.NodeNum = targetCfg.Requests.NodeNum } - return rm.updateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + return rm.updateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ sourceRGName: sourceCfg, targetRGName: targetCfg, }) } // RemoveResourceGroup remove resource group. -func (rm *ResourceManager) RemoveResourceGroup(rgName string) error { +func (rm *ResourceManager) RemoveResourceGroup(ctx context.Context, rgName string) error { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() @@ -296,7 +297,7 @@ func (rm *ResourceManager) RemoveResourceGroup(rgName string) error { // Nodes may be still assign to these group, // recover the resource group from redundant status before remove it. if rm.groups[rgName].NodeNum() > 0 { - if err := rm.recoverRedundantNodeRG(rgName); err != nil { + if err := rm.recoverRedundantNodeRG(ctx, rgName); err != nil { log.Info("failed to recover redundant node resource group before remove it", zap.String("rgName", rgName), zap.Error(err), @@ -306,7 +307,7 @@ func (rm *ResourceManager) RemoveResourceGroup(rgName string) error { } // Remove it from meta storage. - if err := rm.catalog.RemoveResourceGroup(rgName); err != nil { + if err := rm.catalog.RemoveResourceGroup(ctx, rgName); err != nil { log.Info("failed to remove resource group", zap.String("rgName", rgName), zap.Error(err), @@ -327,7 +328,7 @@ func (rm *ResourceManager) RemoveResourceGroup(rgName string) error { } // GetNodesOfMultiRG return nodes of multi rg, it can be used to get a consistent view of nodes of multi rg. -func (rm *ResourceManager) GetNodesOfMultiRG(rgName []string) (map[string]typeutil.UniqueSet, error) { +func (rm *ResourceManager) GetNodesOfMultiRG(ctx context.Context, rgName []string) (map[string]typeutil.UniqueSet, error) { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() @@ -342,7 +343,7 @@ func (rm *ResourceManager) GetNodesOfMultiRG(rgName []string) (map[string]typeut } // GetNodes return nodes of given resource group. -func (rm *ResourceManager) GetNodes(rgName string) ([]int64, error) { +func (rm *ResourceManager) GetNodes(ctx context.Context, rgName string) ([]int64, error) { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() if rm.groups[rgName] == nil { @@ -352,7 +353,7 @@ func (rm *ResourceManager) GetNodes(rgName string) ([]int64, error) { } // GetResourceGroupByNodeID return whether resource group's node match required node count -func (rm *ResourceManager) VerifyNodeCount(requiredNodeCount map[string]int) error { +func (rm *ResourceManager) VerifyNodeCount(ctx context.Context, requiredNodeCount map[string]int) error { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() for rgName, nodeCount := range requiredNodeCount { @@ -368,7 +369,7 @@ func (rm *ResourceManager) VerifyNodeCount(requiredNodeCount map[string]int) err } // GetOutgoingNodeNumByReplica return outgoing node num on each rg from this replica. -func (rm *ResourceManager) GetOutgoingNodeNumByReplica(replica *Replica) map[string]int32 { +func (rm *ResourceManager) GetOutgoingNodeNumByReplica(ctx context.Context, replica *Replica) map[string]int32 { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() @@ -397,7 +398,7 @@ func (rm *ResourceManager) getResourceGroupByNodeID(nodeID int64) *ResourceGroup } // ContainsNode return whether given node is in given resource group. -func (rm *ResourceManager) ContainsNode(rgName string, node int64) bool { +func (rm *ResourceManager) ContainsNode(ctx context.Context, rgName string, node int64) bool { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() if rm.groups[rgName] == nil { @@ -407,14 +408,14 @@ func (rm *ResourceManager) ContainsNode(rgName string, node int64) bool { } // ContainResourceGroup return whether given resource group is exist. -func (rm *ResourceManager) ContainResourceGroup(rgName string) bool { +func (rm *ResourceManager) ContainResourceGroup(ctx context.Context, rgName string) bool { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() return rm.groups[rgName] != nil } // GetResourceGroup return resource group snapshot by name. -func (rm *ResourceManager) GetResourceGroup(rgName string) *ResourceGroup { +func (rm *ResourceManager) GetResourceGroup(ctx context.Context, rgName string) *ResourceGroup { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() @@ -425,7 +426,7 @@ func (rm *ResourceManager) GetResourceGroup(rgName string) *ResourceGroup { } // ListResourceGroups return all resource groups names. -func (rm *ResourceManager) ListResourceGroups() []string { +func (rm *ResourceManager) ListResourceGroups(ctx context.Context) []string { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() @@ -434,7 +435,7 @@ func (rm *ResourceManager) ListResourceGroups() []string { // MeetRequirement return whether resource group meet requirement. // Return error with reason if not meet requirement. -func (rm *ResourceManager) MeetRequirement(rgName string) error { +func (rm *ResourceManager) MeetRequirement(ctx context.Context, rgName string) error { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() if rm.groups[rgName] == nil { @@ -444,21 +445,21 @@ func (rm *ResourceManager) MeetRequirement(rgName string) error { } // CheckIncomingNodeNum return incoming node num. -func (rm *ResourceManager) CheckIncomingNodeNum() int { +func (rm *ResourceManager) CheckIncomingNodeNum(ctx context.Context) int { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() return rm.incomingNode.Len() } // HandleNodeUp handle node when new node is incoming. -func (rm *ResourceManager) HandleNodeUp(node int64) { +func (rm *ResourceManager) HandleNodeUp(ctx context.Context, node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() rm.incomingNode.Insert(node) // Trigger assign incoming node right away. // error can be ignored here, because `AssignPendingIncomingNode`` will retry assign node. - rgName, err := rm.assignIncomingNodeWithNodeCheck(node) + rgName, err := rm.assignIncomingNodeWithNodeCheck(ctx, node) log.Info("HandleNodeUp: add node to resource group", zap.String("rgName", rgName), zap.Int64("node", node), @@ -467,7 +468,7 @@ func (rm *ResourceManager) HandleNodeUp(node int64) { } // HandleNodeDown handle the node when node is leave. -func (rm *ResourceManager) HandleNodeDown(node int64) { +func (rm *ResourceManager) HandleNodeDown(ctx context.Context, node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() @@ -476,7 +477,7 @@ func (rm *ResourceManager) HandleNodeDown(node int64) { // for stopping query node becomes offline, node change won't be triggered, // cause when it becomes stopping, it already remove from resource manager // then `unassignNode` will do nothing - rgName, err := rm.unassignNode(node) + rgName, err := rm.unassignNode(ctx, node) // trigger node changes, expected to remove ro node from replica immediately rm.nodeChangedNotifier.NotifyAll() @@ -487,12 +488,12 @@ func (rm *ResourceManager) HandleNodeDown(node int64) { ) } -func (rm *ResourceManager) HandleNodeStopping(node int64) { +func (rm *ResourceManager) HandleNodeStopping(ctx context.Context, node int64) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() rm.incomingNode.Remove(node) - rgName, err := rm.unassignNode(node) + rgName, err := rm.unassignNode(ctx, node) log.Info("HandleNodeStopping: remove node from resource group", zap.String("rgName", rgName), zap.Int64("node", node), @@ -501,22 +502,22 @@ func (rm *ResourceManager) HandleNodeStopping(node int64) { } // ListenResourceGroupChanged return a listener for resource group changed. -func (rm *ResourceManager) ListenResourceGroupChanged() *syncutil.VersionedListener { +func (rm *ResourceManager) ListenResourceGroupChanged(ctx context.Context) *syncutil.VersionedListener { return rm.rgChangedNotifier.Listen(syncutil.VersionedListenAtEarliest) } // ListenNodeChanged return a listener for node changed. -func (rm *ResourceManager) ListenNodeChanged() *syncutil.VersionedListener { +func (rm *ResourceManager) ListenNodeChanged(ctx context.Context) *syncutil.VersionedListener { return rm.nodeChangedNotifier.Listen(syncutil.VersionedListenAtEarliest) } // AssignPendingIncomingNode assign incoming node to resource group. -func (rm *ResourceManager) AssignPendingIncomingNode() { +func (rm *ResourceManager) AssignPendingIncomingNode(ctx context.Context) { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() for node := range rm.incomingNode { - rgName, err := rm.assignIncomingNodeWithNodeCheck(node) + rgName, err := rm.assignIncomingNodeWithNodeCheck(ctx, node) log.Info("Pending HandleNodeUp: add node to resource group", zap.String("rgName", rgName), zap.Int64("node", node), @@ -526,7 +527,7 @@ func (rm *ResourceManager) AssignPendingIncomingNode() { } // AutoRecoverResourceGroup auto recover rg, return recover used node num -func (rm *ResourceManager) AutoRecoverResourceGroup(rgName string) error { +func (rm *ResourceManager) AutoRecoverResourceGroup(ctx context.Context, rgName string) error { rm.rwmutex.Lock() defer rm.rwmutex.Unlock() @@ -536,19 +537,19 @@ func (rm *ResourceManager) AutoRecoverResourceGroup(rgName string) error { } if rg.MissingNumOfNodes() > 0 { - return rm.recoverMissingNodeRG(rgName) + return rm.recoverMissingNodeRG(ctx, rgName) } // DefaultResourceGroup is the backup resource group of redundant recovery, // So after all other resource group is reach the `limits`, rest redundant node will be transfer to DefaultResourceGroup. if rg.RedundantNumOfNodes() > 0 { - return rm.recoverRedundantNodeRG(rgName) + return rm.recoverRedundantNodeRG(ctx, rgName) } return nil } // recoverMissingNodeRG recover resource group by transfer node from other resource group. -func (rm *ResourceManager) recoverMissingNodeRG(rgName string) error { +func (rm *ResourceManager) recoverMissingNodeRG(ctx context.Context, rgName string) error { for rm.groups[rgName].MissingNumOfNodes() > 0 { targetRG := rm.groups[rgName] node, sourceRG := rm.selectNodeForMissingRecover(targetRG) @@ -557,7 +558,7 @@ func (rm *ResourceManager) recoverMissingNodeRG(rgName string) error { return ErrNodeNotEnough } - err := rm.transferNode(targetRG.GetName(), node) + err := rm.transferNode(ctx, targetRG.GetName(), node) if err != nil { log.Warn("failed to recover missing node by transfer node from other resource group", zap.String("sourceRG", sourceRG.GetName()), @@ -622,7 +623,7 @@ func (rm *ResourceManager) selectNodeForMissingRecover(targetRG *ResourceGroup) } // recoverRedundantNodeRG recover resource group by transfer node to other resource group. -func (rm *ResourceManager) recoverRedundantNodeRG(rgName string) error { +func (rm *ResourceManager) recoverRedundantNodeRG(ctx context.Context, rgName string) error { for rm.groups[rgName].RedundantNumOfNodes() > 0 { sourceRG := rm.groups[rgName] node, targetRG := rm.selectNodeForRedundantRecover(sourceRG) @@ -632,7 +633,7 @@ func (rm *ResourceManager) recoverRedundantNodeRG(rgName string) error { return errors.New("all resource group reach limits") } - if err := rm.transferNode(targetRG.GetName(), node); err != nil { + if err := rm.transferNode(ctx, targetRG.GetName(), node); err != nil { log.Warn("failed to recover redundant node by transfer node to other resource group", zap.String("sourceRG", sourceRG.GetName()), zap.String("targetRG", targetRG.GetName()), @@ -704,7 +705,7 @@ func (rm *ResourceManager) selectNodeForRedundantRecover(sourceRG *ResourceGroup } // assignIncomingNodeWithNodeCheck assign node to resource group with node status check. -func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(node int64) (string, error) { +func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(ctx context.Context, node int64) (string, error) { // node is on stopping or stopped, remove it from incoming node set. if rm.nodeMgr.Get(node) == nil { rm.incomingNode.Remove(node) @@ -715,7 +716,7 @@ func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(node int64) (string, return "", errors.New("node has been stopped") } - rgName, err := rm.assignIncomingNode(node) + rgName, err := rm.assignIncomingNode(ctx, node) if err != nil { return "", err } @@ -725,7 +726,7 @@ func (rm *ResourceManager) assignIncomingNodeWithNodeCheck(node int64) (string, } // assignIncomingNode assign node to resource group. -func (rm *ResourceManager) assignIncomingNode(node int64) (string, error) { +func (rm *ResourceManager) assignIncomingNode(ctx context.Context, node int64) (string, error) { // If node already assign to rg. rg := rm.getResourceGroupByNodeID(node) if rg != nil { @@ -738,7 +739,7 @@ func (rm *ResourceManager) assignIncomingNode(node int64) (string, error) { // select a resource group to assign incoming node. rg = rm.mustSelectAssignIncomingNodeTargetRG(node) - if err := rm.transferNode(rg.GetName(), node); err != nil { + if err := rm.transferNode(ctx, rg.GetName(), node); err != nil { return "", errors.Wrap(err, "at finally assign to default resource group") } return rg.GetName(), nil @@ -791,7 +792,7 @@ func (rm *ResourceManager) findMaxRGWithGivenFilter(filter func(rg *ResourceGrou // transferNode transfer given node to given resource group. // if given node is assigned in given resource group, do nothing. // if given node is assigned to other resource group, it will be unassigned first. -func (rm *ResourceManager) transferNode(rgName string, node int64) error { +func (rm *ResourceManager) transferNode(ctx context.Context, rgName string, node int64) error { if rm.groups[rgName] == nil { return merr.WrapErrResourceGroupNotFound(rgName) } @@ -827,7 +828,7 @@ func (rm *ResourceManager) transferNode(rgName string, node int64) error { modifiedRG = append(modifiedRG, rg) // Commit updates to meta storage. - if err := rm.catalog.SaveResourceGroup(updates...); err != nil { + if err := rm.catalog.SaveResourceGroup(ctx, updates...); err != nil { log.Warn("failed to transfer node to resource group", zap.String("rgName", rgName), zap.String("originalRG", originalRG), @@ -854,12 +855,12 @@ func (rm *ResourceManager) transferNode(rgName string, node int64) error { } // unassignNode remove a node from resource group where it belongs to. -func (rm *ResourceManager) unassignNode(node int64) (string, error) { +func (rm *ResourceManager) unassignNode(ctx context.Context, node int64) (string, error) { if rg := rm.getResourceGroupByNodeID(node); rg != nil { mrg := rg.CopyForWrite() mrg.UnassignNode(node) rg := mrg.ToResourceGroup() - if err := rm.catalog.SaveResourceGroup(rg.GetMeta()); err != nil { + if err := rm.catalog.SaveResourceGroup(ctx, rg.GetMeta()); err != nil { log.Fatal("unassign node from resource group", zap.String("rgName", rg.GetName()), zap.Int64("node", node), @@ -943,7 +944,7 @@ func (rm *ResourceManager) validateResourceGroupIsDeletable(rgName string) error return nil } -func (rm *ResourceManager) GetResourceGroupsJSON() string { +func (rm *ResourceManager) GetResourceGroupsJSON(ctx context.Context) string { rm.rwmutex.RLock() defer rm.rwmutex.RUnlock() diff --git a/internal/querycoordv2/meta/resource_manager_test.go b/internal/querycoordv2/meta/resource_manager_test.go index b7cb16da98620..33899585d2e92 100644 --- a/internal/querycoordv2/meta/resource_manager_test.go +++ b/internal/querycoordv2/meta/resource_manager_test.go @@ -16,6 +16,7 @@ package meta import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -44,6 +45,7 @@ type ResourceManagerSuite struct { kv kv.MetaKv manager *ResourceManager + ctx context.Context } func (suite *ResourceManagerSuite) SetupSuite() { @@ -65,6 +67,7 @@ func (suite *ResourceManagerSuite) SetupTest() { store := querycoord.NewCatalog(suite.kv) suite.manager = NewResourceManager(store, session.NewNodeManager()) + suite.ctx = context.Background() } func (suite *ResourceManagerSuite) TearDownSuite() { @@ -76,6 +79,7 @@ func TestResourceManager(t *testing.T) { } func (suite *ResourceManagerSuite) TestValidateConfiguration() { + ctx := suite.ctx err := suite.manager.validateResourceGroupConfig("rg1", newResourceGroupConfig(0, 0)) suite.NoError(err) @@ -111,16 +115,17 @@ func (suite *ResourceManagerSuite) TestValidateConfiguration() { err = suite.manager.validateResourceGroupConfig("rg1", cfg) suite.ErrorIs(err, merr.ErrResourceGroupIllegalConfig) - err = suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(0, 0)) + err = suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(0, 0)) suite.NoError(err) - err = suite.manager.RemoveResourceGroup("rg2") + err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.NoError(err) } func (suite *ResourceManagerSuite) TestValidateDelete() { + ctx := suite.ctx // Non empty resource group can not be removed. - err := suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(1, 1)) + err := suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(1, 1)) suite.NoError(err) err = suite.manager.validateResourceGroupIsDeletable(DefaultResourceGroupName) @@ -131,8 +136,8 @@ func (suite *ResourceManagerSuite) TestValidateDelete() { cfg := newResourceGroupConfig(0, 0) cfg.TransferFrom = []*rgpb.ResourceGroupTransfer{{ResourceGroup: "rg1"}} - suite.manager.AddResourceGroup("rg2", cfg) - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.AddResourceGroup(ctx, "rg2", cfg) + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(0, 0), }) err = suite.manager.validateResourceGroupIsDeletable("rg1") @@ -140,64 +145,65 @@ func (suite *ResourceManagerSuite) TestValidateDelete() { cfg = newResourceGroupConfig(0, 0) cfg.TransferTo = []*rgpb.ResourceGroupTransfer{{ResourceGroup: "rg1"}} - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg2": cfg, }) err = suite.manager.validateResourceGroupIsDeletable("rg1") suite.ErrorIs(err, merr.ErrParameterInvalid) - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg2": newResourceGroupConfig(0, 0), }) err = suite.manager.validateResourceGroupIsDeletable("rg1") suite.NoError(err) - err = suite.manager.RemoveResourceGroup("rg1") + err = suite.manager.RemoveResourceGroup(ctx, "rg1") suite.NoError(err) - err = suite.manager.RemoveResourceGroup("rg2") + err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.NoError(err) } func (suite *ResourceManagerSuite) TestManipulateResourceGroup() { + ctx := suite.ctx // test add rg - err := suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(0, 0)) + err := suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(0, 0)) suite.NoError(err) - suite.True(suite.manager.ContainResourceGroup("rg1")) - suite.Len(suite.manager.ListResourceGroups(), 2) + suite.True(suite.manager.ContainResourceGroup(ctx, "rg1")) + suite.Len(suite.manager.ListResourceGroups(ctx), 2) // test add duplicate rg but same configuration is ok - err = suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(0, 0)) + err = suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(0, 0)) suite.NoError(err) - err = suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(1, 1)) + err = suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(1, 1)) suite.Error(err) // test delete rg - err = suite.manager.RemoveResourceGroup("rg1") + err = suite.manager.RemoveResourceGroup(ctx, "rg1") suite.NoError(err) // test delete rg which doesn't exist - err = suite.manager.RemoveResourceGroup("rg1") + err = suite.manager.RemoveResourceGroup(ctx, "rg1") suite.NoError(err) // test delete default rg - err = suite.manager.RemoveResourceGroup(DefaultResourceGroupName) + err = suite.manager.RemoveResourceGroup(ctx, DefaultResourceGroupName) suite.ErrorIs(err, merr.ErrParameterInvalid) // test delete a rg not empty. - err = suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(1, 1)) + err = suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(1, 1)) suite.NoError(err) - err = suite.manager.RemoveResourceGroup("rg2") + err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.ErrorIs(err, merr.ErrParameterInvalid) // test delete a rg after update - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg2": newResourceGroupConfig(0, 0), }) - err = suite.manager.RemoveResourceGroup("rg2") + err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.NoError(err) // assign a node to rg. - err = suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(1, 1)) + err = suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(1, 1)) suite.NoError(err) suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, @@ -205,68 +211,69 @@ func (suite *ResourceManagerSuite) TestManipulateResourceGroup() { Hostname: "localhost", })) defer suite.manager.nodeMgr.Remove(1) - suite.manager.HandleNodeUp(1) - err = suite.manager.RemoveResourceGroup("rg2") + suite.manager.HandleNodeUp(ctx, 1) + err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.ErrorIs(err, merr.ErrParameterInvalid) - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg2": newResourceGroupConfig(0, 0), }) log.Info("xxxxx") // RemoveResourceGroup will remove all nodes from the resource group. - err = suite.manager.RemoveResourceGroup("rg2") + err = suite.manager.RemoveResourceGroup(ctx, "rg2") suite.NoError(err) } func (suite *ResourceManagerSuite) TestNodeUpAndDown() { + ctx := suite.ctx suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", Hostname: "localhost", })) - err := suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(1, 1)) + err := suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(1, 1)) suite.NoError(err) // test add node to rg - suite.manager.HandleNodeUp(1) - suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum()) + suite.manager.HandleNodeUp(ctx, 1) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) // test add non-exist node to rg - err = suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + err = suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(2, 3), }) suite.NoError(err) - suite.manager.HandleNodeUp(2) - suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeUp(ctx, 2) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // teardown a non-exist node from rg. - suite.manager.HandleNodeDown(2) - suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeDown(ctx, 2) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // test add exist node to rg - suite.manager.HandleNodeUp(1) - suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeUp(ctx, 1) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // teardown a exist node from rg. - suite.manager.HandleNodeDown(1) - suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeDown(ctx, 1) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // teardown a exist node from rg. - suite.manager.HandleNodeDown(1) - suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeDown(ctx, 1) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) - suite.manager.HandleNodeUp(1) - suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeUp(ctx, 1) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) - err = suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + err = suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(4, 4), }) suite.NoError(err) - suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(1, 1)) + suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(1, 1)) suite.NoError(err) suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -289,29 +296,29 @@ func (suite *ResourceManagerSuite) TestNodeUpAndDown() { Address: "localhost", Hostname: "localhost", })) - suite.manager.HandleNodeUp(11) - suite.manager.HandleNodeUp(12) - suite.manager.HandleNodeUp(13) - suite.manager.HandleNodeUp(14) - - suite.Equal(4, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(1, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - - suite.manager.HandleNodeDown(11) - suite.manager.HandleNodeDown(12) - suite.manager.HandleNodeDown(13) - suite.manager.HandleNodeDown(14) - suite.Equal(1, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - - suite.manager.HandleNodeDown(1) - suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Zero(suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.HandleNodeUp(ctx, 11) + suite.manager.HandleNodeUp(ctx, 12) + suite.manager.HandleNodeUp(ctx, 13) + suite.manager.HandleNodeUp(ctx, 14) + + suite.Equal(4, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + + suite.manager.HandleNodeDown(ctx, 11) + suite.manager.HandleNodeDown(ctx, 12) + suite.manager.HandleNodeDown(ctx, 13) + suite.manager.HandleNodeDown(ctx, 14) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + + suite.manager.HandleNodeDown(ctx, 1) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Zero(suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(20, 30), "rg2": newResourceGroupConfig(30, 40), }) @@ -321,106 +328,107 @@ func (suite *ResourceManagerSuite) TestNodeUpAndDown() { Address: "localhost", Hostname: "localhost", })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } - suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(30, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(50, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(30, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(50, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // down all nodes for i := 1; i <= 100; i++ { - suite.manager.HandleNodeDown(int64(i)) - suite.Equal(100-i, suite.manager.GetResourceGroup("rg1").NodeNum()+ - suite.manager.GetResourceGroup("rg2").NodeNum()+ - suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeDown(ctx, int64(i)) + suite.Equal(100-i, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()+ + suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()+ + suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) } // if there are all rgs reach limit, should be fall back to default rg. - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(0, 0), "rg2": newResourceGroupConfig(0, 0), DefaultResourceGroupName: newResourceGroupConfig(0, 0), }) for i := 1; i <= 100; i++ { - suite.manager.HandleNodeUp(int64(i)) - suite.Equal(i, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg2").NodeNum()) + suite.manager.HandleNodeUp(ctx, int64(i)) + suite.Equal(i, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) } } func (suite *ResourceManagerSuite) TestAutoRecover() { + ctx := suite.ctx for i := 1; i <= 100; i++ { suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: int64(i), Address: "localhost", Hostname: "localhost", })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } - suite.Equal(100, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(100, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Recover 10 nodes from default resource group - suite.manager.AddResourceGroup("rg1", newResourceGroupConfig(10, 30)) - suite.Zero(suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg1").MissingNumOfNodes()) - suite.Equal(100, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg1").MissingNumOfNodes()) - suite.Equal(90, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.AddResourceGroup(ctx, "rg1", newResourceGroupConfig(10, 30)) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").MissingNumOfNodes()) + suite.Equal(100, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg1").MissingNumOfNodes()) + suite.Equal(90, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Recover 20 nodes from default resource group - suite.manager.AddResourceGroup("rg2", newResourceGroupConfig(20, 30)) - suite.Zero(suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(20, suite.manager.GetResourceGroup("rg2").MissingNumOfNodes()) - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(90, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - suite.manager.AutoRecoverResourceGroup("rg2") - suite.Equal(20, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(70, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.AddResourceGroup(ctx, "rg2", newResourceGroupConfig(20, 30)) + suite.Zero(suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg2").MissingNumOfNodes()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(90, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(70, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Recover 5 redundant nodes from resource group - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(5, 5), }) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.Equal(20, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(5, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(75, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(75, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Recover 10 redundant nodes from resource group 2 to resource group 1 and default resource group. - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(10, 20), "rg2": newResourceGroupConfig(5, 10), }) - suite.manager.AutoRecoverResourceGroup("rg2") - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(80, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(80, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // recover redundant nodes from default resource group - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(10, 20), "rg2": newResourceGroupConfig(20, 30), DefaultResourceGroupName: newResourceGroupConfig(10, 20), }) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) // Even though the default resource group has 20 nodes limits, // all redundant nodes will be assign to default resource group. - suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(30, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(50, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(30, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(50, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Test recover missing from high priority resource group by set `from`. - suite.manager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + suite.manager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 15, }, @@ -431,23 +439,23 @@ func (suite *ResourceManagerSuite) TestAutoRecover() { ResourceGroup: "rg1", }}, }) - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ DefaultResourceGroupName: newResourceGroupConfig(30, 40), }) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - suite.manager.AutoRecoverResourceGroup("rg3") + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") // Get 10 from default group for redundant nodes, get 5 from rg1 for rg3 at high priority. - suite.Equal(15, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(30, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(15, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(30, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Test recover redundant to high priority resource group by set `to`. - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg3": { Requests: &rgpb.ResourceGroupLimit{ NodeNum: 0, @@ -463,21 +471,21 @@ func (suite *ResourceManagerSuite) TestAutoRecover() { "rg2": newResourceGroupConfig(15, 40), }) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - suite.manager.AutoRecoverResourceGroup("rg3") + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") // Recover rg3 by transfer 10 nodes to rg2 with high priority, 5 to rg1. - suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) suite.testTransferNode() // Test redundant nodes recover to default resource group. - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ DefaultResourceGroupName: newResourceGroupConfig(1, 1), "rg3": newResourceGroupConfig(0, 0), "rg2": newResourceGroupConfig(0, 0), @@ -485,107 +493,109 @@ func (suite *ResourceManagerSuite) TestAutoRecover() { }) // Even default resource group has 1 node limit, // all redundant nodes will be assign to default resource group if there's no resource group can hold. - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup("rg3") - suite.Equal(0, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(100, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(100, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Test redundant recover to missing nodes and missing nodes from redundant nodes. // Initialize - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ DefaultResourceGroupName: newResourceGroupConfig(0, 0), "rg3": newResourceGroupConfig(10, 10), "rg2": newResourceGroupConfig(80, 80), "rg1": newResourceGroupConfig(10, 10), }) - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup("rg3") - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(80, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ DefaultResourceGroupName: newResourceGroupConfig(0, 5), "rg3": newResourceGroupConfig(5, 5), "rg2": newResourceGroupConfig(80, 80), "rg1": newResourceGroupConfig(20, 30), }) - suite.manager.AutoRecoverResourceGroup("rg3") // recover redundant to missing rg. - suite.Equal(15, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - suite.manager.updateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") // recover redundant to missing rg. + suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(80, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + suite.manager.updateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ DefaultResourceGroupName: newResourceGroupConfig(5, 5), "rg3": newResourceGroupConfig(5, 10), "rg2": newResourceGroupConfig(80, 80), "rg1": newResourceGroupConfig(10, 10), }) - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) // recover missing from redundant rg. - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(80, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(5, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) // recover missing from redundant rg. + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(80, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(5, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) } func (suite *ResourceManagerSuite) testTransferNode() { + ctx := suite.ctx // Test redundant nodes recover to default resource group. - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ DefaultResourceGroupName: newResourceGroupConfig(40, 40), "rg3": newResourceGroupConfig(0, 0), "rg2": newResourceGroupConfig(40, 40), "rg1": newResourceGroupConfig(20, 20), }) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - suite.manager.AutoRecoverResourceGroup("rg3") + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") - suite.Equal(20, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // Test TransferNode. // param error. - err := suite.manager.TransferNode("rg1", "rg1", 1) + err := suite.manager.TransferNode(ctx, "rg1", "rg1", 1) suite.Error(err) - err = suite.manager.TransferNode("rg1", "rg2", 0) + err = suite.manager.TransferNode(ctx, "rg1", "rg2", 0) suite.Error(err) - err = suite.manager.TransferNode("rg3", "rg2", 1) + err = suite.manager.TransferNode(ctx, "rg3", "rg2", 1) suite.Error(err) - err = suite.manager.TransferNode("rg1", "rg10086", 1) + err = suite.manager.TransferNode(ctx, "rg1", "rg10086", 1) suite.Error(err) - err = suite.manager.TransferNode("rg10086", "rg2", 1) + err = suite.manager.TransferNode(ctx, "rg10086", "rg2", 1) suite.Error(err) // success - err = suite.manager.TransferNode("rg1", "rg3", 5) + err = suite.manager.TransferNode(ctx, "rg1", "rg3", 5) suite.NoError(err) - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - suite.manager.AutoRecoverResourceGroup("rg3") + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") - suite.Equal(15, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(5, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(40, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(15, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(5, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(40, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) } func (suite *ResourceManagerSuite) TestIncomingNode() { + ctx := suite.ctx suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 1, Address: "localhost", @@ -593,15 +603,16 @@ func (suite *ResourceManagerSuite) TestIncomingNode() { })) suite.manager.incomingNode.Insert(1) - suite.Equal(1, suite.manager.CheckIncomingNodeNum()) - suite.manager.AssignPendingIncomingNode() - suite.Equal(0, suite.manager.CheckIncomingNodeNum()) - nodes, err := suite.manager.GetNodes(DefaultResourceGroupName) + suite.Equal(1, suite.manager.CheckIncomingNodeNum(ctx)) + suite.manager.AssignPendingIncomingNode(ctx) + suite.Equal(0, suite.manager.CheckIncomingNodeNum(ctx)) + nodes, err := suite.manager.GetNodes(ctx, DefaultResourceGroupName) suite.NoError(err) suite.Len(nodes, 1) } func (suite *ResourceManagerSuite) TestUnassignFail() { + ctx := suite.ctx // suite.man mockKV := mocks.NewMetaKv(suite.T()) mockKV.EXPECT().MultiSave(mock.Anything).Return(nil).Once() @@ -609,7 +620,7 @@ func (suite *ResourceManagerSuite) TestUnassignFail() { store := querycoord.NewCatalog(mockKV) suite.manager = NewResourceManager(store, session.NewNodeManager()) - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": newResourceGroupConfig(20, 30), }) @@ -618,16 +629,17 @@ func (suite *ResourceManagerSuite) TestUnassignFail() { Address: "localhost", Hostname: "localhost", })) - suite.manager.HandleNodeUp(1) + suite.manager.HandleNodeUp(ctx, 1) mockKV.EXPECT().MultiSave(mock.Anything).Return(merr.WrapErrServiceInternal("mocked")).Once() suite.Panics(func() { - suite.manager.HandleNodeDown(1) + suite.manager.HandleNodeDown(ctx, 1) }) } func TestGetResourceGroupsJSON(t *testing.T) { + ctx := context.Background() nodeManager := session.NewNodeManager() manager := &ResourceManager{groups: make(map[string]*ResourceGroup)} rg1 := NewResourceGroup("rg1", newResourceGroupConfig(0, 10), nodeManager) @@ -637,7 +649,7 @@ func TestGetResourceGroupsJSON(t *testing.T) { manager.groups["rg1"] = rg1 manager.groups["rg2"] = rg2 - jsonOutput := manager.GetResourceGroupsJSON() + jsonOutput := manager.GetResourceGroupsJSON(ctx) var resourceGroups []*metricsinfo.ResourceGroup err := json.Unmarshal([]byte(jsonOutput), &resourceGroups) assert.NoError(t, err) @@ -659,7 +671,8 @@ func TestGetResourceGroupsJSON(t *testing.T) { } func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { - suite.manager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{ + ctx := suite.ctx + suite.manager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, }, @@ -676,7 +689,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { }, }) - suite.manager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + suite.manager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, }, @@ -693,7 +706,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { }, }) - suite.manager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + suite.manager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, }, @@ -720,12 +733,12 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { "dc_name": "label1", }, })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // test new querynode with label2 for i := 31; i <= 40; i++ { @@ -737,13 +750,13 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { "dc_name": "label2", }, })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(0, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - nodesInRG, _ := suite.manager.GetNodes("rg2") + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(0, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + nodesInRG, _ := suite.manager.GetNodes(ctx, "rg2") for _, node := range nodesInRG { suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } @@ -758,19 +771,19 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { "dc_name": "label3", }, })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - nodesInRG, _ = suite.manager.GetNodes("rg3") + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + nodesInRG, _ = suite.manager.GetNodes(ctx, "rg3") for _, node := range nodesInRG { suite.Equal("label3", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } // test swap rg's label - suite.manager.UpdateResourceGroups(map[string]*rgpb.ResourceGroupConfig{ + suite.manager.UpdateResourceGroups(ctx, map[string]*rgpb.ResourceGroupConfig{ "rg1": { Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, @@ -823,33 +836,34 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeAssign() { log.Info("test swap rg's label") for i := 0; i < 4; i++ { - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup("rg3") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) } - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(20, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - nodesInRG, _ = suite.manager.GetNodes("rg1") + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(20, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + nodesInRG, _ = suite.manager.GetNodes(ctx, "rg1") for _, node := range nodesInRG { suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } - nodesInRG, _ = suite.manager.GetNodes("rg2") + nodesInRG, _ = suite.manager.GetNodes(ctx, "rg2") for _, node := range nodesInRG { suite.Equal("label3", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } - nodesInRG, _ = suite.manager.GetNodes("rg3") + nodesInRG, _ = suite.manager.GetNodes(ctx, "rg3") for _, node := range nodesInRG { suite.Equal("label1", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } } func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { - suite.manager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{ + ctx := suite.ctx + suite.manager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, }, @@ -866,7 +880,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { }, }) - suite.manager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + suite.manager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, }, @@ -883,7 +897,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { }, }) - suite.manager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + suite.manager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{ NodeNum: 10, }, @@ -910,7 +924,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { "dc_name": "label1", }, })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } // test new querynode with label2 @@ -923,7 +937,7 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { "dc_name": "label2", }, })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } // test new querynode with label3 for i := 41; i <= 50; i++ { @@ -935,18 +949,18 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { "dc_name": "label3", }, })) - suite.manager.HandleNodeUp(int64(i)) + suite.manager.HandleNodeUp(ctx, int64(i)) } - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) // test node down with label1 - suite.manager.HandleNodeDown(int64(1)) + suite.manager.HandleNodeDown(ctx, int64(1)) suite.manager.nodeMgr.Remove(int64(1)) - suite.Equal(9, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) + suite.Equal(9, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) // test node up with label2 suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -957,11 +971,11 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { "dc_name": "label2", }, })) - suite.manager.HandleNodeUp(int64(101)) - suite.Equal(9, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(1, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) + suite.manager.HandleNodeUp(ctx, int64(101)) + suite.Equal(9, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) // test node up with label1 suite.manager.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ @@ -972,21 +986,21 @@ func (suite *ResourceManagerSuite) TestNodeLabels_NodeDown() { "dc_name": "label1", }, })) - suite.manager.HandleNodeUp(int64(102)) - suite.Equal(10, suite.manager.GetResourceGroup("rg1").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg2").NodeNum()) - suite.Equal(10, suite.manager.GetResourceGroup("rg3").NodeNum()) - suite.Equal(1, suite.manager.GetResourceGroup(DefaultResourceGroupName).NodeNum()) - nodesInRG, _ := suite.manager.GetNodes("rg1") + suite.manager.HandleNodeUp(ctx, int64(102)) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg1").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg2").NodeNum()) + suite.Equal(10, suite.manager.GetResourceGroup(ctx, "rg3").NodeNum()) + suite.Equal(1, suite.manager.GetResourceGroup(ctx, DefaultResourceGroupName).NodeNum()) + nodesInRG, _ := suite.manager.GetNodes(ctx, "rg1") for _, node := range nodesInRG { suite.Equal("label1", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } - suite.manager.AutoRecoverResourceGroup("rg1") - suite.manager.AutoRecoverResourceGroup("rg2") - suite.manager.AutoRecoverResourceGroup("rg3") - suite.manager.AutoRecoverResourceGroup(DefaultResourceGroupName) - nodesInRG, _ = suite.manager.GetNodes(DefaultResourceGroupName) + suite.manager.AutoRecoverResourceGroup(ctx, "rg1") + suite.manager.AutoRecoverResourceGroup(ctx, "rg2") + suite.manager.AutoRecoverResourceGroup(ctx, "rg3") + suite.manager.AutoRecoverResourceGroup(ctx, DefaultResourceGroupName) + nodesInRG, _ = suite.manager.GetNodes(ctx, DefaultResourceGroupName) for _, node := range nodesInRG { suite.Equal("label2", suite.manager.nodeMgr.Get(node).Labels()["dc_name"]) } diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index 2115297de469c..68f5b8cf7ba55 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -50,26 +50,26 @@ const ( ) type TargetManagerInterface interface { - UpdateCollectionCurrentTarget(collectionID int64) bool - UpdateCollectionNextTarget(collectionID int64) error - RemoveCollection(collectionID int64) - RemovePartition(collectionID int64, partitionIDs ...int64) - GetGrowingSegmentsByCollection(collectionID int64, scope TargetScope) typeutil.UniqueSet - GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) typeutil.UniqueSet - GetSealedSegmentsByCollection(collectionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo - GetSealedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) map[int64]*datapb.SegmentInfo - GetDroppedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) []int64 - GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo - GetDmChannelsByCollection(collectionID int64, scope TargetScope) map[string]*DmChannel - GetDmChannel(collectionID int64, channel string, scope TargetScope) *DmChannel - GetSealedSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo - GetCollectionTargetVersion(collectionID int64, scope TargetScope) int64 - IsCurrentTargetExist(collectionID int64, partitionID int64) bool - IsNextTargetExist(collectionID int64) bool - SaveCurrentTarget(catalog metastore.QueryCoordCatalog) - Recover(catalog metastore.QueryCoordCatalog) error - CanSegmentBeMoved(collectionID, segmentID int64) bool - GetTargetJSON(scope TargetScope) string + UpdateCollectionCurrentTarget(ctx context.Context, collectionID int64) bool + UpdateCollectionNextTarget(ctx context.Context, collectionID int64) error + RemoveCollection(ctx context.Context, collectionID int64) + RemovePartition(ctx context.Context, collectionID int64, partitionIDs ...int64) + GetGrowingSegmentsByCollection(ctx context.Context, collectionID int64, scope TargetScope) typeutil.UniqueSet + GetGrowingSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope) typeutil.UniqueSet + GetSealedSegmentsByCollection(ctx context.Context, collectionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo + GetSealedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope) map[int64]*datapb.SegmentInfo + GetDroppedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope) []int64 + GetSealedSegmentsByPartition(ctx context.Context, collectionID int64, partitionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo + GetDmChannelsByCollection(ctx context.Context, collectionID int64, scope TargetScope) map[string]*DmChannel + GetDmChannel(ctx context.Context, collectionID int64, channel string, scope TargetScope) *DmChannel + GetSealedSegment(ctx context.Context, collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo + GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope TargetScope) int64 + IsCurrentTargetExist(ctx context.Context, collectionID int64, partitionID int64) bool + IsNextTargetExist(ctx context.Context, collectionID int64) bool + SaveCurrentTarget(ctx context.Context, catalog metastore.QueryCoordCatalog) + Recover(ctx context.Context, catalog metastore.QueryCoordCatalog) error + CanSegmentBeMoved(ctx context.Context, collectionID, segmentID int64) bool + GetTargetJSON(ctx context.Context, scope TargetScope) string } type TargetManager struct { @@ -96,7 +96,7 @@ func NewTargetManager(broker Broker, meta *Meta) *TargetManager { // UpdateCollectionCurrentTarget updates the current target to next target, // WARN: DO NOT call this method for an existing collection as target observer running, or it will lead to a double-update, // which may make the current target not available -func (mgr *TargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool { +func (mgr *TargetManager) UpdateCollectionCurrentTarget(ctx context.Context, collectionID int64) bool { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() log := log.With(zap.Int64("collectionID", collectionID)) @@ -137,7 +137,7 @@ func (mgr *TargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool // UpdateCollectionNextTarget updates the next target with new target pulled from DataCoord, // WARN: DO NOT call this method for an existing collection as target observer running, or it will lead to a double-update, // which may make the current target not available -func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error { +func (mgr *TargetManager) UpdateCollectionNextTarget(ctx context.Context, collectionID int64) error { var vChannelInfos []*datapb.VchannelInfo var segmentInfos []*datapb.SegmentInfo err := retry.Handle(context.TODO(), func() (bool, error) { @@ -155,7 +155,7 @@ func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() - partitions := mgr.meta.GetPartitionsByCollection(collectionID) + partitions := mgr.meta.GetPartitionsByCollection(ctx, collectionID) partitionIDs := lo.Map(partitions, func(partition *Partition, i int) int64 { return partition.PartitionID }) @@ -223,7 +223,7 @@ func (mgr *TargetManager) mergeDmChannelInfo(infos []*datapb.VchannelInfo) *DmCh } // RemoveCollection removes all channels and segments in the given collection -func (mgr *TargetManager) RemoveCollection(collectionID int64) { +func (mgr *TargetManager) RemoveCollection(ctx context.Context, collectionID int64) { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() log.Info("remove collection from targets", @@ -245,7 +245,7 @@ func (mgr *TargetManager) RemoveCollection(collectionID int64) { // RemovePartition removes all segment in the given partition, // NOTE: this doesn't remove any channel even the given one is the only partition -func (mgr *TargetManager) RemovePartition(collectionID int64, partitionIDs ...int64) { +func (mgr *TargetManager) RemovePartition(ctx context.Context, collectionID int64, partitionIDs ...int64) { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() @@ -352,7 +352,7 @@ func (mgr *TargetManager) getCollectionTarget(scope TargetScope, collectionID in return nil } -func (mgr *TargetManager) GetGrowingSegmentsByCollection(collectionID int64, +func (mgr *TargetManager) GetGrowingSegmentsByCollection(ctx context.Context, collectionID int64, scope TargetScope, ) typeutil.UniqueSet { mgr.rwMutex.RLock() @@ -374,7 +374,7 @@ func (mgr *TargetManager) GetGrowingSegmentsByCollection(collectionID int64, return nil } -func (mgr *TargetManager) GetGrowingSegmentsByChannel(collectionID int64, +func (mgr *TargetManager) GetGrowingSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope, ) typeutil.UniqueSet { @@ -398,7 +398,7 @@ func (mgr *TargetManager) GetGrowingSegmentsByChannel(collectionID int64, return nil } -func (mgr *TargetManager) GetSealedSegmentsByCollection(collectionID int64, +func (mgr *TargetManager) GetSealedSegmentsByCollection(ctx context.Context, collectionID int64, scope TargetScope, ) map[int64]*datapb.SegmentInfo { mgr.rwMutex.RLock() @@ -413,7 +413,7 @@ func (mgr *TargetManager) GetSealedSegmentsByCollection(collectionID int64, return nil } -func (mgr *TargetManager) GetSealedSegmentsByChannel(collectionID int64, +func (mgr *TargetManager) GetSealedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope, ) map[int64]*datapb.SegmentInfo { @@ -437,7 +437,7 @@ func (mgr *TargetManager) GetSealedSegmentsByChannel(collectionID int64, return nil } -func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64, +func (mgr *TargetManager) GetDroppedSegmentsByChannel(ctx context.Context, collectionID int64, channelName string, scope TargetScope, ) []int64 { @@ -454,7 +454,7 @@ func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64, return nil } -func (mgr *TargetManager) GetSealedSegmentsByPartition(collectionID int64, +func (mgr *TargetManager) GetSealedSegmentsByPartition(ctx context.Context, collectionID int64, partitionID int64, scope TargetScope, ) map[int64]*datapb.SegmentInfo { @@ -478,7 +478,7 @@ func (mgr *TargetManager) GetSealedSegmentsByPartition(collectionID int64, return nil } -func (mgr *TargetManager) GetDmChannelsByCollection(collectionID int64, scope TargetScope) map[string]*DmChannel { +func (mgr *TargetManager) GetDmChannelsByCollection(ctx context.Context, collectionID int64, scope TargetScope) map[string]*DmChannel { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -491,7 +491,7 @@ func (mgr *TargetManager) GetDmChannelsByCollection(collectionID int64, scope Ta return nil } -func (mgr *TargetManager) GetDmChannel(collectionID int64, channel string, scope TargetScope) *DmChannel { +func (mgr *TargetManager) GetDmChannel(ctx context.Context, collectionID int64, channel string, scope TargetScope) *DmChannel { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -504,7 +504,7 @@ func (mgr *TargetManager) GetDmChannel(collectionID int64, channel string, scope return nil } -func (mgr *TargetManager) GetSealedSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo { +func (mgr *TargetManager) GetSealedSegment(ctx context.Context, collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -518,7 +518,7 @@ func (mgr *TargetManager) GetSealedSegment(collectionID int64, id int64, scope T return nil } -func (mgr *TargetManager) GetCollectionTargetVersion(collectionID int64, scope TargetScope) int64 { +func (mgr *TargetManager) GetCollectionTargetVersion(ctx context.Context, collectionID int64, scope TargetScope) int64 { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -532,7 +532,7 @@ func (mgr *TargetManager) GetCollectionTargetVersion(collectionID int64, scope T return 0 } -func (mgr *TargetManager) IsCurrentTargetExist(collectionID int64, partitionID int64) bool { +func (mgr *TargetManager) IsCurrentTargetExist(ctx context.Context, collectionID int64, partitionID int64) bool { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -541,13 +541,13 @@ func (mgr *TargetManager) IsCurrentTargetExist(collectionID int64, partitionID i return len(targets) > 0 && (targets[0].partitions.Contain(partitionID) || partitionID == common.AllPartitionsID) && len(targets[0].dmChannels) > 0 } -func (mgr *TargetManager) IsNextTargetExist(collectionID int64) bool { - newChannels := mgr.GetDmChannelsByCollection(collectionID, NextTarget) +func (mgr *TargetManager) IsNextTargetExist(ctx context.Context, collectionID int64) bool { + newChannels := mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget) return len(newChannels) > 0 } -func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) { +func (mgr *TargetManager) SaveCurrentTarget(ctx context.Context, catalog metastore.QueryCoordCatalog) { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() if mgr.current != nil { @@ -562,7 +562,7 @@ func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) pool.Submit(func() (any, error) { defer wg.Done() ids := lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) int64 { return p.A }) - if err := catalog.SaveCollectionTargets(lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) *querypb.CollectionTarget { + if err := catalog.SaveCollectionTargets(ctx, lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) *querypb.CollectionTarget { return p.B })...); err != nil { log.Warn("failed to save current target for collection", zap.Int64s("collectionIDs", ids), zap.Error(err)) @@ -587,11 +587,11 @@ func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) } } -func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error { +func (mgr *TargetManager) Recover(ctx context.Context, catalog metastore.QueryCoordCatalog) error { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() - targets, err := catalog.GetCollectionTargets() + targets, err := catalog.GetCollectionTargets(ctx) if err != nil { log.Warn("failed to recover collection target from etcd", zap.Error(err)) return err @@ -608,7 +608,7 @@ func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error { ) // clear target info in meta store - err := catalog.RemoveCollectionTarget(t.GetCollectionID()) + err := catalog.RemoveCollectionTarget(ctx, t.GetCollectionID()) if err != nil { log.Warn("failed to clear collection target from etcd", zap.Error(err)) } @@ -618,7 +618,7 @@ func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error { } // if segment isn't l0 segment, and exist in current/next target, then it can be moved -func (mgr *TargetManager) CanSegmentBeMoved(collectionID, segmentID int64) bool { +func (mgr *TargetManager) CanSegmentBeMoved(ctx context.Context, collectionID, segmentID int64) bool { mgr.rwMutex.Lock() defer mgr.rwMutex.Unlock() current := mgr.current.getCollectionTarget(collectionID) @@ -634,7 +634,7 @@ func (mgr *TargetManager) CanSegmentBeMoved(collectionID, segmentID int64) bool return false } -func (mgr *TargetManager) GetTargetJSON(scope TargetScope) string { +func (mgr *TargetManager) GetTargetJSON(ctx context.Context, scope TargetScope) string { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() diff --git a/internal/querycoordv2/meta/target_manager_test.go b/internal/querycoordv2/meta/target_manager_test.go index 7eeb99e83e2fe..e7ca040ecbc66 100644 --- a/internal/querycoordv2/meta/target_manager_test.go +++ b/internal/querycoordv2/meta/target_manager_test.go @@ -17,6 +17,7 @@ package meta import ( + "context" "testing" "time" @@ -60,6 +61,8 @@ type TargetManagerSuite struct { broker *MockBroker // Test object mgr *TargetManager + + ctx context.Context } func (suite *TargetManagerSuite) SetupSuite() { @@ -110,6 +113,7 @@ func (suite *TargetManagerSuite) SetupTest() { config.EtcdTLSMinVersion.GetValue()) suite.Require().NoError(err) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.ctx = context.Background() // meta suite.catalog = querycoord.NewCatalog(suite.kv) @@ -141,14 +145,14 @@ func (suite *TargetManagerSuite) SetupTest() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(dmChannels, allSegments, nil) - suite.meta.PutCollection(&Collection{ + suite.meta.PutCollection(suite.ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection, ReplicaNumber: 1, }, }) for _, partition := range suite.partitions[collection] { - suite.meta.PutPartition(&Partition{ + suite.meta.PutPartition(suite.ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collection, PartitionID: partition, @@ -156,7 +160,7 @@ func (suite *TargetManagerSuite) SetupTest() { }) } - suite.mgr.UpdateCollectionNextTarget(collection) + suite.mgr.UpdateCollectionNextTarget(suite.ctx, collection) } } @@ -165,35 +169,37 @@ func (suite *TargetManagerSuite) TearDownSuite() { } func (suite *TargetManagerSuite) TestUpdateCurrentTarget() { + ctx := suite.ctx collectionID := int64(1000) suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), - suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) - - suite.mgr.UpdateCollectionCurrentTarget(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) + suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) + + suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), - suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) } func (suite *TargetManagerSuite) TestUpdateNextTarget() { + ctx := suite.ctx collectionID := int64(1003) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) - suite.meta.PutCollection(&Collection{ + suite.meta.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collectionID, ReplicaNumber: 1, }, }) - suite.meta.PutPartition(&Partition{ + suite.meta.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collectionID, PartitionID: 1, @@ -225,62 +231,64 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) - suite.mgr.UpdateCollectionNextTarget(collectionID) - suite.assertSegments([]int64{11, 12}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.mgr.UpdateCollectionNextTarget(ctx, collectionID) + suite.assertSegments([]int64{11, 12}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) suite.broker.ExpectedCalls = nil // test getRecoveryInfoV2 failed , then retry getRecoveryInfoV2 succeed suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nil, nil, errors.New("fake error")).Times(1) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) - err := suite.mgr.UpdateCollectionNextTarget(collectionID) + err := suite.mgr.UpdateCollectionNextTarget(ctx, collectionID) suite.NoError(err) - err = suite.mgr.UpdateCollectionNextTarget(collectionID) + err = suite.mgr.UpdateCollectionNextTarget(ctx, collectionID) suite.NoError(err) } func (suite *TargetManagerSuite) TestRemovePartition() { + ctx := suite.ctx collectionID := int64(1000) - suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) - - suite.mgr.RemovePartition(collectionID, 100) - suite.assertSegments(append([]int64{3, 4}, suite.level0Segments...), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) + + suite.mgr.RemovePartition(ctx, collectionID, 100) + suite.assertSegments(append([]int64{3, 4}, suite.level0Segments...), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) } func (suite *TargetManagerSuite) TestRemoveCollection() { + ctx := suite.ctx collectionID := int64(1000) - suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) - suite.mgr.RemoveCollection(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.mgr.RemoveCollection(ctx, collectionID) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) collectionID = int64(1001) - suite.mgr.UpdateCollectionCurrentTarget(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) - - suite.mgr.RemoveCollection(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) + + suite.mgr.RemoveCollection(ctx, collectionID) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) } func (suite *TargetManagerSuite) getAllSegment(collectionID int64, partitionIDs []int64) []int64 { @@ -325,6 +333,7 @@ func (suite *TargetManagerSuite) assertSegments(expected []int64, actual map[int } func (suite *TargetManagerSuite) TestGetCollectionTargetVersion() { + ctx := suite.ctx t1 := time.Now().UnixNano() target := NewCollectionTarget(nil, nil, nil) t2 := time.Now().UnixNano() @@ -335,28 +344,29 @@ func (suite *TargetManagerSuite) TestGetCollectionTargetVersion() { collectionID := suite.collections[0] t3 := time.Now().UnixNano() - suite.mgr.UpdateCollectionNextTarget(collectionID) + suite.mgr.UpdateCollectionNextTarget(ctx, collectionID) t4 := time.Now().UnixNano() - collectionVersion := suite.mgr.GetCollectionTargetVersion(collectionID, NextTarget) + collectionVersion := suite.mgr.GetCollectionTargetVersion(ctx, collectionID, NextTarget) suite.True(t3 <= collectionVersion) suite.True(t4 >= collectionVersion) } func (suite *TargetManagerSuite) TestGetSegmentByChannel() { + ctx := suite.ctx collectionID := int64(1003) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) - suite.meta.PutCollection(&Collection{ + suite.meta.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collectionID, ReplicaNumber: 1, }, }) - suite.meta.PutPartition(&Partition{ + suite.meta.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collectionID, PartitionID: 1, @@ -391,17 +401,17 @@ func (suite *TargetManagerSuite) TestGetSegmentByChannel() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) - suite.mgr.UpdateCollectionNextTarget(collectionID) - suite.Len(suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget), 2) - suite.Len(suite.mgr.GetSealedSegmentsByChannel(collectionID, "channel-1", NextTarget), 1) - suite.Len(suite.mgr.GetSealedSegmentsByChannel(collectionID, "channel-2", NextTarget), 1) - suite.Len(suite.mgr.GetGrowingSegmentsByChannel(collectionID, "channel-1", NextTarget), 4) - suite.Len(suite.mgr.GetGrowingSegmentsByChannel(collectionID, "channel-2", NextTarget), 1) - suite.Len(suite.mgr.GetDroppedSegmentsByChannel(collectionID, "channel-1", NextTarget), 3) - suite.Len(suite.mgr.GetGrowingSegmentsByCollection(collectionID, NextTarget), 5) - suite.Len(suite.mgr.GetSealedSegmentsByPartition(collectionID, 1, NextTarget), 2) - suite.NotNil(suite.mgr.GetSealedSegment(collectionID, 11, NextTarget)) - suite.NotNil(suite.mgr.GetDmChannel(collectionID, "channel-1", NextTarget)) + suite.mgr.UpdateCollectionNextTarget(ctx, collectionID) + suite.Len(suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget), 2) + suite.Len(suite.mgr.GetSealedSegmentsByChannel(ctx, collectionID, "channel-1", NextTarget), 1) + suite.Len(suite.mgr.GetSealedSegmentsByChannel(ctx, collectionID, "channel-2", NextTarget), 1) + suite.Len(suite.mgr.GetGrowingSegmentsByChannel(ctx, collectionID, "channel-1", NextTarget), 4) + suite.Len(suite.mgr.GetGrowingSegmentsByChannel(ctx, collectionID, "channel-2", NextTarget), 1) + suite.Len(suite.mgr.GetDroppedSegmentsByChannel(ctx, collectionID, "channel-1", NextTarget), 3) + suite.Len(suite.mgr.GetGrowingSegmentsByCollection(ctx, collectionID, NextTarget), 5) + suite.Len(suite.mgr.GetSealedSegmentsByPartition(ctx, collectionID, 1, NextTarget), 2) + suite.NotNil(suite.mgr.GetSealedSegment(ctx, collectionID, 11, NextTarget)) + suite.NotNil(suite.mgr.GetDmChannel(ctx, collectionID, "channel-1", NextTarget)) } func (suite *TargetManagerSuite) TestGetTarget() { @@ -535,19 +545,20 @@ func (suite *TargetManagerSuite) TestGetTarget() { } func (suite *TargetManagerSuite) TestRecover() { + ctx := suite.ctx collectionID := int64(1003) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) - suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(ctx, collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(ctx, collectionID, CurrentTarget)) - suite.meta.PutCollection(&Collection{ + suite.meta.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collectionID, ReplicaNumber: 1, }, }) - suite.meta.PutPartition(&Partition{ + suite.meta.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collectionID, PartitionID: 1, @@ -582,16 +593,16 @@ func (suite *TargetManagerSuite) TestRecover() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) - suite.mgr.UpdateCollectionNextTarget(collectionID) - suite.mgr.UpdateCollectionCurrentTarget(collectionID) + suite.mgr.UpdateCollectionNextTarget(ctx, collectionID) + suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID) - suite.mgr.SaveCurrentTarget(suite.catalog) + suite.mgr.SaveCurrentTarget(ctx, suite.catalog) // clear target in memory version := suite.mgr.current.getCollectionTarget(collectionID).GetTargetVersion() suite.mgr.current.removeCollectionTarget(collectionID) // try to recover - suite.mgr.Recover(suite.catalog) + suite.mgr.Recover(ctx, suite.catalog) target := suite.mgr.current.getCollectionTarget(collectionID) suite.NotNil(target) @@ -600,20 +611,21 @@ func (suite *TargetManagerSuite) TestRecover() { suite.Equal(target.GetTargetVersion(), version) // after recover, target info should be cleaned up - targets, err := suite.catalog.GetCollectionTargets() + targets, err := suite.catalog.GetCollectionTargets(ctx) suite.NoError(err) suite.Len(targets, 0) } func (suite *TargetManagerSuite) TestGetTargetJSON() { + ctx := suite.ctx collectionID := int64(1003) - suite.meta.PutCollection(&Collection{ + suite.meta.PutCollection(ctx, &Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collectionID, ReplicaNumber: 1, }, }) - suite.meta.PutPartition(&Partition{ + suite.meta.PutPartition(ctx, &Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collectionID, PartitionID: 1, @@ -648,10 +660,10 @@ func (suite *TargetManagerSuite) TestGetTargetJSON() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) - suite.NoError(suite.mgr.UpdateCollectionNextTarget(collectionID)) - suite.True(suite.mgr.UpdateCollectionCurrentTarget(collectionID)) + suite.NoError(suite.mgr.UpdateCollectionNextTarget(ctx, collectionID)) + suite.True(suite.mgr.UpdateCollectionCurrentTarget(ctx, collectionID)) - jsonStr := suite.mgr.GetTargetJSON(CurrentTarget) + jsonStr := suite.mgr.GetTargetJSON(ctx, CurrentTarget) assert.NotEmpty(suite.T(), jsonStr) var currentTarget []*metricsinfo.QueryCoordTarget diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index 54b6f7c52983b..f36cf2da4d262 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -86,7 +86,7 @@ func NewCollectionObserver( } // Add load task for collection recovery - collections := meta.GetAllCollections() + collections := meta.GetAllCollections(context.TODO()) for _, collection := range collections { ob.LoadCollection(context.Background(), collection.GetCollectionID()) } @@ -157,13 +157,13 @@ func (ob *CollectionObserver) LoadPartitions(ctx context.Context, collectionID i } func (ob *CollectionObserver) Observe(ctx context.Context) { - ob.observeTimeout() + ob.observeTimeout(ctx) ob.observeLoadStatus(ctx) } -func (ob *CollectionObserver) observeTimeout() { +func (ob *CollectionObserver) observeTimeout(ctx context.Context) { ob.loadTasks.Range(func(traceID string, task LoadTask) bool { - collection := ob.meta.CollectionManager.GetCollection(task.CollectionID) + collection := ob.meta.CollectionManager.GetCollection(ctx, task.CollectionID) // collection released if collection == nil { log.Info("Load Collection Task canceled, collection removed from meta", zap.Int64("collectionID", task.CollectionID), zap.String("traceID", traceID)) @@ -178,14 +178,14 @@ func (ob *CollectionObserver) observeTimeout() { log.Info("load collection timeout, cancel it", zap.Int64("collectionID", collection.GetCollectionID()), zap.Duration("loadTime", time.Since(collection.CreatedAt))) - ob.meta.CollectionManager.RemoveCollection(collection.GetCollectionID()) - ob.meta.ReplicaManager.RemoveCollection(collection.GetCollectionID()) + ob.meta.CollectionManager.RemoveCollection(ctx, collection.GetCollectionID()) + ob.meta.ReplicaManager.RemoveCollection(ctx, collection.GetCollectionID()) ob.targetObserver.ReleaseCollection(collection.GetCollectionID()) ob.loadTasks.Remove(traceID) } case querypb.LoadType_LoadPartition: partitionIDs := typeutil.NewSet(task.PartitionIDs...) - partitions := ob.meta.GetPartitionsByCollection(task.CollectionID) + partitions := ob.meta.GetPartitionsByCollection(ctx, task.CollectionID) partitions = lo.Filter(partitions, func(partition *meta.Partition, _ int) bool { return partitionIDs.Contain(partition.GetPartitionID()) }) @@ -213,16 +213,16 @@ func (ob *CollectionObserver) observeTimeout() { zap.Int64("collectionID", task.CollectionID), zap.Int64s("partitionIDs", task.PartitionIDs)) for _, partition := range partitions { - ob.meta.CollectionManager.RemovePartition(partition.CollectionID, partition.GetPartitionID()) + ob.meta.CollectionManager.RemovePartition(ctx, partition.CollectionID, partition.GetPartitionID()) ob.targetObserver.ReleasePartition(partition.GetCollectionID(), partition.GetPartitionID()) } // all partition timeout, remove collection - if len(ob.meta.CollectionManager.GetPartitionsByCollection(task.CollectionID)) == 0 { + if len(ob.meta.CollectionManager.GetPartitionsByCollection(ctx, task.CollectionID)) == 0 { log.Info("collection timeout due to all partition removed", zap.Int64("collection", task.CollectionID)) - ob.meta.CollectionManager.RemoveCollection(task.CollectionID) - ob.meta.ReplicaManager.RemoveCollection(task.CollectionID) + ob.meta.CollectionManager.RemoveCollection(ctx, task.CollectionID) + ob.meta.ReplicaManager.RemoveCollection(ctx, task.CollectionID) ob.targetObserver.ReleaseCollection(task.CollectionID) } } @@ -231,9 +231,9 @@ func (ob *CollectionObserver) observeTimeout() { }) } -func (ob *CollectionObserver) readyToObserve(collectionID int64) bool { - metaExist := (ob.meta.GetCollection(collectionID) != nil) - targetExist := ob.targetMgr.IsNextTargetExist(collectionID) || ob.targetMgr.IsCurrentTargetExist(collectionID, common.AllPartitionsID) +func (ob *CollectionObserver) readyToObserve(ctx context.Context, collectionID int64) bool { + metaExist := (ob.meta.GetCollection(ctx, collectionID) != nil) + targetExist := ob.targetMgr.IsNextTargetExist(ctx, collectionID) || ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, common.AllPartitionsID) return metaExist && targetExist } @@ -243,7 +243,7 @@ func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) { ob.loadTasks.Range(func(traceID string, task LoadTask) bool { loading = true - collection := ob.meta.CollectionManager.GetCollection(task.CollectionID) + collection := ob.meta.CollectionManager.GetCollection(ctx, task.CollectionID) if collection == nil { return true } @@ -251,10 +251,10 @@ func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) { var partitions []*meta.Partition switch task.LoadType { case querypb.LoadType_LoadCollection: - partitions = ob.meta.GetPartitionsByCollection(task.CollectionID) + partitions = ob.meta.GetPartitionsByCollection(ctx, task.CollectionID) case querypb.LoadType_LoadPartition: partitionIDs := typeutil.NewSet[int64](task.PartitionIDs...) - partitions = ob.meta.GetPartitionsByCollection(task.CollectionID) + partitions = ob.meta.GetPartitionsByCollection(ctx, task.CollectionID) partitions = lo.Filter(partitions, func(partition *meta.Partition, _ int) bool { return partitionIDs.Contain(partition.GetPartitionID()) }) @@ -265,11 +265,11 @@ func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) { if partition.LoadPercentage == 100 { continue } - if ob.readyToObserve(partition.CollectionID) { - replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID()) + if ob.readyToObserve(ctx, partition.CollectionID) { + replicaNum := ob.meta.GetReplicaNumber(ctx, partition.GetCollectionID()) ob.observePartitionLoadStatus(ctx, partition, replicaNum) } - partition = ob.meta.GetPartition(partition.PartitionID) + partition = ob.meta.GetPartition(ctx, partition.PartitionID) if partition != nil && partition.LoadPercentage != 100 { loaded = false } @@ -299,8 +299,8 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa zap.Int64("partitionID", partition.GetPartitionID()), ) - segmentTargets := ob.targetMgr.GetSealedSegmentsByPartition(partition.GetCollectionID(), partition.GetPartitionID(), meta.NextTarget) - channelTargets := ob.targetMgr.GetDmChannelsByCollection(partition.GetCollectionID(), meta.NextTarget) + segmentTargets := ob.targetMgr.GetSealedSegmentsByPartition(ctx, partition.GetCollectionID(), partition.GetPartitionID(), meta.NextTarget) + channelTargets := ob.targetMgr.GetDmChannelsByCollection(ctx, partition.GetCollectionID(), meta.NextTarget) targetNum := len(segmentTargets) + len(channelTargets) if targetNum == 0 { @@ -320,7 +320,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa for _, channel := range channelTargets { views := ob.dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName())) nodes := lo.Map(views, func(v *meta.LeaderView, _ int) int64 { return v.ID }) - group := utils.GroupNodesByReplica(ob.meta.ReplicaManager, partition.GetCollectionID(), nodes) + group := utils.GroupNodesByReplica(ctx, ob.meta.ReplicaManager, partition.GetCollectionID(), nodes) loadedCount += len(group) } subChannelCount := loadedCount @@ -329,7 +329,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa meta.WithChannelName2LeaderView(segment.GetInsertChannel()), meta.WithSegment2LeaderView(segment.GetID(), false)) nodes := lo.Map(views, func(view *meta.LeaderView, _ int) int64 { return view.ID }) - group := utils.GroupNodesByReplica(ob.meta.ReplicaManager, partition.GetCollectionID(), nodes) + group := utils.GroupNodesByReplica(ctx, ob.meta.ReplicaManager, partition.GetCollectionID(), nodes) loadedCount += len(group) } if loadedCount > 0 { @@ -352,7 +352,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa } delete(ob.partitionLoadedCount, partition.GetPartitionID()) } - collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(partition.PartitionID, loadPercentage) + collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(ctx, partition.PartitionID, loadPercentage) if err != nil { log.Warn("failed to update load percentage") } diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index 3293103c855df..d8f9e62247666 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -75,6 +75,8 @@ type CollectionObserverSuite struct { // Test object ob *CollectionObserver + + ctx context.Context } func (suite *CollectionObserverSuite) SetupSuite() { @@ -236,6 +238,7 @@ func (suite *CollectionObserverSuite) SetupTest() { suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 3, })) + suite.ctx = context.Background() } func (suite *CollectionObserverSuite) TearDownTest() { @@ -249,7 +252,7 @@ func (suite *CollectionObserverSuite) TestObserve() { timeout = 3 * time.Second ) // time before load - time := suite.meta.GetCollection(suite.collections[2]).UpdatedAt + time := suite.meta.GetCollection(suite.ctx, suite.collections[2]).UpdatedAt // Not timeout paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "3") @@ -357,12 +360,13 @@ func (suite *CollectionObserverSuite) TestObservePartition() { } func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool { - exist := suite.meta.Exist(collection) - percentage := suite.meta.CalculateLoadPercentage(collection) - status := suite.meta.CalculateLoadStatus(collection) - replicas := suite.meta.ReplicaManager.GetByCollection(collection) - channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) - segments := suite.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget) + ctx := suite.ctx + exist := suite.meta.Exist(ctx, collection) + percentage := suite.meta.CalculateLoadPercentage(ctx, collection) + status := suite.meta.CalculateLoadStatus(ctx, collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) + channels := suite.targetMgr.GetDmChannelsByCollection(ctx, collection, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByCollection(ctx, collection, meta.CurrentTarget) return exist && percentage == 100 && @@ -373,15 +377,16 @@ func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool } func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool { - partition := suite.meta.GetPartition(partitionID) + ctx := suite.ctx + partition := suite.meta.GetPartition(ctx, partitionID) if partition == nil { return false } collection := partition.GetCollectionID() - percentage := suite.meta.GetPartitionLoadPercentage(partitionID) + percentage := suite.meta.GetPartitionLoadPercentage(ctx, partitionID) status := partition.GetStatus() - channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) - segments := suite.targetMgr.GetSealedSegmentsByPartition(collection, partitionID, meta.CurrentTarget) + channels := suite.targetMgr.GetDmChannelsByCollection(ctx, collection, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByPartition(ctx, collection, partitionID, meta.CurrentTarget) expectedSegments := lo.Filter(suite.segments[collection], func(seg *datapb.SegmentInfo, _ int) bool { return seg.PartitionID == partitionID }) @@ -392,10 +397,11 @@ func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool } func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool { - exist := suite.meta.Exist(collection) - replicas := suite.meta.ReplicaManager.GetByCollection(collection) - channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) - segments := suite.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget) + ctx := suite.ctx + exist := suite.meta.Exist(ctx, collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) + channels := suite.targetMgr.GetDmChannelsByCollection(ctx, collection, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByCollection(ctx, collection, meta.CurrentTarget) return !(exist || len(replicas) > 0 || len(channels) > 0 || @@ -403,36 +409,39 @@ func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool } func (suite *CollectionObserverSuite) isPartitionTimeout(collection int64, partitionID int64) bool { - partition := suite.meta.GetPartition(partitionID) - segments := suite.targetMgr.GetSealedSegmentsByPartition(collection, partitionID, meta.CurrentTarget) + ctx := suite.ctx + partition := suite.meta.GetPartition(ctx, partitionID) + segments := suite.targetMgr.GetSealedSegmentsByPartition(ctx, collection, partitionID, meta.CurrentTarget) return partition == nil && len(segments) == 0 } func (suite *CollectionObserverSuite) isCollectionLoadedContinue(collection int64, beforeTime time.Time) bool { - return suite.meta.GetCollection(collection).UpdatedAt.After(beforeTime) + return suite.meta.GetCollection(suite.ctx, collection).UpdatedAt.After(beforeTime) } func (suite *CollectionObserverSuite) loadAll() { + ctx := suite.ctx for _, collection := range suite.collections { suite.load(collection) } - suite.targetMgr.UpdateCollectionCurrentTarget(suite.collections[0]) - suite.targetMgr.UpdateCollectionNextTarget(suite.collections[0]) - suite.targetMgr.UpdateCollectionCurrentTarget(suite.collections[2]) - suite.targetMgr.UpdateCollectionNextTarget(suite.collections[2]) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, suite.collections[0]) + suite.targetMgr.UpdateCollectionNextTarget(ctx, suite.collections[0]) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, suite.collections[2]) + suite.targetMgr.UpdateCollectionNextTarget(ctx, suite.collections[2]) } func (suite *CollectionObserverSuite) load(collection int64) { + ctx := suite.ctx // Mock meta data - replicas, err := suite.meta.ReplicaManager.Spawn(collection, map[string]int{meta.DefaultResourceGroupName: int(suite.replicaNumber[collection])}, nil) + replicas, err := suite.meta.ReplicaManager.Spawn(ctx, collection, map[string]int{meta.DefaultResourceGroupName: int(suite.replicaNumber[collection])}, nil) suite.NoError(err) for _, replica := range replicas { replica.AddRWNode(suite.nodes...) } - err = suite.meta.ReplicaManager.Put(replicas...) + err = suite.meta.ReplicaManager.Put(ctx, replicas...) suite.NoError(err) - suite.meta.PutCollection(&meta.Collection{ + suite.meta.PutCollection(ctx, &meta.Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection, ReplicaNumber: suite.replicaNumber[collection], @@ -444,7 +453,7 @@ func (suite *CollectionObserverSuite) load(collection int64) { }) for _, partition := range suite.partitions[collection] { - suite.meta.PutPartition(&meta.Partition{ + suite.meta.PutPartition(ctx, &meta.Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: collection, PartitionID: partition, @@ -474,7 +483,7 @@ func (suite *CollectionObserverSuite) load(collection int64) { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(dmChannels, allSegments, nil) - suite.targetMgr.UpdateCollectionNextTarget(collection) + suite.targetMgr.UpdateCollectionNextTarget(ctx, collection) suite.ob.LoadCollection(context.Background(), collection) } diff --git a/internal/querycoordv2/observers/replica_observer.go b/internal/querycoordv2/observers/replica_observer.go index 4251fef99fae7..ae0be66efc693 100644 --- a/internal/querycoordv2/observers/replica_observer.go +++ b/internal/querycoordv2/observers/replica_observer.go @@ -71,7 +71,7 @@ func (ob *ReplicaObserver) schedule(ctx context.Context) { defer ob.wg.Done() log.Info("Start check replica loop") - listener := ob.meta.ResourceManager.ListenNodeChanged() + listener := ob.meta.ResourceManager.ListenNodeChanged(ctx) for { ob.waitNodeChangedOrTimeout(ctx, listener) // stop if the context is canceled. @@ -92,15 +92,16 @@ func (ob *ReplicaObserver) waitNodeChangedOrTimeout(ctx context.Context, listene } func (ob *ReplicaObserver) checkNodesInReplica() { - log := log.Ctx(context.Background()).WithRateGroup("qcv2.replicaObserver", 1, 60) - collections := ob.meta.GetAll() + ctx := context.Background() + log := log.Ctx(ctx).WithRateGroup("qcv2.replicaObserver", 1, 60) + collections := ob.meta.GetAll(ctx) for _, collectionID := range collections { - utils.RecoverReplicaOfCollection(ob.meta, collectionID) + utils.RecoverReplicaOfCollection(ctx, ob.meta, collectionID) } // check all ro nodes, remove it from replica if all segment/channel has been moved for _, collectionID := range collections { - replicas := ob.meta.ReplicaManager.GetByCollection(collectionID) + replicas := ob.meta.ReplicaManager.GetByCollection(ctx, collectionID) for _, replica := range replicas { roNodes := replica.GetRONodes() rwNodes := replica.GetRWNodes() @@ -130,7 +131,7 @@ func (ob *ReplicaObserver) checkNodesInReplica() { zap.Int64s("roNodes", roNodes), zap.Int64s("rwNodes", rwNodes), ) - if err := ob.meta.ReplicaManager.RemoveNode(replica.GetID(), removeNodes...); err != nil { + if err := ob.meta.ReplicaManager.RemoveNode(ctx, replica.GetID(), removeNodes...); err != nil { logger.Warn("fail to remove node from replica", zap.Error(err)) continue } diff --git a/internal/querycoordv2/observers/replica_observer_test.go b/internal/querycoordv2/observers/replica_observer_test.go index 9f9062488cb86..266d731a00d22 100644 --- a/internal/querycoordv2/observers/replica_observer_test.go +++ b/internal/querycoordv2/observers/replica_observer_test.go @@ -16,6 +16,7 @@ package observers import ( + "context" "testing" "time" @@ -47,6 +48,7 @@ type ReplicaObserverSuite struct { collectionID int64 partitionID int64 + ctx context.Context } func (suite *ReplicaObserverSuite) SetupSuite() { @@ -67,6 +69,7 @@ func (suite *ReplicaObserverSuite) SetupTest() { config.EtcdTLSMinVersion.GetValue()) suite.Require().NoError(err) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.ctx = context.Background() // meta store := querycoord.NewCatalog(suite.kv) @@ -82,11 +85,12 @@ func (suite *ReplicaObserverSuite) SetupTest() { } func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { - suite.meta.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{ + ctx := suite.ctx + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 2}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 2}, }) - suite.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 2}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 2}, }) @@ -110,14 +114,14 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { Address: "localhost:8080", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(1) - suite.meta.ResourceManager.HandleNodeUp(2) - suite.meta.ResourceManager.HandleNodeUp(3) - suite.meta.ResourceManager.HandleNodeUp(4) + suite.meta.ResourceManager.HandleNodeUp(ctx, 1) + suite.meta.ResourceManager.HandleNodeUp(ctx, 2) + suite.meta.ResourceManager.HandleNodeUp(ctx, 3) + suite.meta.ResourceManager.HandleNodeUp(ctx, 4) - err := suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 2)) + err := suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(suite.collectionID, 2)) suite.NoError(err) - replicas, err := suite.meta.Spawn(suite.collectionID, map[string]int{ + replicas, err := suite.meta.Spawn(ctx, suite.collectionID, map[string]int{ "rg1": 1, "rg2": 1, }, nil) @@ -127,7 +131,7 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { suite.Eventually(func() bool { availableNodes := typeutil.NewUniqueSet() for _, r := range replicas { - replica := suite.meta.ReplicaManager.Get(r.GetID()) + replica := suite.meta.ReplicaManager.Get(ctx, r.GetID()) suite.NotNil(replica) if replica.RWNodesCount() != 2 { return false @@ -151,13 +155,13 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { } // Do a replica transfer. - suite.meta.ReplicaManager.TransferReplica(suite.collectionID, "rg1", "rg2", 1) + suite.meta.ReplicaManager.TransferReplica(ctx, suite.collectionID, "rg1", "rg2", 1) // All replica should in the rg2 but not rg1 // And some nodes will become ro nodes before all segment and channel on it is cleaned. suite.Eventually(func() bool { for _, r := range replicas { - replica := suite.meta.ReplicaManager.Get(r.GetID()) + replica := suite.meta.ReplicaManager.Get(ctx, r.GetID()) suite.NotNil(replica) suite.Equal("rg2", replica.GetResourceGroup()) // all replica should have ro nodes. @@ -178,7 +182,7 @@ func (suite *ReplicaObserverSuite) TestCheckNodesInReplica() { suite.Eventually(func() bool { for _, r := range replicas { - replica := suite.meta.ReplicaManager.Get(r.GetID()) + replica := suite.meta.ReplicaManager.Get(ctx, r.GetID()) suite.NotNil(replica) suite.Equal("rg2", replica.GetResourceGroup()) if replica.RONodesCount() > 0 { diff --git a/internal/querycoordv2/observers/resource_observer.go b/internal/querycoordv2/observers/resource_observer.go index 3e0938b0d5af9..ae701dc52ef00 100644 --- a/internal/querycoordv2/observers/resource_observer.go +++ b/internal/querycoordv2/observers/resource_observer.go @@ -69,7 +69,7 @@ func (ob *ResourceObserver) schedule(ctx context.Context) { defer ob.wg.Done() log.Info("Start check resource group loop") - listener := ob.meta.ResourceManager.ListenResourceGroupChanged() + listener := ob.meta.ResourceManager.ListenResourceGroupChanged(ctx) for { ob.waitRGChangedOrTimeout(ctx, listener) // stop if the context is canceled. @@ -79,7 +79,7 @@ func (ob *ResourceObserver) schedule(ctx context.Context) { } // do check once. - ob.checkAndRecoverResourceGroup() + ob.checkAndRecoverResourceGroup(ctx) } } @@ -89,29 +89,29 @@ func (ob *ResourceObserver) waitRGChangedOrTimeout(ctx context.Context, listener listener.Wait(ctxWithTimeout) } -func (ob *ResourceObserver) checkAndRecoverResourceGroup() { +func (ob *ResourceObserver) checkAndRecoverResourceGroup(ctx context.Context) { manager := ob.meta.ResourceManager - rgNames := manager.ListResourceGroups() + rgNames := manager.ListResourceGroups(ctx) enableRGAutoRecover := params.Params.QueryCoordCfg.EnableRGAutoRecover.GetAsBool() log.Debug("start to check resource group", zap.Bool("enableRGAutoRecover", enableRGAutoRecover), zap.Int("resourceGroupNum", len(rgNames))) // Check if there is any incoming node. - if manager.CheckIncomingNodeNum() > 0 { - log.Info("new incoming node is ready to be assigned...", zap.Int("incomingNodeNum", manager.CheckIncomingNodeNum())) - manager.AssignPendingIncomingNode() + if manager.CheckIncomingNodeNum(ctx) > 0 { + log.Info("new incoming node is ready to be assigned...", zap.Int("incomingNodeNum", manager.CheckIncomingNodeNum(ctx))) + manager.AssignPendingIncomingNode(ctx) } log.Debug("recover resource groups...") // Recover all resource group into expected configuration. for _, rgName := range rgNames { - if err := manager.MeetRequirement(rgName); err != nil { + if err := manager.MeetRequirement(ctx, rgName); err != nil { log.Info("found resource group need to be recovered", zap.String("rgName", rgName), zap.String("reason", err.Error()), ) if enableRGAutoRecover { - err := manager.AutoRecoverResourceGroup(rgName) + err := manager.AutoRecoverResourceGroup(ctx, rgName) if err != nil { log.Warn("failed to recover resource group", zap.String("rgName", rgName), diff --git a/internal/querycoordv2/observers/resource_observer_test.go b/internal/querycoordv2/observers/resource_observer_test.go index 9079d89e846ce..42f7149b8d478 100644 --- a/internal/querycoordv2/observers/resource_observer_test.go +++ b/internal/querycoordv2/observers/resource_observer_test.go @@ -16,6 +16,7 @@ package observers import ( + "context" "fmt" "testing" "time" @@ -47,6 +48,8 @@ type ResourceObserverSuite struct { collectionID int64 partitionID int64 + + ctx context.Context } func (suite *ResourceObserverSuite) SetupSuite() { @@ -67,6 +70,7 @@ func (suite *ResourceObserverSuite) SetupTest() { config.EtcdTLSMinVersion.GetValue()) suite.Require().NoError(err) suite.kv = etcdKV.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.ctx = context.Background() // meta suite.store = mocks.NewQueryCoordCatalog(suite.T()) @@ -76,15 +80,15 @@ func (suite *ResourceObserverSuite) SetupTest() { suite.observer = NewResourceObserver(suite.meta) - suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil) suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil) + suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil) for i := 0; i < 10; i++ { suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: int64(i), Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(int64(i)) + suite.meta.ResourceManager.HandleNodeUp(suite.ctx, int64(i)) } } @@ -93,80 +97,82 @@ func (suite *ResourceObserverSuite) TearDownTest() { } func (suite *ResourceObserverSuite) TestObserverRecoverOperation() { - suite.meta.ResourceManager.AddResourceGroup("rg", &rgpb.ResourceGroupConfig{ + ctx := suite.ctx + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 4}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 6}, }) - suite.Error(suite.meta.ResourceManager.MeetRequirement("rg")) + suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg")) // There's 10 exists node in cluster, new incoming resource group should get 4 nodes after recover. - suite.observer.checkAndRecoverResourceGroup() - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg")) + suite.observer.checkAndRecoverResourceGroup(ctx) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg")) - suite.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 6}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 10}, }) - suite.Error(suite.meta.ResourceManager.MeetRequirement("rg2")) + suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) // There's 10 exists node in cluster, new incoming resource group should get 6 nodes after recover. - suite.observer.checkAndRecoverResourceGroup() - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) + suite.observer.checkAndRecoverResourceGroup(ctx) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) - suite.meta.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 1}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 1}, }) - suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3")) + suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3")) // There's 10 exists node in cluster, but has been occupied by rg1 and rg2, new incoming resource group cannot get any node. - suite.observer.checkAndRecoverResourceGroup() - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) - suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3")) + suite.observer.checkAndRecoverResourceGroup(ctx) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) + suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3")) // New node up, rg3 should get the node. suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 10, })) - suite.meta.ResourceManager.HandleNodeUp(10) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg3")) + suite.meta.ResourceManager.HandleNodeUp(ctx, 10) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3")) // new node is down, rg3 cannot use that node anymore. suite.nodeMgr.Remove(10) - suite.meta.ResourceManager.HandleNodeDown(10) - suite.observer.checkAndRecoverResourceGroup() - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) - suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3")) + suite.meta.ResourceManager.HandleNodeDown(ctx, 10) + suite.observer.checkAndRecoverResourceGroup(ctx) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) + suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3")) // create a new incoming node failure. - suite.store.EXPECT().SaveResourceGroup(mock.Anything).Unset() - suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(errors.New("failure")) + suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Unset() + suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(errors.New("failure")) suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ NodeID: 11, })) // should be failure, so new node cannot be used by rg3. - suite.meta.ResourceManager.HandleNodeUp(11) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) - suite.Error(suite.meta.ResourceManager.MeetRequirement("rg3")) - suite.store.EXPECT().SaveResourceGroup(mock.Anything).Unset() - suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil) + suite.meta.ResourceManager.HandleNodeUp(ctx, 11) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) + suite.Error(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3")) + suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Unset() + suite.store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil) // storage recovered, so next recover will be success. - suite.observer.checkAndRecoverResourceGroup() - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg1")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg2")) - suite.NoError(suite.meta.ResourceManager.MeetRequirement("rg3")) + suite.observer.checkAndRecoverResourceGroup(ctx) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg1")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg2")) + suite.NoError(suite.meta.ResourceManager.MeetRequirement(ctx, "rg3")) } func (suite *ResourceObserverSuite) TestSchedule() { suite.observer.Start() defer suite.observer.Stop() + ctx := suite.ctx check := func() { suite.Eventually(func() bool { - rgs := suite.meta.ResourceManager.ListResourceGroups() + rgs := suite.meta.ResourceManager.ListResourceGroups(ctx) for _, rg := range rgs { - if err := suite.meta.ResourceManager.GetResourceGroup(rg).MeetRequirement(); err != nil { + if err := suite.meta.ResourceManager.GetResourceGroup(ctx, rg).MeetRequirement(); err != nil { return false } } @@ -175,7 +181,7 @@ func (suite *ResourceObserverSuite) TestSchedule() { } for i := 1; i <= 4; i++ { - suite.meta.ResourceManager.AddResourceGroup(fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{ + suite.meta.ResourceManager.AddResourceGroup(ctx, fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: int32(i)}, Limits: &rgpb.ResourceGroupLimit{NodeNum: int32(i)}, }) @@ -183,7 +189,7 @@ func (suite *ResourceObserverSuite) TestSchedule() { check() for i := 1; i <= 4; i++ { - suite.meta.ResourceManager.AddResourceGroup(fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{ + suite.meta.ResourceManager.AddResourceGroup(ctx, fmt.Sprintf("rg%d", i), &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 0}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 0}, }) diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index 45e6345488b45..eb415ba567620 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -169,14 +169,14 @@ func (ob *TargetObserver) schedule(ctx context.Context) { return case <-ob.initChan: - for _, collectionID := range ob.meta.GetAll() { + for _, collectionID := range ob.meta.GetAll(ctx) { ob.init(ctx, collectionID) } log.Info("target observer init done") case <-ticker.C: ob.clean() - loaded := lo.FilterMap(ob.meta.GetAllCollections(), func(collection *meta.Collection, _ int) (int64, bool) { + loaded := lo.FilterMap(ob.meta.GetAllCollections(ctx), func(collection *meta.Collection, _ int) (int64, bool) { if collection.GetStatus() == querypb.LoadStatus_Loaded { return collection.GetCollectionID(), true } @@ -192,7 +192,7 @@ func (ob *TargetObserver) schedule(ctx context.Context) { switch req.opType { case UpdateCollection: ob.keylocks.Lock(req.CollectionID) - err := ob.updateNextTarget(req.CollectionID) + err := ob.updateNextTarget(ctx, req.CollectionID) ob.keylocks.Unlock(req.CollectionID) if err != nil { log.Warn("failed to manually update next target", @@ -214,10 +214,10 @@ func (ob *TargetObserver) schedule(ctx context.Context) { delete(ob.readyNotifiers, req.CollectionID) ob.mut.Unlock() - ob.targetMgr.RemoveCollection(req.CollectionID) + ob.targetMgr.RemoveCollection(ctx, req.CollectionID) req.Notifier <- nil case ReleasePartition: - ob.targetMgr.RemovePartition(req.CollectionID, req.PartitionIDs...) + ob.targetMgr.RemovePartition(ctx, req.CollectionID, req.PartitionIDs...) req.Notifier <- nil } log.Info("manually trigger update target done", @@ -230,7 +230,7 @@ func (ob *TargetObserver) schedule(ctx context.Context) { // Check whether provided collection is has current target. // If not, submit an async task into dispatcher. func (ob *TargetObserver) Check(ctx context.Context, collectionID int64, partitionID int64) bool { - result := ob.targetMgr.IsCurrentTargetExist(collectionID, partitionID) + result := ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, partitionID) if !result { ob.loadingDispatcher.AddTask(collectionID) } @@ -246,24 +246,24 @@ func (ob *TargetObserver) check(ctx context.Context, collectionID int64) { defer ob.keylocks.Unlock(collectionID) if ob.shouldUpdateCurrentTarget(ctx, collectionID) { - ob.updateCurrentTarget(collectionID) + ob.updateCurrentTarget(ctx, collectionID) } - if ob.shouldUpdateNextTarget(collectionID) { + if ob.shouldUpdateNextTarget(ctx, collectionID) { // update next target in collection level - ob.updateNextTarget(collectionID) + ob.updateNextTarget(ctx, collectionID) } } func (ob *TargetObserver) init(ctx context.Context, collectionID int64) { // pull next target first if not exist - if !ob.targetMgr.IsNextTargetExist(collectionID) { - ob.updateNextTarget(collectionID) + if !ob.targetMgr.IsNextTargetExist(ctx, collectionID) { + ob.updateNextTarget(ctx, collectionID) } // try to update current target if all segment/channel are ready if ob.shouldUpdateCurrentTarget(ctx, collectionID) { - ob.updateCurrentTarget(collectionID) + ob.updateCurrentTarget(ctx, collectionID) } // refresh collection loading status upon restart ob.check(ctx, collectionID) @@ -310,7 +310,7 @@ func (ob *TargetObserver) ReleasePartition(collectionID int64, partitionID ...in } func (ob *TargetObserver) clean() { - collectionSet := typeutil.NewUniqueSet(ob.meta.GetAll()...) + collectionSet := typeutil.NewUniqueSet(ob.meta.GetAll(context.TODO())...) // for collection which has been removed from target, try to clear nextTargetLastUpdate ob.nextTargetLastUpdate.Range(func(collectionID int64, _ time.Time) bool { if !collectionSet.Contain(collectionID) { @@ -331,8 +331,8 @@ func (ob *TargetObserver) clean() { } } -func (ob *TargetObserver) shouldUpdateNextTarget(collectionID int64) bool { - return !ob.targetMgr.IsNextTargetExist(collectionID) || ob.isNextTargetExpired(collectionID) +func (ob *TargetObserver) shouldUpdateNextTarget(ctx context.Context, collectionID int64) bool { + return !ob.targetMgr.IsNextTargetExist(ctx, collectionID) || ob.isNextTargetExpired(collectionID) } func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool { @@ -343,12 +343,12 @@ func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool { return time.Since(lastUpdated) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second) } -func (ob *TargetObserver) updateNextTarget(collectionID int64) error { +func (ob *TargetObserver) updateNextTarget(ctx context.Context, collectionID int64) error { log := log.Ctx(context.TODO()).WithRateGroup("qcv2.TargetObserver", 1, 60). With(zap.Int64("collectionID", collectionID)) log.RatedInfo(10, "observer trigger update next target") - err := ob.targetMgr.UpdateCollectionNextTarget(collectionID) + err := ob.targetMgr.UpdateCollectionNextTarget(ctx, collectionID) if err != nil { log.Warn("failed to update next target for collection", zap.Error(err)) @@ -363,7 +363,7 @@ func (ob *TargetObserver) updateNextTargetTimestamp(collectionID int64) { } func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collectionID int64) bool { - replicaNum := ob.meta.CollectionManager.GetReplicaNumber(collectionID) + replicaNum := ob.meta.CollectionManager.GetReplicaNumber(ctx, collectionID) log := log.Ctx(ctx).WithRateGroup( fmt.Sprintf("qcv2.TargetObserver-%d", collectionID), 10, @@ -374,7 +374,7 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect ) // check channel first - channelNames := ob.targetMgr.GetDmChannelsByCollection(collectionID, meta.NextTarget) + channelNames := ob.targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.NextTarget) if len(channelNames) == 0 { // next target is empty, no need to update log.RatedInfo(10, "next target is empty, no need to update") @@ -402,13 +402,13 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect var partitions []int64 var indexInfo []*indexpb.IndexInfo var err error - newVersion := ob.targetMgr.GetCollectionTargetVersion(collectionID, meta.NextTarget) + newVersion := ob.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.NextTarget) for _, leader := range collectionReadyLeaders { updateVersionAction := ob.checkNeedUpdateTargetVersion(ctx, leader, newVersion) if updateVersionAction == nil { continue } - replica := ob.meta.ReplicaManager.GetByCollectionAndNode(collectionID, leader.ID) + replica := ob.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, leader.ID) if replica == nil { log.Warn("replica not found", zap.Int64("nodeID", leader.ID), zap.Int64("collectionID", collectionID)) continue @@ -422,7 +422,7 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect return false } - partitions, err = utils.GetPartitions(ob.meta.CollectionManager, collectionID) + partitions, err = utils.GetPartitions(ctx, ob.meta.CollectionManager, collectionID) if err != nil { log.Warn("failed to get partitions", zap.Error(err)) return false @@ -467,7 +467,7 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, leade Actions: diffs, Schema: collectionInfo.GetSchema(), LoadMeta: &querypb.LoadMetaInfo{ - LoadType: ob.meta.GetLoadType(leaderView.CollectionID), + LoadType: ob.meta.GetLoadType(ctx, leaderView.CollectionID), CollectionID: leaderView.CollectionID, PartitionIDs: partitions, DbName: collectionInfo.GetDbName(), @@ -506,10 +506,10 @@ func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, lead zap.Int64("newVersion", targetVersion), ) - sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTarget) - growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTarget) - droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTarget) - channel := ob.targetMgr.GetDmChannel(leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst) + sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget) + growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget) + droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget) + channel := ob.targetMgr.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst) action := &querypb.SyncAction{ Type: querypb.SyncType_UpdateVersion, @@ -526,10 +526,10 @@ func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, lead return action } -func (ob *TargetObserver) updateCurrentTarget(collectionID int64) { - log := log.Ctx(context.TODO()).WithRateGroup("qcv2.TargetObserver", 1, 60) +func (ob *TargetObserver) updateCurrentTarget(ctx context.Context, collectionID int64) { + log := log.Ctx(ctx).WithRateGroup("qcv2.TargetObserver", 1, 60) log.RatedInfo(10, "observer trigger update current target", zap.Int64("collectionID", collectionID)) - if ob.targetMgr.UpdateCollectionCurrentTarget(collectionID) { + if ob.targetMgr.UpdateCollectionCurrentTarget(ctx, collectionID) { ob.mut.Lock() defer ob.mut.Unlock() notifiers := ob.readyNotifiers[collectionID] diff --git a/internal/querycoordv2/observers/target_observer_test.go b/internal/querycoordv2/observers/target_observer_test.go index dcbd8ec5f247e..68903bae2c23e 100644 --- a/internal/querycoordv2/observers/target_observer_test.go +++ b/internal/querycoordv2/observers/target_observer_test.go @@ -57,6 +57,7 @@ type TargetObserverSuite struct { partitionID int64 nextTargetSegments []*datapb.SegmentInfo nextTargetChannels []*datapb.VchannelInfo + ctx context.Context } func (suite *TargetObserverSuite) SetupSuite() { @@ -77,6 +78,7 @@ func (suite *TargetObserverSuite) SetupTest() { config.EtcdTLSMinVersion.GetValue()) suite.Require().NoError(err) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.ctx = context.Background() // meta nodeMgr := session.NewNodeManager() @@ -100,14 +102,14 @@ func (suite *TargetObserverSuite) SetupTest() { testCollection := utils.CreateTestCollection(suite.collectionID, 1) testCollection.Status = querypb.LoadStatus_Loaded - err = suite.meta.CollectionManager.PutCollection(testCollection) + err = suite.meta.CollectionManager.PutCollection(suite.ctx, testCollection) suite.NoError(err) - err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID)) + err = suite.meta.CollectionManager.PutPartition(suite.ctx, utils.CreateTestPartition(suite.collectionID, suite.partitionID)) suite.NoError(err) - replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) + replicas, err := suite.meta.ReplicaManager.Spawn(suite.ctx, suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) suite.NoError(err) replicas[0].AddRWNode(2) - err = suite.meta.ReplicaManager.Put(replicas...) + err = suite.meta.ReplicaManager.Put(suite.ctx, replicas...) suite.NoError(err) suite.nextTargetChannels = []*datapb.VchannelInfo{ @@ -140,9 +142,11 @@ func (suite *TargetObserverSuite) SetupTest() { } func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { + ctx := suite.ctx + suite.Eventually(func() bool { - return len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 2 && - len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2 + return len(suite.targetMgr.GetSealedSegmentsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 2 && + len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 2 }, 5*time.Second, 1*time.Second) suite.distMgr.LeaderViewManager.Update(2, @@ -166,7 +170,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { // Never update current target if it's empty, even the next target is ready suite.Eventually(func() bool { - return len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.CurrentTarget)) == 0 + return len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.CurrentTarget)) == 0 }, 3*time.Second, 1*time.Second) suite.broker.AssertExpectations(suite.T()) @@ -176,7 +180,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { PartitionID: suite.partitionID, InsertChannel: "channel-1", }) - suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, suite.collectionID) // Pull next again suite.broker.EXPECT(). @@ -184,8 +188,8 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { Return(suite.nextTargetChannels, suite.nextTargetSegments, nil) suite.Eventually(func() bool { - return len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 && - len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2 + return len(suite.targetMgr.GetSealedSegmentsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 3 && + len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.NextTarget)) == 2 }, 7*time.Second, 1*time.Second) suite.broker.AssertExpectations(suite.T()) @@ -226,18 +230,19 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { default: } return isReady && - len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.CurrentTarget)) == 3 && - len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.CurrentTarget)) == 2 + len(suite.targetMgr.GetSealedSegmentsByCollection(ctx, suite.collectionID, meta.CurrentTarget)) == 3 && + len(suite.targetMgr.GetDmChannelsByCollection(ctx, suite.collectionID, meta.CurrentTarget)) == 2 }, 7*time.Second, 1*time.Second) } func (suite *TargetObserverSuite) TestTriggerRelease() { + ctx := suite.ctx // Manually update next target _, err := suite.observer.UpdateNextTarget(suite.collectionID) suite.NoError(err) // manually release partition - partitions := suite.meta.CollectionManager.GetPartitionsByCollection(suite.collectionID) + partitions := suite.meta.CollectionManager.GetPartitionsByCollection(ctx, suite.collectionID) partitionIDs := lo.Map(partitions, func(partition *meta.Partition, _ int) int64 { return partition.PartitionID }) suite.observer.ReleasePartition(suite.collectionID, partitionIDs[0]) @@ -265,6 +270,7 @@ type TargetObserverCheckSuite struct { collectionID int64 partitionID int64 + ctx context.Context } func (suite *TargetObserverCheckSuite) SetupSuite() { @@ -284,6 +290,7 @@ func (suite *TargetObserverCheckSuite) SetupTest() { config.EtcdTLSMinVersion.GetValue()) suite.Require().NoError(err) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + suite.ctx = context.Background() // meta store := querycoord.NewCatalog(suite.kv) @@ -306,14 +313,14 @@ func (suite *TargetObserverCheckSuite) SetupTest() { suite.collectionID = int64(1000) suite.partitionID = int64(100) - err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1)) + err = suite.meta.CollectionManager.PutCollection(suite.ctx, utils.CreateTestCollection(suite.collectionID, 1)) suite.NoError(err) - err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID)) + err = suite.meta.CollectionManager.PutPartition(suite.ctx, utils.CreateTestPartition(suite.collectionID, suite.partitionID)) suite.NoError(err) - replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) + replicas, err := suite.meta.ReplicaManager.Spawn(suite.ctx, suite.collectionID, map[string]int{meta.DefaultResourceGroupName: 1}, nil) suite.NoError(err) replicas[0].AddRWNode(2) - err = suite.meta.ReplicaManager.Put(replicas...) + err = suite.meta.ReplicaManager.Put(suite.ctx, replicas...) suite.NoError(err) } diff --git a/internal/querycoordv2/ops_service_test.go b/internal/querycoordv2/ops_service_test.go index 2eb4a1c34ae1f..db56c8ded85d4 100644 --- a/internal/querycoordv2/ops_service_test.go +++ b/internal/querycoordv2/ops_service_test.go @@ -440,8 +440,8 @@ func (suite *OpsServiceSuite) TestSuspendAndResumeNode() { Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(1) - nodes, err := suite.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName) + suite.meta.ResourceManager.HandleNodeUp(ctx, 1) + nodes, err := suite.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName) suite.NoError(err) suite.Contains(nodes, int64(1)) // test success @@ -451,7 +451,7 @@ func (suite *OpsServiceSuite) TestSuspendAndResumeNode() { }) suite.NoError(err) suite.True(merr.Ok(resp)) - nodes, err = suite.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName) + nodes, err = suite.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName) suite.NoError(err) suite.NotContains(nodes, int64(1)) @@ -460,7 +460,7 @@ func (suite *OpsServiceSuite) TestSuspendAndResumeNode() { }) suite.NoError(err) suite.True(merr.Ok(resp)) - nodes, err = suite.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName) + nodes, err = suite.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName) suite.NoError(err) suite.Contains(nodes, int64(1)) } @@ -492,10 +492,10 @@ func (suite *OpsServiceSuite) TestTransferSegment() { replicaID := int64(1) nodes := []int64{1, 2, 3, 4} replica := utils.CreateTestReplica(replicaID, collectionID, nodes) - suite.meta.ReplicaManager.Put(replica) + suite.meta.ReplicaManager.Put(ctx, replica) collection := utils.CreateTestCollection(collectionID, 1) partition := utils.CreateTestPartition(partitionID, collectionID) - suite.meta.PutCollection(collection, partition) + suite.meta.PutCollection(ctx, collection, partition) segmentIDs := []int64{1, 2, 3, 4} channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"} @@ -594,8 +594,8 @@ func (suite *OpsServiceSuite) TestTransferSegment() { suite.True(merr.Ok(resp)) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil) - suite.targetMgr.UpdateCollectionNextTarget(1) - suite.targetMgr.UpdateCollectionCurrentTarget(1) + suite.targetMgr.UpdateCollectionNextTarget(ctx, 1) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, 1) suite.dist.SegmentDistManager.Update(1, segmentInfos...) suite.dist.ChannelDistManager.Update(1, chanenlInfos...) @@ -605,7 +605,7 @@ func (suite *OpsServiceSuite) TestTransferSegment() { Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(node) + suite.meta.ResourceManager.HandleNodeUp(ctx, node) } // test transfer segment success, expect generate 1 balance segment task @@ -741,10 +741,10 @@ func (suite *OpsServiceSuite) TestTransferChannel() { replicaID := int64(1) nodes := []int64{1, 2, 3, 4} replica := utils.CreateTestReplica(replicaID, collectionID, nodes) - suite.meta.ReplicaManager.Put(replica) + suite.meta.ReplicaManager.Put(ctx, replica) collection := utils.CreateTestCollection(collectionID, 1) partition := utils.CreateTestPartition(partitionID, collectionID) - suite.meta.PutCollection(collection, partition) + suite.meta.PutCollection(ctx, collection, partition) segmentIDs := []int64{1, 2, 3, 4} channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"} @@ -845,8 +845,8 @@ func (suite *OpsServiceSuite) TestTransferChannel() { suite.True(merr.Ok(resp)) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil) - suite.targetMgr.UpdateCollectionNextTarget(1) - suite.targetMgr.UpdateCollectionCurrentTarget(1) + suite.targetMgr.UpdateCollectionNextTarget(ctx, 1) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, 1) suite.dist.SegmentDistManager.Update(1, segmentInfos...) suite.dist.ChannelDistManager.Update(1, chanenlInfos...) @@ -856,7 +856,7 @@ func (suite *OpsServiceSuite) TestTransferChannel() { Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(node) + suite.meta.ResourceManager.HandleNodeUp(ctx, node) } // test transfer channel success, expect generate 1 balance channel task diff --git a/internal/querycoordv2/ops_services.go b/internal/querycoordv2/ops_services.go index e9d76feb6635d..9051f50bd93e3 100644 --- a/internal/querycoordv2/ops_services.go +++ b/internal/querycoordv2/ops_services.go @@ -212,7 +212,7 @@ func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeReques return merr.Status(err), nil } - s.meta.ResourceManager.HandleNodeDown(req.GetNodeID()) + s.meta.ResourceManager.HandleNodeDown(ctx, req.GetNodeID()) return merr.Success(), nil } @@ -233,7 +233,7 @@ func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) return merr.Status(err), nil } - s.meta.ResourceManager.HandleNodeUp(req.GetNodeID()) + s.meta.ResourceManager.HandleNodeUp(ctx, req.GetNodeID()) return merr.Success(), nil } @@ -262,7 +262,7 @@ func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegme return merr.Status(err), nil } - replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID()) + replicas := s.meta.ReplicaManager.GetByNode(ctx, req.GetSourceNodeID()) for _, replica := range replicas { // when no dst node specified, default to use all other nodes in same dstNodeSet := typeutil.NewUniqueSet() @@ -292,7 +292,7 @@ func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegme return merr.Status(err), nil } - existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + existInTarget := s.targetMgr.GetSealedSegment(ctx, segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil if !existInTarget { log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", req.GetSegmentID())) } else { @@ -334,7 +334,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann return merr.Status(err), nil } - replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID()) + replicas := s.meta.ReplicaManager.GetByNode(ctx, req.GetSourceNodeID()) for _, replica := range replicas { // when no dst node specified, default to use all other nodes in same dstNodeSet := typeutil.NewUniqueSet() @@ -362,7 +362,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann err := merr.WrapErrChannelNotFound(req.GetChannelName(), "channel not found in source node") return merr.Status(err), nil } - existInTarget := s.targetMgr.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil + existInTarget := s.targetMgr.GetDmChannel(ctx, channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil if !existInTarget { log.Info("channel doesn't exist in current target, skip it", zap.String("channelName", channel.GetChannelName())) } else { @@ -414,7 +414,7 @@ func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.Ch return ch.GetChannelName(), ch }) for _, ch := range channelOnSrc { - if s.targetMgr.GetDmChannel(ch.GetCollectionID(), ch.GetChannelName(), meta.CurrentTargetFirst) == nil { + if s.targetMgr.GetDmChannel(ctx, ch.GetCollectionID(), ch.GetChannelName(), meta.CurrentTargetFirst) == nil { continue } @@ -430,7 +430,7 @@ func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.Ch return s.GetID(), s }) for _, segment := range segmentOnSrc { - if s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst) == nil { + if s.targetMgr.GetSealedSegment(ctx, segment.GetCollectionID(), segment.GetID(), meta.CurrentTargetFirst) == nil { continue } diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 43f0264448ed7..8549f3d1a2852 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -209,15 +209,15 @@ func (s *Server) registerMetricsRequest() { if v.Exists() { scope = meta.TargetScope(v.Int()) } - return s.targetMgr.GetTargetJSON(scope), nil + return s.targetMgr.GetTargetJSON(ctx, scope), nil } QueryReplicasAction := func(ctx context.Context, req *milvuspb.GetMetricsRequest, jsonReq gjson.Result) (string, error) { - return s.meta.GetReplicasJSON(), nil + return s.meta.GetReplicasJSON(ctx), nil } QueryResourceGroupsAction := func(ctx context.Context, req *milvuspb.GetMetricsRequest, jsonReq gjson.Result) (string, error) { - return s.meta.GetResourceGroupsJSON(), nil + return s.meta.GetResourceGroupsJSON(ctx), nil } QuerySegmentsAction := func(ctx context.Context, req *milvuspb.GetMetricsRequest, jsonReq gjson.Result) (string, error) { @@ -421,26 +421,26 @@ func (s *Server) initMeta() error { ) log.Info("recover meta...") - err := s.meta.CollectionManager.Recover(s.broker) + err := s.meta.CollectionManager.Recover(s.ctx, s.broker) if err != nil { log.Warn("failed to recover collections", zap.Error(err)) return err } - collections := s.meta.GetAll() + collections := s.meta.GetAll(s.ctx) log.Info("recovering collections...", zap.Int64s("collections", collections)) // We really update the metric after observers think the collection loaded. metrics.QueryCoordNumCollections.WithLabelValues().Set(0) - metrics.QueryCoordNumPartitions.WithLabelValues().Set(float64(len(s.meta.GetAllPartitions()))) + metrics.QueryCoordNumPartitions.WithLabelValues().Set(float64(len(s.meta.GetAllPartitions(s.ctx)))) - err = s.meta.ReplicaManager.Recover(collections) + err = s.meta.ReplicaManager.Recover(s.ctx, collections) if err != nil { log.Warn("failed to recover replicas", zap.Error(err)) return err } - err = s.meta.ResourceManager.Recover() + err = s.meta.ResourceManager.Recover(s.ctx) if err != nil { log.Warn("failed to recover resource groups", zap.Error(err)) return err @@ -452,7 +452,7 @@ func (s *Server) initMeta() error { LeaderViewManager: meta.NewLeaderViewManager(), } s.targetMgr = meta.NewTargetManager(s.broker, s.meta) - err = s.targetMgr.Recover(s.store) + err = s.targetMgr.Recover(s.ctx, s.store) if err != nil { log.Warn("failed to recover collection targets", zap.Error(err)) } @@ -609,7 +609,7 @@ func (s *Server) Stop() error { // save target to meta store, after querycoord restart, make it fast to recover current target // should save target after target observer stop, incase of target changed if s.targetMgr != nil { - s.targetMgr.SaveCurrentTarget(s.store) + s.targetMgr.SaveCurrentTarget(s.ctx, s.store) } if s.replicaObserver != nil { @@ -773,7 +773,7 @@ func (s *Server) watchNodes(revision int64) { ) s.nodeMgr.Stopping(nodeID) s.checkerController.Check() - s.meta.ResourceManager.HandleNodeStopping(nodeID) + s.meta.ResourceManager.HandleNodeStopping(s.ctx, nodeID) case sessionutil.SessionDelEvent: nodeID := event.Session.ServerID @@ -833,7 +833,7 @@ func (s *Server) handleNodeUp(node int64) { s.taskScheduler.AddExecutor(node) s.distController.StartDistInstance(s.ctx, node) // need assign to new rg and replica - s.meta.ResourceManager.HandleNodeUp(node) + s.meta.ResourceManager.HandleNodeUp(s.ctx, node) } func (s *Server) handleNodeDown(node int64) { @@ -848,18 +848,18 @@ func (s *Server) handleNodeDown(node int64) { // Clear tasks s.taskScheduler.RemoveByNode(node) - s.meta.ResourceManager.HandleNodeDown(node) + s.meta.ResourceManager.HandleNodeDown(s.ctx, node) } func (s *Server) checkNodeStateInRG() { - for _, rgName := range s.meta.ListResourceGroups() { - rg := s.meta.ResourceManager.GetResourceGroup(rgName) + for _, rgName := range s.meta.ListResourceGroups(s.ctx) { + rg := s.meta.ResourceManager.GetResourceGroup(s.ctx, rgName) for _, node := range rg.GetNodes() { info := s.nodeMgr.Get(node) if info == nil { - s.meta.ResourceManager.HandleNodeDown(node) + s.meta.ResourceManager.HandleNodeDown(s.ctx, node) } else if info.IsStoppingState() { - s.meta.ResourceManager.HandleNodeStopping(node) + s.meta.ResourceManager.HandleNodeStopping(s.ctx, node) } } } @@ -917,7 +917,7 @@ func (s *Server) watchLoadConfigChanges() { replicaNumHandler := config.NewHandler("watchReplicaNumberChanges", func(e *config.Event) { log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType)) - collectionIDs := s.meta.GetAll() + collectionIDs := s.meta.GetAll(s.ctx) if len(collectionIDs) == 0 { log.Warn("no collection loaded, skip to trigger update load config") return @@ -944,7 +944,7 @@ func (s *Server) watchLoadConfigChanges() { rgHandler := config.NewHandler("watchResourceGroupChanges", func(e *config.Event) { log.Info("watch load config changes", zap.String("key", e.Key), zap.String("value", e.Value), zap.String("type", e.EventType)) - collectionIDs := s.meta.GetAll() + collectionIDs := s.meta.GetAll(s.ctx) if len(collectionIDs) == 0 { log.Warn("no collection loaded, skip to trigger update load config") return diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 70f5892e33e36..948d0d7a9277b 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -89,6 +89,7 @@ type ServerSuite struct { tikvCli *txnkv.Client server *Server nodes []*mocks.MockQueryNode + ctx context.Context } var testMeta string @@ -125,6 +126,7 @@ func (suite *ServerSuite) SetupSuite() { 1001: 3, } suite.nodes = make([]*mocks.MockQueryNode, 3) + suite.ctx = context.Background() } func (suite *ServerSuite) SetupTest() { @@ -144,13 +146,13 @@ func (suite *ServerSuite) SetupTest() { suite.Require().NoError(err) ok := suite.waitNodeUp(suite.nodes[i], 5*time.Second) suite.Require().True(ok) - suite.server.meta.ResourceManager.HandleNodeUp(suite.nodes[i].ID) + suite.server.meta.ResourceManager.HandleNodeUp(suite.ctx, suite.nodes[i].ID) suite.expectLoadAndReleasePartitions(suite.nodes[i]) } suite.loadAll() for _, collection := range suite.collections { - suite.True(suite.server.meta.Exist(collection)) + suite.True(suite.server.meta.Exist(suite.ctx, collection)) suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) } } @@ -181,7 +183,7 @@ func (suite *ServerSuite) TestRecover() { suite.NoError(err) for _, collection := range suite.collections { - suite.True(suite.server.meta.Exist(collection)) + suite.True(suite.server.meta.Exist(suite.ctx, collection)) } suite.True(suite.server.nodeMgr.IsStoppingNode(suite.nodes[0].ID)) @@ -201,7 +203,7 @@ func (suite *ServerSuite) TestNodeUp() { return false } for _, collection := range suite.collections { - replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, node1.ID) + replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node1.ID) if replica == nil { return false } @@ -230,7 +232,7 @@ func (suite *ServerSuite) TestNodeUp() { return false } for _, collection := range suite.collections { - replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, node2.ID) + replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID) if replica == nil { return true } @@ -249,7 +251,7 @@ func (suite *ServerSuite) TestNodeUp() { return false } for _, collection := range suite.collections { - replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, node2.ID) + replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, node2.ID) if replica == nil { return false } @@ -279,7 +281,7 @@ func (suite *ServerSuite) TestNodeDown() { return false } for _, collection := range suite.collections { - replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(collection, downNode.ID) + replica := suite.server.meta.ReplicaManager.GetByCollectionAndNode(suite.ctx, collection, downNode.ID) if replica != nil { return false } @@ -525,7 +527,7 @@ func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, } func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) { - collection := suite.server.meta.GetCollection(collectionID) + collection := suite.server.meta.GetCollection(suite.ctx, collectionID) if collection != nil { collection := collection.Clone() collection.LoadPercentage = 0 @@ -533,9 +535,9 @@ func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status quer collection.LoadPercentage = 100 } collection.CollectionLoadInfo.Status = status - suite.server.meta.PutCollection(collection) + suite.server.meta.PutCollection(suite.ctx, collection) - partitions := suite.server.meta.GetPartitionsByCollection(collectionID) + partitions := suite.server.meta.GetPartitionsByCollection(suite.ctx, collectionID) for _, partition := range partitions { partition := partition.Clone() partition.LoadPercentage = 0 @@ -543,7 +545,7 @@ func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status quer partition.LoadPercentage = 100 } partition.PartitionLoadInfo.Status = status - suite.server.meta.PutPartition(partition) + suite.server.meta.PutPartition(suite.ctx, partition) } } } diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index f67c332a39ceb..d764dcaddba01 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -70,7 +70,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio isGetAll := false collectionSet := typeutil.NewUniqueSet(req.GetCollectionIDs()...) if len(req.GetCollectionIDs()) == 0 { - for _, collection := range s.meta.GetAllCollections() { + for _, collection := range s.meta.GetAllCollections(ctx) { collectionSet.Insert(collection.GetCollectionID()) } isGetAll = true @@ -86,9 +86,9 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio for _, collectionID := range collections { log := log.With(zap.Int64("collectionID", collectionID)) - collection := s.meta.CollectionManager.GetCollection(collectionID) - percentage := s.meta.CollectionManager.CalculateLoadPercentage(collectionID) - loadFields := s.meta.CollectionManager.GetLoadFields(collectionID) + collection := s.meta.CollectionManager.GetCollection(ctx, collectionID) + percentage := s.meta.CollectionManager.CalculateLoadPercentage(ctx, collectionID) + loadFields := s.meta.CollectionManager.GetLoadFields(ctx, collectionID) refreshProgress := int64(0) if percentage < 0 { if isGetAll { @@ -150,13 +150,13 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions refreshProgress := int64(0) if len(partitions) == 0 { - partitions = lo.Map(s.meta.GetPartitionsByCollection(req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 { + partitions = lo.Map(s.meta.GetPartitionsByCollection(ctx, req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 { return partition.GetPartitionID() }) } for _, partitionID := range partitions { - percentage := s.meta.GetPartitionLoadPercentage(partitionID) + percentage := s.meta.GetPartitionLoadPercentage(ctx, partitionID) if percentage < 0 { err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID()) if err != nil { @@ -177,7 +177,7 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions percentages = append(percentages, int64(percentage)) } - collection := s.meta.GetCollection(req.GetCollectionID()) + collection := s.meta.GetCollection(ctx, req.GetCollectionID()) if collection != nil && collection.IsRefreshed() { refreshProgress = 100 } @@ -217,7 +217,7 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection // If refresh mode is ON. if req.GetRefresh() { - err := s.refreshCollection(req.GetCollectionID()) + err := s.refreshCollection(ctx, req.GetCollectionID()) if err != nil { log.Warn("failed to refresh collection", zap.Error(err)) } @@ -253,11 +253,11 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection } var loadJob job.Job - collection := s.meta.GetCollection(req.GetCollectionID()) + collection := s.meta.GetCollection(ctx, req.GetCollectionID()) if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded { // if collection is loaded, check if collection is loaded with the same replica number and resource groups // if replica number or resource group changes, switch to update load config - collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect() + collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(ctx, collection.GetCollectionID()).Collect() left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups()) rgChanged := len(left) > 0 || len(right) > 0 replicaChanged := collection.GetReplicaNumber() != req.GetReplicaNumber() @@ -372,7 +372,7 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions // If refresh mode is ON. if req.GetRefresh() { - err := s.refreshCollection(req.GetCollectionID()) + err := s.refreshCollection(ctx, req.GetCollectionID()) if err != nil { log.Warn("failed to refresh partitions", zap.Error(err)) } @@ -494,9 +494,9 @@ func (s *Server) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti } states := make([]*querypb.PartitionStates, 0, len(req.GetPartitionIDs())) - switch s.meta.GetLoadType(req.GetCollectionID()) { + switch s.meta.GetLoadType(ctx, req.GetCollectionID()) { case querypb.LoadType_LoadCollection: - collection := s.meta.GetCollection(req.GetCollectionID()) + collection := s.meta.GetCollection(ctx, req.GetCollectionID()) state := querypb.PartitionState_PartialInMemory if collection.LoadPercentage >= 100 { state = querypb.PartitionState_InMemory @@ -515,7 +515,7 @@ func (s *Server) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti case querypb.LoadType_LoadPartition: for _, partitionID := range req.GetPartitionIDs() { - partition := s.meta.GetPartition(partitionID) + partition := s.meta.GetPartition(ctx, partitionID) if partition == nil { log.Warn(msg, zap.Int64("partition", partitionID)) return notLoadResp, nil @@ -558,7 +558,7 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo infos := make([]*querypb.SegmentInfo, 0, len(req.GetSegmentIDs())) if len(req.GetSegmentIDs()) == 0 { - infos = s.getCollectionSegmentInfo(req.GetCollectionID()) + infos = s.getCollectionSegmentInfo(ctx, req.GetCollectionID()) } else { for _, segmentID := range req.GetSegmentIDs() { segments := s.dist.SegmentDistManager.GetByFilter(meta.WithSegmentID(segmentID)) @@ -611,8 +611,8 @@ func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncN // tries to load them up. It returns when all segments of the given collection are loaded, or when error happens. // Note that a collection's loading progress always stays at 100% after a successful load and will not get updated // during refreshCollection. -func (s *Server) refreshCollection(collectionID int64) error { - collection := s.meta.CollectionManager.GetCollection(collectionID) +func (s *Server) refreshCollection(ctx context.Context, collectionID int64) error { + collection := s.meta.CollectionManager.GetCollection(ctx, collectionID) if collection == nil { return merr.WrapErrCollectionNotLoaded(collectionID) } @@ -724,14 +724,14 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques log.Warn(msg, zap.Int("source-nodes-num", len(req.GetSourceNodeIDs()))) return merr.Status(err), nil } - if s.meta.CollectionManager.CalculateLoadPercentage(req.GetCollectionID()) < 100 { + if s.meta.CollectionManager.CalculateLoadPercentage(ctx, req.GetCollectionID()) < 100 { err := merr.WrapErrCollectionNotFullyLoaded(req.GetCollectionID()) msg := "can't balance segments of not fully loaded collection" log.Warn(msg) return merr.Status(err), nil } srcNode := req.GetSourceNodeIDs()[0] - replica := s.meta.ReplicaManager.GetByCollectionAndNode(req.GetCollectionID(), srcNode) + replica := s.meta.ReplicaManager.GetByCollectionAndNode(ctx, req.GetCollectionID(), srcNode) if replica == nil { err := merr.WrapErrNodeNotFound(srcNode, fmt.Sprintf("source node not found in any replica of collection %d", req.GetCollectionID())) msg := "source node not found in any replica" @@ -785,7 +785,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques } // Only balance segments in targets - existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + existInTarget := s.targetMgr.GetSealedSegment(ctx, segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil if !existInTarget { log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", segmentID)) continue @@ -881,13 +881,13 @@ func (s *Server) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReque Replicas: make([]*milvuspb.ReplicaInfo, 0), } - replicas := s.meta.ReplicaManager.GetByCollection(req.GetCollectionID()) + replicas := s.meta.ReplicaManager.GetByCollection(ctx, req.GetCollectionID()) if len(replicas) == 0 { return resp, nil } for _, replica := range replicas { - resp.Replicas = append(resp.Replicas, s.fillReplicaInfo(replica, req.GetWithShardNodes())) + resp.Replicas = append(resp.Replicas, s.fillReplicaInfo(ctx, replica, req.GetWithShardNodes())) } return resp, nil } @@ -969,7 +969,7 @@ func (s *Server) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateRe return merr.Status(err), nil } - err := s.meta.ResourceManager.AddResourceGroup(req.GetResourceGroup(), req.GetConfig()) + err := s.meta.ResourceManager.AddResourceGroup(ctx, req.GetResourceGroup(), req.GetConfig()) if err != nil { log.Warn("failed to create resource group", zap.Error(err)) return merr.Status(err), nil @@ -988,7 +988,7 @@ func (s *Server) UpdateResourceGroups(ctx context.Context, req *querypb.UpdateRe return merr.Status(err), nil } - err := s.meta.ResourceManager.UpdateResourceGroups(req.GetResourceGroups()) + err := s.meta.ResourceManager.UpdateResourceGroups(ctx, req.GetResourceGroups()) if err != nil { log.Warn("failed to update resource group", zap.Error(err)) return merr.Status(err), nil @@ -1007,14 +1007,14 @@ func (s *Server) DropResourceGroup(ctx context.Context, req *milvuspb.DropResour return merr.Status(err), nil } - replicas := s.meta.ReplicaManager.GetByResourceGroup(req.GetResourceGroup()) + replicas := s.meta.ReplicaManager.GetByResourceGroup(ctx, req.GetResourceGroup()) if len(replicas) > 0 { err := merr.WrapErrParameterInvalid("empty resource group", fmt.Sprintf("resource group %s has collection %d loaded", req.GetResourceGroup(), replicas[0].GetCollectionID())) return merr.Status(errors.Wrap(err, fmt.Sprintf("some replicas still loaded in resource group[%s], release it first", req.GetResourceGroup()))), nil } - err := s.meta.ResourceManager.RemoveResourceGroup(req.GetResourceGroup()) + err := s.meta.ResourceManager.RemoveResourceGroup(ctx, req.GetResourceGroup()) if err != nil { log.Warn("failed to drop resource group", zap.Error(err)) return merr.Status(err), nil @@ -1037,7 +1037,7 @@ func (s *Server) TransferNode(ctx context.Context, req *milvuspb.TransferNodeReq } // Move node from source resource group to target resource group. - if err := s.meta.ResourceManager.TransferNode(req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumNode())); err != nil { + if err := s.meta.ResourceManager.TransferNode(ctx, req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumNode())); err != nil { log.Warn("failed to transfer node", zap.Error(err)) return merr.Status(err), nil } @@ -1059,20 +1059,20 @@ func (s *Server) TransferReplica(ctx context.Context, req *querypb.TransferRepli } // TODO: !!!WARNING, replica manager and resource manager doesn't protected with each other by lock. - if ok := s.meta.ResourceManager.ContainResourceGroup(req.GetSourceResourceGroup()); !ok { + if ok := s.meta.ResourceManager.ContainResourceGroup(ctx, req.GetSourceResourceGroup()); !ok { err := merr.WrapErrResourceGroupNotFound(req.GetSourceResourceGroup()) return merr.Status(errors.Wrap(err, fmt.Sprintf("the source resource group[%s] doesn't exist", req.GetSourceResourceGroup()))), nil } - if ok := s.meta.ResourceManager.ContainResourceGroup(req.GetTargetResourceGroup()); !ok { + if ok := s.meta.ResourceManager.ContainResourceGroup(ctx, req.GetTargetResourceGroup()); !ok { err := merr.WrapErrResourceGroupNotFound(req.GetTargetResourceGroup()) return merr.Status(errors.Wrap(err, fmt.Sprintf("the target resource group[%s] doesn't exist", req.GetTargetResourceGroup()))), nil } // Apply change into replica manager. - err := s.meta.TransferReplica(req.GetCollectionID(), req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumReplica())) + err := s.meta.TransferReplica(ctx, req.GetCollectionID(), req.GetSourceResourceGroup(), req.GetTargetResourceGroup(), int(req.GetNumReplica())) return merr.Status(err), nil } @@ -1089,7 +1089,7 @@ func (s *Server) ListResourceGroups(ctx context.Context, req *milvuspb.ListResou return resp, nil } - resp.ResourceGroups = s.meta.ResourceManager.ListResourceGroups() + resp.ResourceGroups = s.meta.ResourceManager.ListResourceGroups(ctx) return resp, nil } @@ -1108,7 +1108,7 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ return resp, nil } - rg := s.meta.ResourceManager.GetResourceGroup(req.GetResourceGroup()) + rg := s.meta.ResourceManager.GetResourceGroup(ctx, req.GetResourceGroup()) if rg == nil { err := merr.WrapErrResourceGroupNotFound(req.GetResourceGroup()) resp.Status = merr.Status(err) @@ -1117,26 +1117,26 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ loadedReplicas := make(map[int64]int32) outgoingNodes := make(map[int64]int32) - replicasInRG := s.meta.GetByResourceGroup(req.GetResourceGroup()) + replicasInRG := s.meta.GetByResourceGroup(ctx, req.GetResourceGroup()) for _, replica := range replicasInRG { loadedReplicas[replica.GetCollectionID()]++ for _, node := range replica.GetRONodes() { - if !s.meta.ContainsNode(replica.GetResourceGroup(), node) { + if !s.meta.ContainsNode(ctx, replica.GetResourceGroup(), node) { outgoingNodes[replica.GetCollectionID()]++ } } } incomingNodes := make(map[int64]int32) - collections := s.meta.GetAll() + collections := s.meta.GetAll(ctx) for _, collection := range collections { - replicas := s.meta.GetByCollection(collection) + replicas := s.meta.GetByCollection(ctx, collection) for _, replica := range replicas { if replica.GetResourceGroup() == req.GetResourceGroup() { continue } for _, node := range replica.GetRONodes() { - if s.meta.ContainsNode(req.GetResourceGroup(), node) { + if s.meta.ContainsNode(ctx, req.GetResourceGroup(), node) { incomingNodes[collection]++ } } @@ -1184,14 +1184,14 @@ func (s *Server) UpdateLoadConfig(ctx context.Context, req *querypb.UpdateLoadCo jobs := make([]job.Job, 0, len(req.GetCollectionIDs())) for _, collectionID := range req.GetCollectionIDs() { - collection := s.meta.GetCollection(collectionID) + collection := s.meta.GetCollection(ctx, collectionID) if collection == nil || collection.GetStatus() != querypb.LoadStatus_Loaded { err := merr.WrapErrCollectionNotLoaded(collectionID) log.Warn("failed to update load config", zap.Error(err)) continue } - collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(collection.GetCollectionID()).Collect() + collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(ctx, collection.GetCollectionID()).Collect() left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups()) rgChanged := len(left) > 0 || len(right) > 0 replicaChanged := collection.GetReplicaNumber() != req.GetReplicaNumber() diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index a573eea5c8de7..451bf85596b73 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -166,7 +166,7 @@ func (suite *ServiceSuite) SetupTest() { Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(node) + suite.meta.ResourceManager.HandleNodeUp(context.TODO(), node) } suite.cluster = session.NewMockCluster(suite.T()) suite.cluster.EXPECT().SyncDistribution(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe() @@ -250,15 +250,15 @@ func (suite *ServiceSuite) TestShowCollections() { suite.Equal(collection, resp.CollectionIDs[0]) // Test insufficient memory - colBak := suite.meta.CollectionManager.GetCollection(collection) - err = suite.meta.CollectionManager.RemoveCollection(collection) + colBak := suite.meta.CollectionManager.GetCollection(ctx, collection) + err = suite.meta.CollectionManager.RemoveCollection(ctx, collection) suite.NoError(err) meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowCollections(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) meta.GlobalFailedLoadCache.Remove(collection) - err = suite.meta.CollectionManager.PutCollection(colBak) + err = suite.meta.CollectionManager.PutCollection(ctx, colBak) suite.NoError(err) // Test when server is not healthy @@ -304,27 +304,27 @@ func (suite *ServiceSuite) TestShowPartitions() { // Test insufficient memory if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { - colBak := suite.meta.CollectionManager.GetCollection(collection) - err = suite.meta.CollectionManager.RemoveCollection(collection) + colBak := suite.meta.CollectionManager.GetCollection(ctx, collection) + err = suite.meta.CollectionManager.RemoveCollection(ctx, collection) suite.NoError(err) meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowPartitions(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) meta.GlobalFailedLoadCache.Remove(collection) - err = suite.meta.CollectionManager.PutCollection(colBak) + err = suite.meta.CollectionManager.PutCollection(ctx, colBak) suite.NoError(err) } else { partitionID := partitions[0] - parBak := suite.meta.CollectionManager.GetPartition(partitionID) - err = suite.meta.CollectionManager.RemovePartition(collection, partitionID) + parBak := suite.meta.CollectionManager.GetPartition(ctx, partitionID) + err = suite.meta.CollectionManager.RemovePartition(ctx, collection, partitionID) suite.NoError(err) meta.GlobalFailedLoadCache.Put(collection, merr.WrapErrServiceMemoryLimitExceeded(100, 10)) resp, err = server.ShowPartitions(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_InsufficientMemoryToLoad, resp.GetStatus().GetErrorCode()) meta.GlobalFailedLoadCache.Remove(collection) - err = suite.meta.CollectionManager.PutPartition(parBak) + err = suite.meta.CollectionManager.PutPartition(ctx, parBak) suite.NoError(err) } } @@ -354,7 +354,7 @@ func (suite *ServiceSuite) TestLoadCollection() { resp, err := server.LoadCollection(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) - suite.assertLoaded(collection) + suite.assertLoaded(ctx, collection) } // Test load again @@ -430,21 +430,21 @@ func (suite *ServiceSuite) TestResourceGroup() { Address: "localhost", Hostname: "localhost", })) - server.meta.ResourceManager.AddResourceGroup("rg11", &rgpb.ResourceGroupConfig{ + server.meta.ResourceManager.AddResourceGroup(ctx, "rg11", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 2}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 2}, }) - server.meta.ResourceManager.HandleNodeUp(1011) - server.meta.ResourceManager.HandleNodeUp(1012) - server.meta.ResourceManager.AddResourceGroup("rg12", &rgpb.ResourceGroupConfig{ + server.meta.ResourceManager.HandleNodeUp(ctx, 1011) + server.meta.ResourceManager.HandleNodeUp(ctx, 1012) + server.meta.ResourceManager.AddResourceGroup(ctx, "rg12", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 2}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 2}, }) - server.meta.ResourceManager.HandleNodeUp(1013) - server.meta.ResourceManager.HandleNodeUp(1014) - server.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - server.meta.CollectionManager.PutCollection(utils.CreateTestCollection(2, 1)) - server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{ + server.meta.ResourceManager.HandleNodeUp(ctx, 1013) + server.meta.ResourceManager.HandleNodeUp(ctx, 1014) + server.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + server.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(2, 1)) + server.meta.ReplicaManager.Put(ctx, meta.NewReplica(&querypb.Replica{ ID: 1, CollectionID: 1, Nodes: []int64{1011}, @@ -453,7 +453,7 @@ func (suite *ServiceSuite) TestResourceGroup() { }, typeutil.NewUniqueSet(1011, 1013)), ) - server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{ + server.meta.ReplicaManager.Put(ctx, meta.NewReplica(&querypb.Replica{ ID: 2, CollectionID: 2, Nodes: []int64{1014}, @@ -548,18 +548,18 @@ func (suite *ServiceSuite) TestTransferNode() { defer server.resourceObserver.Stop() defer server.replicaObserver.Stop() - err := server.meta.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{ + err := server.meta.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 0}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 0}, }) suite.NoError(err) - err = server.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 0}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 0}, }) suite.NoError(err) - suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 2)) - suite.meta.ReplicaManager.Put(meta.NewReplica( + suite.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 2)) + suite.meta.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 1, CollectionID: 1, @@ -578,15 +578,15 @@ func (suite *ServiceSuite) TestTransferNode() { suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) suite.Eventually(func() bool { - nodes, err := server.meta.ResourceManager.GetNodes("rg1") + nodes, err := server.meta.ResourceManager.GetNodes(ctx, "rg1") if err != nil || len(nodes) != 1 { return false } - nodesInReplica := server.meta.ReplicaManager.Get(1).GetNodes() + nodesInReplica := server.meta.ReplicaManager.Get(ctx, 1).GetNodes() return len(nodesInReplica) == 1 }, 5*time.Second, 100*time.Millisecond) - suite.meta.ReplicaManager.Put(meta.NewReplica( + suite.meta.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 2, CollectionID: 1, @@ -612,12 +612,12 @@ func (suite *ServiceSuite) TestTransferNode() { suite.NoError(err) suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode) - err = server.meta.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 4}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 4}, }) suite.NoError(err) - err = server.meta.ResourceManager.AddResourceGroup("rg4", &rgpb.ResourceGroupConfig{ + err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg4", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 0}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 0}, }) @@ -642,10 +642,10 @@ func (suite *ServiceSuite) TestTransferNode() { Address: "localhost", Hostname: "localhost", })) - suite.meta.ResourceManager.HandleNodeUp(11) - suite.meta.ResourceManager.HandleNodeUp(12) - suite.meta.ResourceManager.HandleNodeUp(13) - suite.meta.ResourceManager.HandleNodeUp(14) + suite.meta.ResourceManager.HandleNodeUp(ctx, 11) + suite.meta.ResourceManager.HandleNodeUp(ctx, 12) + suite.meta.ResourceManager.HandleNodeUp(ctx, 13) + suite.meta.ResourceManager.HandleNodeUp(ctx, 14) resp, err = server.TransferNode(ctx, &milvuspb.TransferNodeRequest{ SourceResourceGroup: "rg3", @@ -656,11 +656,11 @@ func (suite *ServiceSuite) TestTransferNode() { suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) suite.Eventually(func() bool { - nodes, err := server.meta.ResourceManager.GetNodes("rg3") + nodes, err := server.meta.ResourceManager.GetNodes(ctx, "rg3") if err != nil || len(nodes) != 1 { return false } - nodes, err = server.meta.ResourceManager.GetNodes("rg4") + nodes, err = server.meta.ResourceManager.GetNodes(ctx, "rg4") return err == nil && len(nodes) == 3 }, 5*time.Second, 100*time.Millisecond) @@ -695,17 +695,17 @@ func (suite *ServiceSuite) TestTransferReplica() { ctx := context.Background() server := suite.server - err := server.meta.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{ + err := server.meta.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 1}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 1}, }) suite.NoError(err) - err = server.meta.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 1}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 1}, }) suite.NoError(err) - err = server.meta.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + err = server.meta.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 3}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 3}, }) @@ -747,17 +747,17 @@ func (suite *ServiceSuite) TestTransferReplica() { suite.NoError(err) suite.ErrorIs(merr.Error(resp), merr.ErrParameterInvalid) - suite.server.meta.Put(meta.NewReplica(&querypb.Replica{ + suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{ CollectionID: 1, ID: 111, ResourceGroup: meta.DefaultResourceGroupName, }, typeutil.NewUniqueSet(1))) - suite.server.meta.Put(meta.NewReplica(&querypb.Replica{ + suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{ CollectionID: 1, ID: 222, ResourceGroup: meta.DefaultResourceGroupName, }, typeutil.NewUniqueSet(2))) - suite.server.meta.Put(meta.NewReplica(&querypb.Replica{ + suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{ CollectionID: 1, ID: 333, ResourceGroup: meta.DefaultResourceGroupName, @@ -788,18 +788,18 @@ func (suite *ServiceSuite) TestTransferReplica() { Address: "localhost", Hostname: "localhost", })) - suite.server.meta.HandleNodeUp(1001) - suite.server.meta.HandleNodeUp(1002) - suite.server.meta.HandleNodeUp(1003) - suite.server.meta.HandleNodeUp(1004) - suite.server.meta.HandleNodeUp(1005) + suite.server.meta.HandleNodeUp(ctx, 1001) + suite.server.meta.HandleNodeUp(ctx, 1002) + suite.server.meta.HandleNodeUp(ctx, 1003) + suite.server.meta.HandleNodeUp(ctx, 1004) + suite.server.meta.HandleNodeUp(ctx, 1005) - suite.server.meta.Put(meta.NewReplica(&querypb.Replica{ + suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{ CollectionID: 2, ID: 444, ResourceGroup: meta.DefaultResourceGroupName, }, typeutil.NewUniqueSet(3))) - suite.server.meta.Put(meta.NewReplica(&querypb.Replica{ + suite.server.meta.Put(ctx, meta.NewReplica(&querypb.Replica{ CollectionID: 2, ID: 555, ResourceGroup: "rg2", @@ -824,7 +824,7 @@ func (suite *ServiceSuite) TestTransferReplica() { // we support transfer replica to resource group load same collection. suite.Equal(resp.ErrorCode, commonpb.ErrorCode_Success) - replicaNum := len(suite.server.meta.ReplicaManager.GetByCollection(1)) + replicaNum := len(suite.server.meta.ReplicaManager.GetByCollection(ctx, 1)) suite.Equal(3, replicaNum) resp, err = suite.server.TransferReplica(ctx, &querypb.TransferReplicaRequest{ SourceResourceGroup: meta.DefaultResourceGroupName, @@ -842,7 +842,7 @@ func (suite *ServiceSuite) TestTransferReplica() { }) suite.NoError(err) suite.Equal(resp.ErrorCode, commonpb.ErrorCode_Success) - suite.Len(suite.server.meta.GetByResourceGroup("rg3"), 3) + suite.Len(suite.server.meta.GetByResourceGroup(ctx, "rg3"), 3) // server unhealthy server.UpdateStateCode(commonpb.StateCode_Abnormal) @@ -924,7 +924,7 @@ func (suite *ServiceSuite) TestLoadPartition() { resp, err := server.LoadPartitions(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) - suite.assertLoaded(collection) + suite.assertLoaded(ctx, collection) } // Test load again @@ -1020,7 +1020,7 @@ func (suite *ServiceSuite) TestReleaseCollection() { resp, err := server.ReleaseCollection(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) - suite.assertReleased(collection) + suite.assertReleased(ctx, collection) } // Test release again @@ -1059,7 +1059,7 @@ func (suite *ServiceSuite) TestReleasePartition() { resp, err := server.ReleasePartitions(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) - suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...) + suite.assertPartitionLoaded(ctx, collection, suite.partitions[collection][1:]...) } // Test release again @@ -1071,7 +1071,7 @@ func (suite *ServiceSuite) TestReleasePartition() { resp, err := server.ReleasePartitions(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) - suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...) + suite.assertPartitionLoaded(ctx, collection, suite.partitions[collection][1:]...) } // Test when server is not healthy @@ -1086,11 +1086,12 @@ func (suite *ServiceSuite) TestReleasePartition() { } func (suite *ServiceSuite) TestRefreshCollection() { + ctx := context.Background() server := suite.server // Test refresh all collections. for _, collection := range suite.collections { - err := server.refreshCollection(collection) + err := server.refreshCollection(ctx, collection) // Collection not loaded error. suite.ErrorIs(err, merr.ErrCollectionNotLoaded) } @@ -1100,19 +1101,19 @@ func (suite *ServiceSuite) TestRefreshCollection() { // Test refresh all collections again when collections are loaded. This time should fail with collection not 100% loaded. for _, collection := range suite.collections { - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loading) - err := server.refreshCollection(collection) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loading) + err := server.refreshCollection(ctx, collection) suite.ErrorIs(err, merr.ErrCollectionNotLoaded) } // Test refresh all collections for _, id := range suite.collections { // Load and explicitly mark load percentage to 100%. - suite.updateChannelDist(id) + suite.updateChannelDist(ctx, id) suite.updateSegmentDist(id, suite.nodes[0]) - suite.updateCollectionStatus(id, querypb.LoadStatus_Loaded) + suite.updateCollectionStatus(ctx, id, querypb.LoadStatus_Loaded) - err := server.refreshCollection(id) + err := server.refreshCollection(ctx, id) suite.NoError(err) readyCh, err := server.targetObserver.UpdateNextTarget(id) @@ -1120,18 +1121,18 @@ func (suite *ServiceSuite) TestRefreshCollection() { <-readyCh // Now the refresh must be done - collection := server.meta.CollectionManager.GetCollection(id) + collection := server.meta.CollectionManager.GetCollection(ctx, id) suite.True(collection.IsRefreshed()) } // Test refresh not ready for _, id := range suite.collections { - suite.updateChannelDistWithoutSegment(id) - err := server.refreshCollection(id) + suite.updateChannelDistWithoutSegment(ctx, id) + err := server.refreshCollection(ctx, id) suite.NoError(err) // Now the refresh must be not done - collection := server.meta.CollectionManager.GetCollection(id) + collection := server.meta.CollectionManager.GetCollection(ctx, id) suite.False(collection.IsRefreshed()) } } @@ -1209,11 +1210,11 @@ func (suite *ServiceSuite) TestLoadBalance() { // Test get balance first segment for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) nodes := replicas[0].GetNodes() srcNode := nodes[0] dstNode := nodes[1] - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) suite.updateSegmentDist(collection, srcNode) segments := suite.getAllSegments(collection) req := &querypb.LoadBalanceRequest{ @@ -1258,10 +1259,10 @@ func (suite *ServiceSuite) TestLoadBalanceWithNoDstNode() { // Test get balance first segment for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) nodes := replicas[0].GetNodes() srcNode := nodes[0] - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) suite.updateSegmentDist(collection, srcNode) segments := suite.getAllSegments(collection) req := &querypb.LoadBalanceRequest{ @@ -1310,10 +1311,10 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { // update two collection's dist for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) replicas[0].AddRWNode(srcNode) replicas[0].AddRWNode(dstNode) - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) for partition, segments := range suite.segments[collection] { for _, segment := range segments { @@ -1336,9 +1337,9 @@ func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { })) defer func() { for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) - suite.meta.ReplicaManager.RemoveNode(replicas[0].GetID(), srcNode) - suite.meta.ReplicaManager.RemoveNode(replicas[0].GetID(), dstNode) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) + suite.meta.ReplicaManager.RemoveNode(ctx, replicas[0].GetID(), srcNode) + suite.meta.ReplicaManager.RemoveNode(ctx, replicas[0].GetID(), dstNode) } suite.nodeMgr.Remove(1001) suite.nodeMgr.Remove(1002) @@ -1380,7 +1381,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { // Test load balance without source node for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) dstNode := replicas[0].GetNodes()[1] segments := suite.getAllSegments(collection) req := &querypb.LoadBalanceRequest{ @@ -1395,11 +1396,11 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { // Test load balance with not fully loaded for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) nodes := replicas[0].GetNodes() srcNode := nodes[0] dstNode := nodes[1] - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loading) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loading) segments := suite.getAllSegments(collection) req := &querypb.LoadBalanceRequest{ CollectionID: collection, @@ -1418,10 +1419,10 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { continue } - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) srcNode := replicas[0].GetNodes()[0] dstNode := replicas[1].GetNodes()[0] - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) suite.updateSegmentDist(collection, srcNode) segments := suite.getAllSegments(collection) req := &querypb.LoadBalanceRequest{ @@ -1437,11 +1438,11 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { // Test balance task failed for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) nodes := replicas[0].GetNodes() srcNode := nodes[0] dstNode := nodes[1] - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) suite.updateSegmentDist(collection, srcNode) segments := suite.getAllSegments(collection) req := &querypb.LoadBalanceRequest{ @@ -1458,7 +1459,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) suite.Contains(resp.Reason, "mock error") - suite.meta.ReplicaManager.RecoverNodesInCollection(collection, map[string]typeutil.UniqueSet{meta.DefaultResourceGroupName: typeutil.NewUniqueSet(10)}) + suite.meta.ReplicaManager.RecoverNodesInCollection(ctx, collection, map[string]typeutil.UniqueSet{meta.DefaultResourceGroupName: typeutil.NewUniqueSet(10)}) req.SourceNodeIDs = []int64{10} resp, err = server.LoadBalance(ctx, req) suite.NoError(err) @@ -1480,7 +1481,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { suite.NoError(err) suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) suite.nodeMgr.Remove(10) - suite.meta.ReplicaManager.RemoveNode(replicas[0].GetID(), 10) + suite.meta.ReplicaManager.RemoveNode(ctx, replicas[0].GetID(), 10) } } @@ -1545,7 +1546,7 @@ func (suite *ServiceSuite) TestGetReplicas() { server := suite.server for _, collection := range suite.collections { - suite.updateChannelDist(collection) + suite.updateChannelDist(ctx, collection) req := &milvuspb.GetReplicasRequest{ CollectionID: collection, } @@ -1557,11 +1558,11 @@ func (suite *ServiceSuite) TestGetReplicas() { // Test get with shard nodes for _, collection := range suite.collections { - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) for _, replica := range replicas { suite.updateSegmentDist(collection, replica.GetNodes()[0]) } - suite.updateChannelDist(collection) + suite.updateChannelDist(ctx, collection) req := &milvuspb.GetReplicasRequest{ CollectionID: collection, WithShardNodes: true, @@ -1582,7 +1583,7 @@ func (suite *ServiceSuite) TestGetReplicas() { } } - suite.Equal(len(replica.GetNodeIds()), len(suite.meta.ReplicaManager.Get(replica.ReplicaID).GetNodes())) + suite.Equal(len(replica.GetNodeIds()), len(suite.meta.ReplicaManager.Get(ctx, replica.ReplicaID).GetNodes())) } } @@ -1601,13 +1602,13 @@ func (suite *ServiceSuite) TestGetReplicasWhenNoAvailableNodes() { ctx := context.Background() server := suite.server - replicas := suite.meta.ReplicaManager.GetByCollection(suite.collections[0]) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, suite.collections[0]) for _, replica := range replicas { suite.updateSegmentDist(suite.collections[0], replica.GetNodes()[0]) } - suite.updateChannelDist(suite.collections[0]) + suite.updateChannelDist(ctx, suite.collections[0]) - suite.meta.ReplicaManager.Put(utils.CreateTestReplica(100001, suite.collections[0], []int64{})) + suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(100001, suite.collections[0], []int64{})) req := &milvuspb.GetReplicasRequest{ CollectionID: suite.collections[0], @@ -1660,14 +1661,14 @@ func (suite *ServiceSuite) TestCheckHealth() { // Test for check channel ok for _, collection := range suite.collections { - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) - suite.updateChannelDist(collection) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) + suite.updateChannelDist(ctx, collection) } assertCheckHealthResult(true) // Test for check channel fail tm := meta.NewMockTargetManager(suite.T()) - tm.EXPECT().GetDmChannelsByCollection(mock.Anything, mock.Anything).Return(nil).Maybe() + tm.EXPECT().GetDmChannelsByCollection(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() otm := server.targetMgr server.targetMgr = tm assertCheckHealthResult(true) @@ -1686,8 +1687,8 @@ func (suite *ServiceSuite) TestGetShardLeaders() { server := suite.server for _, collection := range suite.collections { - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) - suite.updateChannelDist(collection) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) + suite.updateChannelDist(ctx, collection) req := &querypb.GetShardLeadersRequest{ CollectionID: collection, } @@ -1718,8 +1719,8 @@ func (suite *ServiceSuite) TestGetShardLeadersFailed() { server := suite.server for _, collection := range suite.collections { - suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) - suite.updateChannelDist(collection) + suite.updateCollectionStatus(ctx, collection, querypb.LoadStatus_Loaded) + suite.updateChannelDist(ctx, collection) req := &querypb.GetShardLeadersRequest{ CollectionID: collection, } @@ -1746,7 +1747,7 @@ func (suite *ServiceSuite) TestGetShardLeadersFailed() { suite.dist.ChannelDistManager.Update(node) suite.dist.LeaderViewManager.Update(node) } - suite.updateChannelDistWithoutSegment(collection) + suite.updateChannelDistWithoutSegment(ctx, collection) suite.fetchHeartbeats(time.Now()) resp, err = server.GetShardLeaders(ctx, req) suite.NoError(err) @@ -1789,9 +1790,10 @@ func (suite *ServiceSuite) TestHandleNodeUp() { suite.server.resourceObserver.Start() defer suite.server.resourceObserver.Stop() + ctx := context.Background() server := suite.server - suite.server.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) - suite.server.meta.ReplicaManager.Put(meta.NewReplica( + suite.server.meta.CollectionManager.PutCollection(ctx, utils.CreateTestCollection(1, 1)) + suite.server.meta.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 1, CollectionID: 1, @@ -1812,12 +1814,12 @@ func (suite *ServiceSuite) TestHandleNodeUp() { server.handleNodeUp(111) // wait for async update by observer suite.Eventually(func() bool { - nodes := suite.server.meta.ReplicaManager.Get(1).GetNodes() - nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName) + nodes := suite.server.meta.ReplicaManager.Get(ctx, 1).GetNodes() + nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName) return len(nodes) == len(nodesInRG) }, 5*time.Second, 100*time.Millisecond) - nodes := suite.server.meta.ReplicaManager.Get(1).GetNodes() - nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(meta.DefaultResourceGroupName) + nodes := suite.server.meta.ReplicaManager.Get(ctx, 1).GetNodes() + nodesInRG, _ := suite.server.meta.ResourceManager.GetNodes(ctx, meta.DefaultResourceGroupName) suite.ElementsMatch(nodes, nodesInRG) } @@ -1846,10 +1848,10 @@ func (suite *ServiceSuite) loadAll() { suite.jobScheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(collection)) - suite.True(suite.meta.Exist(collection)) - suite.NotNil(suite.meta.GetCollection(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(ctx, collection)) + suite.True(suite.meta.Exist(ctx, collection)) + suite.NotNil(suite.meta.GetCollection(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) } else { req := &querypb.LoadPartitionsRequest{ CollectionID: collection, @@ -1871,30 +1873,30 @@ func (suite *ServiceSuite) loadAll() { suite.jobScheduler.Add(job) err := job.Wait() suite.NoError(err) - suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(collection)) - suite.True(suite.meta.Exist(collection)) - suite.NotNil(suite.meta.GetPartitionsByCollection(collection)) - suite.targetMgr.UpdateCollectionCurrentTarget(collection) + suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(ctx, collection)) + suite.True(suite.meta.Exist(ctx, collection)) + suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection)) + suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection) } } } -func (suite *ServiceSuite) assertLoaded(collection int64) { - suite.True(suite.meta.Exist(collection)) +func (suite *ServiceSuite) assertLoaded(ctx context.Context, collection int64) { + suite.True(suite.meta.Exist(ctx, collection)) for _, channel := range suite.channels[collection] { - suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.NextTarget)) + suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.NextTarget)) } for _, partitions := range suite.segments[collection] { for _, segment := range partitions { - suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.NextTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.NextTarget)) } } } -func (suite *ServiceSuite) assertPartitionLoaded(collection int64, partitions ...int64) { - suite.True(suite.meta.Exist(collection)) +func (suite *ServiceSuite) assertPartitionLoaded(ctx context.Context, collection int64, partitions ...int64) { + suite.True(suite.meta.Exist(ctx, collection)) for _, channel := range suite.channels[collection] { - suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget)) } partitionSet := typeutil.NewUniqueSet(partitions...) for partition, segments := range suite.segments[collection] { @@ -1902,20 +1904,20 @@ func (suite *ServiceSuite) assertPartitionLoaded(collection int64, partitions .. continue } for _, segment := range segments { - suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget)) } } } -func (suite *ServiceSuite) assertReleased(collection int64) { - suite.False(suite.meta.Exist(collection)) +func (suite *ServiceSuite) assertReleased(ctx context.Context, collection int64) { + suite.False(suite.meta.Exist(ctx, collection)) for _, channel := range suite.channels[collection] { - suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget)) } for _, partitions := range suite.segments[collection] { for _, segment := range partitions { - suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) - suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.NextTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.NextTarget)) } } } @@ -1989,11 +1991,11 @@ func (suite *ServiceSuite) updateSegmentDist(collection, node int64) { suite.dist.SegmentDistManager.Update(node, metaSegments...) } -func (suite *ServiceSuite) updateChannelDist(collection int64) { +func (suite *ServiceSuite) updateChannelDist(ctx context.Context, collection int64) { channels := suite.channels[collection] segments := lo.Flatten(lo.Values(suite.segments[collection])) - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) for _, replica := range replicas { i := 0 for _, node := range suite.sortInt64(replica.GetNodes()) { @@ -2027,10 +2029,10 @@ func (suite *ServiceSuite) sortInt64(ints []int64) []int64 { return ints } -func (suite *ServiceSuite) updateChannelDistWithoutSegment(collection int64) { +func (suite *ServiceSuite) updateChannelDistWithoutSegment(ctx context.Context, collection int64) { channels := suite.channels[collection] - replicas := suite.meta.ReplicaManager.GetByCollection(collection) + replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection) for _, replica := range replicas { i := 0 for _, node := range suite.sortInt64(replica.GetNodes()) { @@ -2052,8 +2054,8 @@ func (suite *ServiceSuite) updateChannelDistWithoutSegment(collection int64) { } } -func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) { - collection := suite.meta.GetCollection(collectionID) +func (suite *ServiceSuite) updateCollectionStatus(ctx context.Context, collectionID int64, status querypb.LoadStatus) { + collection := suite.meta.GetCollection(ctx, collectionID) if collection != nil { collection := collection.Clone() collection.LoadPercentage = 0 @@ -2061,9 +2063,9 @@ func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status que collection.LoadPercentage = 100 } collection.CollectionLoadInfo.Status = status - suite.meta.PutCollection(collection) + suite.meta.PutCollection(ctx, collection) - partitions := suite.meta.GetPartitionsByCollection(collectionID) + partitions := suite.meta.GetPartitionsByCollection(ctx, collectionID) for _, partition := range partitions { partition := partition.Clone() partition.LoadPercentage = 0 @@ -2071,7 +2073,7 @@ func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status que partition.LoadPercentage = 100 } partition.PartitionLoadInfo.Status = status - suite.meta.PutPartition(partition) + suite.meta.PutPartition(ctx, partition) } } } diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 36f04889d23b8..fbdc9ddc14db9 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -208,7 +208,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { ) // get segment's replica first, then get shard leader by replica - replica := ex.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) + replica := ex.meta.ReplicaManager.GetByCollectionAndNode(ctx, task.CollectionID(), action.Node()) if replica == nil { msg := "node doesn't belong to any replica" err := merr.WrapErrNodeNotAvailable(action.Node()) @@ -259,7 +259,7 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) { dstNode := action.Node() req := packReleaseSegmentRequest(task, action) - channel := ex.targetMgr.GetDmChannel(task.CollectionID(), task.Shard(), meta.CurrentTarget) + channel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), task.Shard(), meta.CurrentTarget) if channel != nil { // if channel exists in current target, set cp to ReleaseSegmentRequest, need to use it as growing segment's exclude ts req.Checkpoint = channel.GetSeekPosition() @@ -272,9 +272,9 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) { } else { req.Shard = task.shard - if ex.meta.CollectionManager.Exist(task.CollectionID()) { + if ex.meta.CollectionManager.Exist(ctx, task.CollectionID()) { // get segment's replica first, then get shard leader by replica - replica := ex.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) + replica := ex.meta.ReplicaManager.GetByCollectionAndNode(ctx, task.CollectionID(), action.Node()) if replica == nil { msg := "node doesn't belong to any replica, try to send release to worker" err := merr.WrapErrNodeNotAvailable(action.Node()) @@ -344,8 +344,8 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { log.Warn("failed to get collection info") return err } - loadFields := ex.meta.GetLoadFields(task.CollectionID()) - partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID()) + loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID()) + partitions, err := utils.GetPartitions(ctx, ex.meta.CollectionManager, task.CollectionID()) if err != nil { log.Warn("failed to get partitions of collection") return err @@ -356,7 +356,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { return err } loadMeta := packLoadMeta( - ex.meta.GetLoadType(task.CollectionID()), + ex.meta.GetLoadType(ctx, task.CollectionID()), task.CollectionID(), collectionInfo.GetDbName(), task.ResourceGroup(), @@ -364,7 +364,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { partitions..., ) - dmChannel := ex.targetMgr.GetDmChannel(task.CollectionID(), action.ChannelName(), meta.NextTarget) + dmChannel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), action.ChannelName(), meta.NextTarget) if dmChannel == nil { msg := "channel does not exist in next target, skip it" log.Warn(msg, zap.String("channelName", action.ChannelName())) @@ -652,15 +652,15 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr log.Warn("failed to get collection info", zap.Error(err)) return nil, nil, nil, err } - loadFields := ex.meta.GetLoadFields(task.CollectionID()) - partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID) + loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID()) + partitions, err := utils.GetPartitions(ctx, ex.meta.CollectionManager, collectionID) if err != nil { log.Warn("failed to get partitions of collection", zap.Error(err)) return nil, nil, nil, err } loadMeta := packLoadMeta( - ex.meta.GetLoadType(task.CollectionID()), + ex.meta.GetLoadType(ctx, task.CollectionID()), task.CollectionID(), collectionInfo.GetDbName(), task.ResourceGroup(), @@ -669,7 +669,7 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr ) // get channel first, in case of target updated after segment info fetched - channel := ex.targetMgr.GetDmChannel(collectionID, shard, meta.NextTargetFirst) + channel := ex.targetMgr.GetDmChannel(ctx, collectionID, shard, meta.NextTargetFirst) if channel == nil { return nil, nil, nil, merr.WrapErrChannelNotAvailable(shard) } diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 1a0f983b50289..cde0c3d43df34 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -385,7 +385,7 @@ func (scheduler *taskScheduler) preAdd(task Task) error { if taskType == TaskTypeGrow { views := scheduler.distMgr.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(task.Channel())) nodesWithChannel := lo.Map(views, func(v *meta.LeaderView, _ int) UniqueID { return v.ID }) - replicaNodeMap := utils.GroupNodesByReplica(scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel) + replicaNodeMap := utils.GroupNodesByReplica(task.ctx, scheduler.meta.ReplicaManager, task.CollectionID(), nodesWithChannel) if _, ok := replicaNodeMap[task.ReplicaID()]; ok { return merr.WrapErrServiceInternal("channel subscribed, it can be only balanced") } @@ -535,7 +535,7 @@ func (scheduler *taskScheduler) calculateTaskDelta(collectionID int64, targetAct case *SegmentAction: // skip growing segment's count, cause doesn't know realtime row number of growing segment if action.Scope == querypb.DataScope_Historical { - segment := scheduler.targetMgr.GetSealedSegment(collectionID, action.SegmentID, meta.NextTargetFirst) + segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) if segment != nil { sum += int(segment.GetNumOfRows()) * delta } @@ -708,14 +708,14 @@ func (scheduler *taskScheduler) isRelated(task Task, node int64) bool { taskType := GetTaskType(task) var segment *datapb.SegmentInfo if taskType == TaskTypeMove || taskType == TaskTypeUpdate { - segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget) + segment = scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.CurrentTarget) } else { - segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget) + segment = scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.NextTarget) } if segment == nil { continue } - replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) + replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.ctx, task.CollectionID(), action.Node()) if replica == nil { continue } @@ -851,7 +851,7 @@ func (scheduler *taskScheduler) remove(task Task) { if errors.Is(task.Err(), merr.ErrSegmentNotFound) { log.Info("segment in target has been cleaned, trigger force update next target", zap.Int64("collectionID", task.CollectionID())) - scheduler.targetMgr.UpdateCollectionNextTarget(task.CollectionID()) + scheduler.targetMgr.UpdateCollectionNextTarget(task.Context(), task.CollectionID()) } task.Cancel(nil) @@ -884,7 +884,7 @@ func (scheduler *taskScheduler) remove(task Task) { scheduler.updateTaskMetrics() log.Info("task removed") - if scheduler.meta.Exist(task.CollectionID()) { + if scheduler.meta.Exist(task.Context(), task.CollectionID()) { metrics.QueryCoordTaskLatency.WithLabelValues(fmt.Sprint(task.CollectionID()), scheduler.getTaskMetricsLabel(task), task.Shard()).Observe(float64(task.GetTaskLatency())) } @@ -985,7 +985,7 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error { return merr.WrapErrNodeOffline(action.Node()) } taskType := GetTaskType(task) - segment := scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst) + segment := scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst) if segment == nil { log.Warn("task stale due to the segment to load not exists in targets", zap.Int64("segment", task.segmentID), @@ -994,7 +994,7 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error { return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment") } - replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) + replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.ctx, task.CollectionID(), action.Node()) if replica == nil { log.Warn("task stale due to replica not found") return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID") @@ -1027,7 +1027,7 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error { log.Warn("task stale due to node offline", zap.String("channel", task.Channel())) return merr.WrapErrNodeOffline(action.Node()) } - if scheduler.targetMgr.GetDmChannel(task.collectionID, task.Channel(), meta.NextTargetFirst) == nil { + if scheduler.targetMgr.GetDmChannel(task.ctx, task.collectionID, task.Channel(), meta.NextTargetFirst) == nil { log.Warn("the task is stale, the channel to subscribe not exists in targets", zap.String("channel", task.Channel())) return merr.WrapErrChannelReduplicate(task.Channel(), "target doesn't contain this channel") @@ -1058,7 +1058,7 @@ func (scheduler *taskScheduler) checkLeaderTaskStale(task *LeaderTask) error { } taskType := GetTaskType(task) - segment := scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst) + segment := scheduler.targetMgr.GetSealedSegment(task.ctx, task.CollectionID(), task.SegmentID(), meta.CurrentTargetFirst) if segment == nil { log.Warn("task stale due to the segment to load not exists in targets", zap.Int64("segment", task.segmentID), @@ -1067,7 +1067,7 @@ func (scheduler *taskScheduler) checkLeaderTaskStale(task *LeaderTask) error { return merr.WrapErrSegmentReduplicate(task.SegmentID(), "target doesn't contain this segment") } - replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.CollectionID(), action.Node()) + replica := scheduler.meta.ReplicaManager.GetByCollectionAndNode(task.ctx, task.CollectionID(), action.Node()) if replica == nil { log.Warn("task stale due to replica not found") return merr.WrapErrReplicaNotFound(task.CollectionID(), "by collectionID") diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 46a775e1acfc3..37c824a9fe258 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -86,6 +86,7 @@ type TaskSuite struct { // Test object scheduler *taskScheduler + ctx context.Context } func (suite *TaskSuite) SetupSuite() { @@ -133,6 +134,7 @@ func (suite *TaskSuite) SetupSuite() { segments: typeutil.NewSet[int64](), }, } + suite.ctx = context.Background() } func (suite *TaskSuite) TearDownSuite() { @@ -193,20 +195,20 @@ func (suite *TaskSuite) BeforeTest(suiteName, testName string) { "TestLeaderTaskSet", "TestLeaderTaskRemove", "TestNoExecutor": - suite.meta.PutCollection(&meta.Collection{ + suite.meta.PutCollection(suite.ctx, &meta.Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: suite.collection, ReplicaNumber: 1, Status: querypb.LoadStatus_Loading, }, }) - suite.meta.PutPartition(&meta.Partition{ + suite.meta.PutPartition(suite.ctx, &meta.Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ CollectionID: suite.collection, PartitionID: 1, }, }) - suite.meta.ReplicaManager.Put(utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3})) + suite.meta.ReplicaManager.Put(suite.ctx, utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3})) } } @@ -276,7 +278,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(dmChannels, nil, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0) // Process tasks @@ -371,7 +373,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(dmChannels, nil, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) // Only first channel exists suite.dist.LeaderViewManager.Update(targetNode, &meta.LeaderView{ @@ -463,7 +465,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) @@ -564,7 +566,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskNotIndex() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) @@ -658,7 +660,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) @@ -874,8 +876,8 @@ func (suite *TaskSuite) TestMoveSegmentTask() { tasks = append(tasks, task) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) - suite.target.UpdateCollectionCurrentTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) + suite.target.UpdateCollectionCurrentTarget(ctx, suite.collection) suite.dist.SegmentDistManager.Update(sourceNode, segments...) suite.dist.LeaderViewManager.Update(leader, view) for _, task := range tasks { @@ -958,8 +960,8 @@ func (suite *TaskSuite) TestMoveSegmentTaskStale() { tasks = append(tasks, task) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) - suite.target.UpdateCollectionCurrentTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) + suite.target.UpdateCollectionCurrentTarget(ctx, suite.collection) suite.dist.LeaderViewManager.Update(leader, view) for _, task := range tasks { err := suite.scheduler.Add(task) @@ -1039,8 +1041,8 @@ func (suite *TaskSuite) TestTaskCanceled() { segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segmentInfos, nil) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collection, partition)) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(suite.collection, partition)) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) // Process tasks suite.dispatchAndWait(targetNode) @@ -1100,7 +1102,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() { suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test load segment task - suite.meta.ReplicaManager.Put(createReplica(suite.collection, targetNode)) + suite.meta.ReplicaManager.Put(ctx, createReplica(suite.collection, targetNode)) suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ CollectionID: suite.collection, ChannelName: channel.ChannelName, @@ -1128,8 +1130,8 @@ func (suite *TaskSuite) TestSegmentTaskStale() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collection, partition)) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(suite.collection, partition)) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) @@ -1166,8 +1168,8 @@ func (suite *TaskSuite) TestSegmentTaskStale() { suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0] suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collection, 2)) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.meta.CollectionManager.PutPartition(ctx, utils.CreateTestPartition(suite.collection, 2)) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) suite.dispatchAndWait(targetNode) suite.AssertTaskNum(0, 0, 0, 0) @@ -1306,7 +1308,7 @@ func (suite *TaskSuite) TestLeaderTaskSet() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) @@ -1452,7 +1454,7 @@ func (suite *TaskSuite) TestNoExecutor() { ChannelName: Params.CommonCfg.RootCoordDml.GetValue() + "-test", } - suite.meta.ReplicaManager.Put(utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3, -1})) + suite.meta.ReplicaManager.Put(ctx, utils.CreateTestReplica(suite.replica.GetID(), suite.collection, []int64{1, 2, 3, -1})) // Test load segment task suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ @@ -1479,7 +1481,7 @@ func (suite *TaskSuite) TestNoExecutor() { suite.NoError(err) } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{channel}, segments, nil) - suite.target.UpdateCollectionNextTarget(suite.collection) + suite.target.UpdateCollectionNextTarget(ctx, suite.collection) segmentsNum := len(suite.loadSegments) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) @@ -1625,6 +1627,7 @@ func createReplica(collection int64, nodes ...int64) *meta.Replica { } func (suite *TaskSuite) TestBalanceChannelTask() { + ctx := context.Background() collectionID := int64(1) partitionID := int64(1) channel := "channel-1" @@ -1653,12 +1656,12 @@ func (suite *TaskSuite) TestBalanceChannelTask() { InsertChannel: channel, }, } - suite.meta.PutCollection(utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1)) + suite.meta.PutCollection(ctx, utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1)) suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return([]*datapb.VchannelInfo{vchannel}, segments, nil) - suite.target.UpdateCollectionNextTarget(collectionID) - suite.target.UpdateCollectionCurrentTarget(collectionID) - suite.target.UpdateCollectionNextTarget(collectionID) + suite.target.UpdateCollectionNextTarget(ctx, collectionID) + suite.target.UpdateCollectionCurrentTarget(ctx, collectionID) + suite.target.UpdateCollectionNextTarget(ctx, collectionID) suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{ ID: 2, @@ -1712,6 +1715,7 @@ func (suite *TaskSuite) TestBalanceChannelTask() { } func (suite *TaskSuite) TestBalanceChannelWithL0SegmentTask() { + ctx := context.Background() collectionID := int64(1) partitionID := int64(1) channel := "channel-1" @@ -1743,12 +1747,12 @@ func (suite *TaskSuite) TestBalanceChannelWithL0SegmentTask() { Level: datapb.SegmentLevel_L0, }, } - suite.meta.PutCollection(utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1)) + suite.meta.PutCollection(ctx, utils.CreateTestCollection(collectionID, 1), utils.CreateTestPartition(collectionID, 1)) suite.broker.ExpectedCalls = nil suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return([]*datapb.VchannelInfo{vchannel}, segments, nil) - suite.target.UpdateCollectionNextTarget(collectionID) - suite.target.UpdateCollectionCurrentTarget(collectionID) - suite.target.UpdateCollectionNextTarget(collectionID) + suite.target.UpdateCollectionNextTarget(ctx, collectionID) + suite.target.UpdateCollectionCurrentTarget(ctx, collectionID) + suite.target.UpdateCollectionNextTarget(ctx, collectionID) suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{ ID: 2, diff --git a/internal/querycoordv2/utils/meta.go b/internal/querycoordv2/utils/meta.go index 8a20eb3bc7c63..9139379a5b784 100644 --- a/internal/querycoordv2/utils/meta.go +++ b/internal/querycoordv2/utils/meta.go @@ -17,6 +17,7 @@ package utils import ( + "context" "strings" "github.com/cockroachdb/errors" @@ -29,10 +30,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) { - collection := collectionMgr.GetCollection(collectionID) +func GetPartitions(ctx context.Context, collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) { + collection := collectionMgr.GetCollection(ctx, collectionID) if collection != nil { - partitions := collectionMgr.GetPartitionsByCollection(collectionID) + partitions := collectionMgr.GetPartitionsByCollection(ctx, collectionID) if partitions != nil { return lo.Map(partitions, func(partition *meta.Partition, i int) int64 { return partition.PartitionID @@ -45,9 +46,9 @@ func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([ // GroupNodesByReplica groups nodes by replica, // returns ReplicaID -> NodeIDs -func GroupNodesByReplica(replicaMgr *meta.ReplicaManager, collectionID int64, nodes []int64) map[int64][]int64 { +func GroupNodesByReplica(ctx context.Context, replicaMgr *meta.ReplicaManager, collectionID int64, nodes []int64) map[int64][]int64 { ret := make(map[int64][]int64) - replicas := replicaMgr.GetByCollection(collectionID) + replicas := replicaMgr.GetByCollection(ctx, collectionID) for _, replica := range replicas { for _, node := range nodes { if replica.Contains(node) { @@ -71,9 +72,9 @@ func GroupPartitionsByCollection(partitions []*meta.Partition) map[int64][]*meta // GroupSegmentsByReplica groups segments by replica, // returns ReplicaID -> Segments -func GroupSegmentsByReplica(replicaMgr *meta.ReplicaManager, collectionID int64, segments []*meta.Segment) map[int64][]*meta.Segment { +func GroupSegmentsByReplica(ctx context.Context, replicaMgr *meta.ReplicaManager, collectionID int64, segments []*meta.Segment) map[int64][]*meta.Segment { ret := make(map[int64][]*meta.Segment) - replicas := replicaMgr.GetByCollection(collectionID) + replicas := replicaMgr.GetByCollection(ctx, collectionID) for _, replica := range replicas { for _, segment := range segments { if replica.Contains(segment.Node) { @@ -85,32 +86,32 @@ func GroupSegmentsByReplica(replicaMgr *meta.ReplicaManager, collectionID int64, } // RecoverReplicaOfCollection recovers all replica of collection with latest resource group. -func RecoverReplicaOfCollection(m *meta.Meta, collectionID typeutil.UniqueID) { +func RecoverReplicaOfCollection(ctx context.Context, m *meta.Meta, collectionID typeutil.UniqueID) { logger := log.With(zap.Int64("collectionID", collectionID)) - rgNames := m.ReplicaManager.GetResourceGroupByCollection(collectionID) + rgNames := m.ReplicaManager.GetResourceGroupByCollection(ctx, collectionID) if rgNames.Len() == 0 { logger.Error("no resource group found for collection", zap.Int64("collectionID", collectionID)) return } - rgs, err := m.ResourceManager.GetNodesOfMultiRG(rgNames.Collect()) + rgs, err := m.ResourceManager.GetNodesOfMultiRG(ctx, rgNames.Collect()) if err != nil { logger.Error("unreachable code as expected, fail to get resource group for replica", zap.Error(err)) return } - if err := m.ReplicaManager.RecoverNodesInCollection(collectionID, rgs); err != nil { + if err := m.ReplicaManager.RecoverNodesInCollection(ctx, collectionID, rgs); err != nil { logger.Warn("fail to set available nodes in replica", zap.Error(err)) } } // RecoverAllCollectionrecovers all replica of all collection in resource group. func RecoverAllCollection(m *meta.Meta) { - for _, collection := range m.CollectionManager.GetAll() { - RecoverReplicaOfCollection(m, collection) + for _, collection := range m.CollectionManager.GetAll(context.TODO()) { + RecoverReplicaOfCollection(context.TODO(), m, collection) } } -func AssignReplica(m *meta.Meta, resourceGroups []string, replicaNumber int32, checkNodeNum bool) (map[string]int, error) { +func AssignReplica(ctx context.Context, m *meta.Meta, resourceGroups []string, replicaNumber int32, checkNodeNum bool) (map[string]int, error) { if len(resourceGroups) != 0 && len(resourceGroups) != 1 && len(resourceGroups) != int(replicaNumber) { return nil, errors.Errorf( "replica=[%d] resource group=[%s], resource group num can only be 0, 1 or same as replica number", replicaNumber, strings.Join(resourceGroups, ",")) @@ -135,10 +136,10 @@ func AssignReplica(m *meta.Meta, resourceGroups []string, replicaNumber int32, c // 2. rg1 is removed. // 3. replica1 spawn finished, but cannot find related resource group. for rgName, num := range replicaNumInRG { - if !m.ContainResourceGroup(rgName) { + if !m.ContainResourceGroup(ctx, rgName) { return nil, merr.WrapErrResourceGroupNotFound(rgName) } - nodes, err := m.ResourceManager.GetNodes(rgName) + nodes, err := m.ResourceManager.GetNodes(ctx, rgName) if err != nil { return nil, err } @@ -155,35 +156,36 @@ func AssignReplica(m *meta.Meta, resourceGroups []string, replicaNumber int32, c } // SpawnReplicasWithRG spawns replicas in rgs one by one for given collection. -func SpawnReplicasWithRG(m *meta.Meta, collection int64, resourceGroups []string, replicaNumber int32, channels []string) ([]*meta.Replica, error) { - replicaNumInRG, err := AssignReplica(m, resourceGroups, replicaNumber, true) +func SpawnReplicasWithRG(ctx context.Context, m *meta.Meta, collection int64, resourceGroups []string, replicaNumber int32, channels []string) ([]*meta.Replica, error) { + replicaNumInRG, err := AssignReplica(ctx, m, resourceGroups, replicaNumber, true) if err != nil { return nil, err } // Spawn it in replica manager. - replicas, err := m.ReplicaManager.Spawn(collection, replicaNumInRG, channels) + replicas, err := m.ReplicaManager.Spawn(ctx, collection, replicaNumInRG, channels) if err != nil { return nil, err } // Active recover it. - RecoverReplicaOfCollection(m, collection) + RecoverReplicaOfCollection(ctx, m, collection) return replicas, nil } func ReassignReplicaToRG( + ctx context.Context, m *meta.Meta, collectionID int64, newReplicaNumber int32, newResourceGroups []string, ) (map[string]int, map[string][]*meta.Replica, []int64, error) { // assign all replicas to newResourceGroups, got each rg's replica number - newAssignment, err := AssignReplica(m, newResourceGroups, newReplicaNumber, false) + newAssignment, err := AssignReplica(ctx, m, newResourceGroups, newReplicaNumber, false) if err != nil { return nil, nil, nil, err } - replicas := m.ReplicaManager.GetByCollection(collectionID) + replicas := m.ReplicaManager.GetByCollection(context.TODO(), collectionID) replicasInRG := lo.GroupBy(replicas, func(replica *meta.Replica) string { return replica.GetResourceGroup() }) diff --git a/internal/querycoordv2/utils/meta_test.go b/internal/querycoordv2/utils/meta_test.go index d5412b0e69e52..9b2bbbbdd8051 100644 --- a/internal/querycoordv2/utils/meta_test.go +++ b/internal/querycoordv2/utils/meta_test.go @@ -17,6 +17,7 @@ package utils import ( + "context" "testing" "github.com/cockroachdb/errors" @@ -51,18 +52,19 @@ func TestSpawnReplicasWithRG(t *testing.T) { require.NoError(t, err) kv := etcdKV.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + ctx := context.Background() store := querycoord.NewCatalog(kv) nodeMgr := session.NewNodeManager() m := meta.NewMeta(RandomIncrementIDAllocator(), store, nodeMgr) - m.ResourceManager.AddResourceGroup("rg1", &rgpb.ResourceGroupConfig{ + m.ResourceManager.AddResourceGroup(ctx, "rg1", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 3}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 3}, }) - m.ResourceManager.AddResourceGroup("rg2", &rgpb.ResourceGroupConfig{ + m.ResourceManager.AddResourceGroup(ctx, "rg2", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 3}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 3}, }) - m.ResourceManager.AddResourceGroup("rg3", &rgpb.ResourceGroupConfig{ + m.ResourceManager.AddResourceGroup(ctx, "rg3", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 3}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 3}, }) @@ -74,13 +76,13 @@ func TestSpawnReplicasWithRG(t *testing.T) { Hostname: "localhost", })) if i%3 == 0 { - m.ResourceManager.HandleNodeUp(int64(i)) + m.ResourceManager.HandleNodeUp(ctx, int64(i)) } if i%3 == 1 { - m.ResourceManager.HandleNodeUp(int64(i)) + m.ResourceManager.HandleNodeUp(ctx, int64(i)) } if i%3 == 2 { - m.ResourceManager.HandleNodeUp(int64(i)) + m.ResourceManager.HandleNodeUp(ctx, int64(i)) } } @@ -120,7 +122,7 @@ func TestSpawnReplicasWithRG(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := SpawnReplicasWithRG(tt.args.m, tt.args.collection, tt.args.resourceGroups, tt.args.replicaNumber, nil) + got, err := SpawnReplicasWithRG(ctx, tt.args.m, tt.args.collection, tt.args.resourceGroups, tt.args.replicaNumber, nil) if (err != nil) != tt.wantErr { t.Errorf("SpawnReplicasWithRG() error = %v, wantErr %v", err, tt.wantErr) return @@ -135,21 +137,22 @@ func TestSpawnReplicasWithRG(t *testing.T) { func TestAddNodesToCollectionsInRGFailed(t *testing.T) { paramtable.Init() + ctx := context.Background() store := mocks.NewQueryCoordCatalog(t) - store.EXPECT().SaveCollection(mock.Anything).Return(nil) - store.EXPECT().SaveReplica(mock.Anything).Return(nil).Times(4) - store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil) + store.EXPECT().SaveCollection(mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil).Times(4) store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil) nodeMgr := session.NewNodeManager() m := meta.NewMeta(RandomIncrementIDAllocator(), store, nodeMgr) - m.ResourceManager.AddResourceGroup("rg", &rgpb.ResourceGroupConfig{ + m.ResourceManager.AddResourceGroup(ctx, "rg", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 0}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 0}, }) - m.CollectionManager.PutCollection(CreateTestCollection(1, 2)) - m.CollectionManager.PutCollection(CreateTestCollection(2, 2)) - m.ReplicaManager.Put(meta.NewReplica( + m.CollectionManager.PutCollection(ctx, CreateTestCollection(1, 2)) + m.CollectionManager.PutCollection(ctx, CreateTestCollection(2, 2)) + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 1, CollectionID: 1, @@ -159,7 +162,7 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) { typeutil.NewUniqueSet(), )) - m.ReplicaManager.Put(meta.NewReplica( + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 2, CollectionID: 1, @@ -169,7 +172,7 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) { typeutil.NewUniqueSet(), )) - m.ReplicaManager.Put(meta.NewReplica( + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 3, CollectionID: 2, @@ -179,7 +182,7 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) { typeutil.NewUniqueSet(), )) - m.ReplicaManager.Put(meta.NewReplica( + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 4, CollectionID: 2, @@ -190,33 +193,34 @@ func TestAddNodesToCollectionsInRGFailed(t *testing.T) { )) storeErr := errors.New("store error") - store.EXPECT().SaveReplica(mock.Anything).Return(storeErr) + store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(storeErr) RecoverAllCollection(m) - assert.Len(t, m.ReplicaManager.Get(1).GetNodes(), 0) - assert.Len(t, m.ReplicaManager.Get(2).GetNodes(), 0) - assert.Len(t, m.ReplicaManager.Get(3).GetNodes(), 0) - assert.Len(t, m.ReplicaManager.Get(4).GetNodes(), 0) + assert.Len(t, m.ReplicaManager.Get(ctx, 1).GetNodes(), 0) + assert.Len(t, m.ReplicaManager.Get(ctx, 2).GetNodes(), 0) + assert.Len(t, m.ReplicaManager.Get(ctx, 3).GetNodes(), 0) + assert.Len(t, m.ReplicaManager.Get(ctx, 4).GetNodes(), 0) } func TestAddNodesToCollectionsInRG(t *testing.T) { paramtable.Init() + ctx := context.Background() store := mocks.NewQueryCoordCatalog(t) - store.EXPECT().SaveCollection(mock.Anything).Return(nil) - store.EXPECT().SaveReplica(mock.Anything).Return(nil) - store.EXPECT().SaveReplica(mock.Anything, mock.Anything).Return(nil) - store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil) + store.EXPECT().SaveCollection(mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveReplica(mock.Anything, mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveReplica(mock.Anything, mock.Anything, mock.Anything).Return(nil) store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything).Return(nil) + store.EXPECT().SaveResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(nil) nodeMgr := session.NewNodeManager() m := meta.NewMeta(RandomIncrementIDAllocator(), store, nodeMgr) - m.ResourceManager.AddResourceGroup("rg", &rgpb.ResourceGroupConfig{ + m.ResourceManager.AddResourceGroup(ctx, "rg", &rgpb.ResourceGroupConfig{ Requests: &rgpb.ResourceGroupLimit{NodeNum: 4}, Limits: &rgpb.ResourceGroupLimit{NodeNum: 4}, }) - m.CollectionManager.PutCollection(CreateTestCollection(1, 2)) - m.CollectionManager.PutCollection(CreateTestCollection(2, 2)) - m.ReplicaManager.Put(meta.NewReplica( + m.CollectionManager.PutCollection(ctx, CreateTestCollection(1, 2)) + m.CollectionManager.PutCollection(ctx, CreateTestCollection(2, 2)) + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 1, CollectionID: 1, @@ -226,7 +230,7 @@ func TestAddNodesToCollectionsInRG(t *testing.T) { typeutil.NewUniqueSet(), )) - m.ReplicaManager.Put(meta.NewReplica( + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 2, CollectionID: 1, @@ -236,7 +240,7 @@ func TestAddNodesToCollectionsInRG(t *testing.T) { typeutil.NewUniqueSet(), )) - m.ReplicaManager.Put(meta.NewReplica( + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 3, CollectionID: 2, @@ -246,7 +250,7 @@ func TestAddNodesToCollectionsInRG(t *testing.T) { typeutil.NewUniqueSet(), )) - m.ReplicaManager.Put(meta.NewReplica( + m.ReplicaManager.Put(ctx, meta.NewReplica( &querypb.Replica{ ID: 4, CollectionID: 2, @@ -262,12 +266,12 @@ func TestAddNodesToCollectionsInRG(t *testing.T) { Address: "127.0.0.1", Hostname: "localhost", })) - m.ResourceManager.HandleNodeUp(nodeID) + m.ResourceManager.HandleNodeUp(ctx, nodeID) } RecoverAllCollection(m) - assert.Len(t, m.ReplicaManager.Get(1).GetNodes(), 2) - assert.Len(t, m.ReplicaManager.Get(2).GetNodes(), 2) - assert.Len(t, m.ReplicaManager.Get(3).GetNodes(), 2) - assert.Len(t, m.ReplicaManager.Get(4).GetNodes(), 2) + assert.Len(t, m.ReplicaManager.Get(ctx, 1).GetNodes(), 2) + assert.Len(t, m.ReplicaManager.Get(ctx, 2).GetNodes(), 2) + assert.Len(t, m.ReplicaManager.Get(ctx, 3).GetNodes(), 2) + assert.Len(t, m.ReplicaManager.Get(ctx, 4).GetNodes(), 2) } diff --git a/internal/querycoordv2/utils/util.go b/internal/querycoordv2/utils/util.go index 47e54278ba840..5e283b926ee60 100644 --- a/internal/querycoordv2/utils/util.go +++ b/internal/querycoordv2/utils/util.go @@ -68,7 +68,7 @@ func CheckDelegatorDataReady(nodeMgr *session.NodeManager, targetMgr meta.Target return err } } - segmentDist := targetMgr.GetSealedSegmentsByChannel(leader.CollectionID, leader.Channel, scope) + segmentDist := targetMgr.GetSealedSegmentsByChannel(context.TODO(), leader.CollectionID, leader.Channel, scope) // Check whether segments are fully loaded for segmentID, info := range segmentDist { _, exist := leader.Segments[segmentID] @@ -87,13 +87,13 @@ func CheckDelegatorDataReady(nodeMgr *session.NodeManager, targetMgr meta.Target } func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) error { - percentage := m.CollectionManager.CalculateLoadPercentage(collectionID) + percentage := m.CollectionManager.CalculateLoadPercentage(ctx, collectionID) if percentage < 0 { err := merr.WrapErrCollectionNotLoaded(collectionID) log.Ctx(ctx).Warn("failed to GetShardLeaders", zap.Error(err)) return err } - collection := m.CollectionManager.GetCollection(collectionID) + collection := m.CollectionManager.GetCollection(ctx, collectionID) if collection != nil && collection.GetStatus() == querypb.LoadStatus_Loaded { // when collection is loaded, regard collection as readable, set percentage == 100 percentage = 100 @@ -108,7 +108,7 @@ func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) erro return nil } -func GetShardLeadersWithChannels(m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, +func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64, channels map[string]*meta.DmChannel, ) ([]*querypb.ShardLeadersList, error) { ret := make([]*querypb.ShardLeadersList, 0) @@ -137,7 +137,7 @@ func GetShardLeadersWithChannels(m *meta.Meta, targetMgr meta.TargetManagerInter return nil, err } - readableLeaders = filterDupLeaders(m.ReplicaManager, readableLeaders) + readableLeaders = filterDupLeaders(ctx, m.ReplicaManager, readableLeaders) ids := make([]int64, 0, len(leaders)) addrs := make([]string, 0, len(leaders)) for _, leader := range readableLeaders { @@ -174,26 +174,26 @@ func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetMan return nil, err } - channels := targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget) + channels := targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget) if len(channels) == 0 { msg := "loaded collection do not found any channel in target, may be in recovery" err := merr.WrapErrCollectionOnRecovering(collectionID, msg) log.Ctx(ctx).Warn("failed to get channels", zap.Error(err)) return nil, err } - return GetShardLeadersWithChannels(m, targetMgr, dist, nodeMgr, collectionID, channels) + return GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels) } // CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection func CheckCollectionsQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager) error { maxInterval := paramtable.Get().QueryCoordCfg.UpdateCollectionLoadStatusInterval.GetAsDuration(time.Minute) - for _, coll := range m.GetAllCollections() { + for _, coll := range m.GetAllCollections(ctx) { err := checkCollectionQueryable(ctx, m, targetMgr, dist, nodeMgr, coll) // the collection is not queryable, if meet following conditions: // 1. Some segments are not loaded // 2. Collection is not starting to release // 3. The load percentage has not been updated in the last 5 minutes. - if err != nil && m.Exist(coll.CollectionID) && time.Since(coll.UpdatedAt) >= maxInterval { + if err != nil && m.Exist(ctx, coll.CollectionID) && time.Since(coll.UpdatedAt) >= maxInterval { log.Ctx(ctx).Warn("collection not querable", zap.Int64("collectionID", coll.CollectionID), zap.Time("lastUpdated", coll.UpdatedAt), @@ -212,7 +212,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta. return err } - channels := targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget) + channels := targetMgr.GetDmChannelsByCollection(ctx, collectionID, meta.CurrentTarget) if len(channels) == 0 { msg := "loaded collection do not found any channel in target, may be in recovery" err := merr.WrapErrCollectionOnRecovering(collectionID, msg) @@ -220,7 +220,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta. return err } - shardList, err := GetShardLeadersWithChannels(m, targetMgr, dist, nodeMgr, collectionID, channels) + shardList, err := GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels) if err != nil { return err } @@ -232,7 +232,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta. return nil } -func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView { +func filterDupLeaders(ctx context.Context, replicaManager *meta.ReplicaManager, leaders map[int64]*meta.LeaderView) map[int64]*meta.LeaderView { type leaderID struct { ReplicaID int64 Shard string @@ -240,7 +240,7 @@ func filterDupLeaders(replicaManager *meta.ReplicaManager, leaders map[int64]*me newLeaders := make(map[leaderID]*meta.LeaderView) for _, view := range leaders { - replica := replicaManager.GetByCollectionAndNode(view.CollectionID, view.ID) + replica := replicaManager.GetByCollectionAndNode(ctx, view.CollectionID, view.ID) if replica == nil { continue } diff --git a/internal/querycoordv2/utils/util_test.go b/internal/querycoordv2/utils/util_test.go index f6f3d7fe285ab..a111600e193bf 100644 --- a/internal/querycoordv2/utils/util_test.go +++ b/internal/querycoordv2/utils/util_test.go @@ -59,13 +59,13 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliable() { } mockTargetManager := meta.NewMockTargetManager(suite.T()) - mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ + mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ 2: { ID: 2, InsertChannel: "test", }, }).Maybe() - mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe() + mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe() suite.setNodeAvailable(1, 2) err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget) @@ -81,13 +81,13 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() { TargetVersion: 1011, } mockTargetManager := meta.NewMockTargetManager(suite.T()) - mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ + mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ 2: { ID: 2, InsertChannel: "test", }, }).Maybe() - mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe() + mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe() // leader nodeID=1 not available suite.setNodeAvailable(2) err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget) @@ -103,13 +103,13 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() { } mockTargetManager := meta.NewMockTargetManager(suite.T()) - mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ + mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ 2: { ID: 2, InsertChannel: "test", }, }).Maybe() - mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe() + mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe() // leader nodeID=2 not available suite.setNodeAvailable(1) err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget) @@ -124,14 +124,14 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() { TargetVersion: 1011, } mockTargetManager := meta.NewMockTargetManager(suite.T()) - mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ + mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ // target segmentID=1 not in leadView 1: { ID: 1, InsertChannel: "test", }, }).Maybe() - mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe() + mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe() suite.setNodeAvailable(1, 2) err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget) suite.Error(err) @@ -144,14 +144,14 @@ func (suite *UtilTestSuite) TestCheckLeaderAvaliableFailed() { Segments: map[int64]*querypb.SegmentDist{2: {NodeID: 2}}, } mockTargetManager := meta.NewMockTargetManager(suite.T()) - mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ + mockTargetManager.EXPECT().GetSealedSegmentsByChannel(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[int64]*datapb.SegmentInfo{ // target segmentID=1 not in leadView 1: { ID: 1, InsertChannel: "test", }, }).Maybe() - mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything).Return(1011).Maybe() + mockTargetManager.EXPECT().GetCollectionTargetVersion(mock.Anything, mock.Anything, mock.Anything).Return(1011).Maybe() suite.setNodeAvailable(1, 2) err := CheckDelegatorDataReady(suite.nodeMgr, mockTargetManager, leadview, meta.CurrentTarget) suite.Error(err)