From 08ee8b1951f44640392368604ac6956664747394 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 1 Nov 2024 17:51:48 +0800 Subject: [PATCH 1/3] enhance: Decouple shard client manager from shard cache the old implementation update shard cache and shard client manager at same time, which causes lots of conor case due to concurrent issue without lock. This PR decouple shard client manager from shard cache, so only shard cache will be updated if delegator changes. and make sure shard client manager will always return the right client, and create a new client if not exist. in case of client leak, shard client manager will purge client in async for every 10 minutes. Signed-off-by: Wei Liu --- internal/proxy/lb_balancer.go | 1 + internal/proxy/lb_policy.go | 59 +++--- internal/proxy/lb_policy_test.go | 47 +++-- internal/proxy/look_aside_balancer.go | 19 +- internal/proxy/look_aside_balancer_test.go | 56 +++++- internal/proxy/meta_cache.go | 65 ++++--- internal/proxy/meta_cache_test.go | 1 - internal/proxy/mock_cache.go | 47 +++++ internal/proxy/mock_lb_balancer.go | 33 ++++ internal/proxy/mock_shardclient_manager.go | 98 +++++----- internal/proxy/roundrobin_balancer.go | 2 + internal/proxy/shard_client.go | 204 ++++++++++----------- internal/proxy/shard_client_test.go | 123 +++++-------- internal/proxy/task_policies.go | 3 +- internal/proxy/task_policies_test.go | 1 - internal/proxy/task_query_test.go | 2 +- internal/proxy/task_search_test.go | 2 +- internal/proxy/task_statistic_test.go | 2 +- 18 files changed, 449 insertions(+), 316 deletions(-) diff --git a/internal/proxy/lb_balancer.go b/internal/proxy/lb_balancer.go index f918ab20a7581..3d46cb63c0f19 100644 --- a/internal/proxy/lb_balancer.go +++ b/internal/proxy/lb_balancer.go @@ -23,6 +23,7 @@ import ( ) type LBBalancer interface { + RegisterNodeInfo(nodeInfos []nodeInfo) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) CancelWorkload(node int64, nq int64) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 0201bfec2b480..79ef086da0d9f 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -39,7 +39,7 @@ type ChannelWorkload struct { collectionName string collectionID int64 channel string - shardLeaders []int64 + shardLeaders []nodeInfo nq int64 exec executeFunc retryTimes uint @@ -116,9 +116,20 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll } // try to select the best node from the available nodes -func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { - availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) }) - targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq) +func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) { + filterDelegator := func(nodes []nodeInfo) map[int64]nodeInfo { + ret := make(map[int64]nodeInfo) + for _, node := range nodes { + if !excludeNodes.Contain(node.nodeID) { + ret[node.nodeID] = node + } + } + return ret + } + + availableNodes := filterDelegator(workload.shardLeaders) + balancer.RegisterNodeInfo(lo.Values(availableNodes)) + targetNode, err := balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq) if err != nil { log := log.Ctx(ctx) globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) @@ -128,32 +139,33 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), zap.Error(err)) - return -1, err + return nodeInfo{}, err } - availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) }) + availableNodes = filterDelegator(shardLeaders[workload.channel]) if len(availableNodes) == 0 { - nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }) log.Warn("no available shard delegator found", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64s("nodes", nodes), + zap.Int64s("availableNodes", lo.Keys(availableNodes)), zap.Int64s("excluded", excludeNodes.Collect())) - return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found") + return nodeInfo{}, merr.WrapErrChannelNotAvailable("no available shard delegator found") } - targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq) + balancer.RegisterNodeInfo(lo.Values(availableNodes)) + targetNode, err = balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq) if err != nil { log.Warn("failed to select shard", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64s("availableNodes", availableNodes), + zap.Int64s("availableNodes", lo.Keys(availableNodes)), + zap.Int64s("excluded", excludeNodes.Collect()), zap.Error(err)) - return -1, err + return nodeInfo{}, err } } - return targetNode, nil + return availableNodes[targetNode], nil } // ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes. @@ -168,7 +180,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo log.Warn("failed to select node for shard", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64("nodeID", targetNode), + zap.Int64("nodeID", targetNode.nodeID), zap.Error(err), ) if lastErr != nil { @@ -177,29 +189,30 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo return err } // cancel work load which assign to the target node - defer balancer.CancelWorkload(targetNode, workload.nq) + defer balancer.CancelWorkload(targetNode.nodeID, workload.nq) client, err := lb.clientMgr.GetClient(ctx, targetNode) + defer lb.clientMgr.ReleaseClient(targetNode.nodeID) if err != nil { log.Warn("search/query channel failed, node not available", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64("nodeID", targetNode), + zap.Int64("nodeID", targetNode.nodeID), zap.Error(err)) - excludeNodes.Insert(targetNode) + excludeNodes.Insert(targetNode.nodeID) lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel) return lastErr } - err = workload.exec(ctx, targetNode, client, workload.channel) + err = workload.exec(ctx, targetNode.nodeID, client, workload.channel) if err != nil { log.Warn("search/query channel failed", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64("nodeID", targetNode), + zap.Int64("nodeID", targetNode.nodeID), zap.Error(err)) - excludeNodes.Insert(targetNode) + excludeNodes.Insert(targetNode.nodeID) lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel) return lastErr } @@ -221,9 +234,9 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad // let every request could retry at least twice, which could retry after update shard leader cache retryTimes := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt() wg, ctx := errgroup.WithContext(ctx) - for channel, nodes := range dml2leaders { - channel := channel - nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID }) + for k, v := range dml2leaders { + channel := k + nodes := v channelRetryTimes := retryTimes if len(nodes) > 0 { channelRetryTimes *= len(nodes) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 0f0f8e46885bb..4a96ef9c824bb 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -49,7 +49,8 @@ type LBPolicySuite struct { lbBalancer *MockLBBalancer lbPolicy *LBPolicyImpl - nodes []int64 + nodeIDs []int64 + nodes []nodeInfo channels []string qnList []*mocks.MockQueryNode @@ -62,7 +63,14 @@ func (s *LBPolicySuite) SetupSuite() { } func (s *LBPolicySuite) SetupTest() { - s.nodes = []int64{1, 2, 3, 4, 5} + s.nodeIDs = make([]int64, 0) + for i := 1; i <= 5; i++ { + s.nodeIDs = append(s.nodeIDs, int64(i)) + s.nodes = append(s.nodes, nodeInfo{ + nodeID: int64(i), + address: "localhost", + }) + } s.channels = []string{"channel1", "channel2"} successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc := mocks.NewMockQueryCoordClient(s.T()) @@ -74,12 +82,12 @@ func (s *LBPolicySuite) SetupTest() { Shards: []*querypb.ShardLeadersList{ { ChannelName: s.channels[0], - NodeIds: s.nodes, + NodeIds: s.nodeIDs, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"}, }, { ChannelName: s.channels[1], - NodeIds: s.nodes, + NodeIds: s.nodeIDs, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"}, }, }, @@ -96,7 +104,6 @@ func (s *LBPolicySuite) SetupTest() { s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.mgr = NewMockShardClientManager(s.T()) - s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.lbBalancer = NewMockLBBalancer(s.T()) s.lbBalancer.EXPECT().Start(context.Background()).Maybe() s.lbPolicy = NewLBPolicyImpl(s.mgr) @@ -164,6 +171,7 @@ func (s *LBPolicySuite) loadCollection() { func (s *LBPolicySuite) TestSelectNode() { ctx := context.Background() + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, @@ -174,10 +182,11 @@ func (s *LBPolicySuite) TestSelectNode() { nq: 1, }, typeutil.NewUniqueSet()) s.NoError(err) - s.Equal(int64(5), targetNode) + s.Equal(int64(5), targetNode.nodeID) // test select node failed, then update shard leader cache and retry, expect success s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ @@ -185,28 +194,29 @@ func (s *LBPolicySuite) TestSelectNode() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: []int64{}, + shardLeaders: []nodeInfo{}, nq: 1, }, typeutil.NewUniqueSet()) s.NoError(err) - s.Equal(int64(3), targetNode) + s.Equal(int64(3), targetNode.nodeID) // test select node always fails, expected failure s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: []int64{}, + shardLeaders: []nodeInfo{}, nq: 1, }, typeutil.NewUniqueSet()) s.ErrorIs(err, merr.ErrNodeNotAvailable) - s.Equal(int64(-1), targetNode) // test all nodes has been excluded, expected failure s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, @@ -215,12 +225,12 @@ func (s *LBPolicySuite) TestSelectNode() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - }, typeutil.NewUniqueSet(s.nodes...)) + }, typeutil.NewUniqueSet(s.nodeIDs...)) s.ErrorIs(err, merr.ErrChannelNotAvailable) - s.Equal(int64(-1), targetNode) // test get shard leaders failed, retry to select node failed s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.qc.ExpectedCalls = nil s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable) @@ -233,7 +243,6 @@ func (s *LBPolicySuite) TestSelectNode() { nq: 1, }, typeutil.NewUniqueSet()) s.ErrorIs(err, merr.ErrServiceUnavailable) - s.Equal(int64(-1), targetNode) } func (s *LBPolicySuite) TestExecuteWithRetry() { @@ -241,7 +250,9 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test execute success s.lbBalancer.ExpectedCalls = nil + s.mgr.EXPECT().ReleaseClient(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ @@ -260,6 +271,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test select node failed, expected error s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, @@ -277,8 +289,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test get client failed, and retry failed, expected success s.mgr.ExpectedCalls = nil + s.mgr.EXPECT().ReleaseClient(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ @@ -296,8 +310,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.Error(err) s.mgr.ExpectedCalls = nil + s.mgr.EXPECT().ReleaseClient(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, @@ -315,8 +331,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test exec failed, then retry success s.mgr.ExpectedCalls = nil + s.mgr.EXPECT().ReleaseClient(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) counter := 0 @@ -341,6 +359,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test exec timeout s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.mgr.EXPECT().ReleaseClient(mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1) @@ -365,7 +384,9 @@ func (s *LBPolicySuite) TestExecute() { ctx := context.Background() mockErr := errors.New("mock error") // test all channel success + s.mgr.EXPECT().ReleaseClient(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index 3bd7c3d63e7fb..aa7c964dda469 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -44,7 +44,8 @@ type CostMetrics struct { type LookAsideBalancer struct { clientMgr shardClientMgr - metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics] + knownNodeInfos *typeutil.ConcurrentMap[int64, nodeInfo] + metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics] // query node id -> number of consecutive heartbeat failures failedHeartBeatCounter *typeutil.ConcurrentMap[int64, *atomic.Int64] @@ -64,6 +65,7 @@ type LookAsideBalancer struct { func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer { balancer := &LookAsideBalancer{ clientMgr: clientMgr, + knownNodeInfos: typeutil.NewConcurrentMap[int64, nodeInfo](), metricsMap: typeutil.NewConcurrentMap[int64, *CostMetrics](), failedHeartBeatCounter: typeutil.NewConcurrentMap[int64, *atomic.Int64](), closeCh: make(chan struct{}), @@ -88,6 +90,12 @@ func (b *LookAsideBalancer) Close() { }) } +func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) { + for _, node := range nodeInfos { + b.knownNodeInfos.Insert(node.nodeID, node) + } +} + func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) { targetNode := int64(-1) defer func() { @@ -233,9 +241,10 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { case <-ticker.C: var futures []*conc.Future[any] now := time.Now() - b.metricsMap.Range(func(node int64, metrics *CostMetrics) bool { + b.knownNodeInfos.Range(func(node int64, info nodeInfo) bool { futures = append(futures, pool.Submit(func() (any, error) { - if now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() { + metrics, ok := b.metricsMap.Get(node) + if !ok || now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() { checkTimeout := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), checkTimeout) defer cancel() @@ -244,13 +253,14 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { panic("let it panic") } - qn, err := b.clientMgr.GetClient(ctx, node) + qn, err := b.clientMgr.GetClient(ctx, info) if err != nil { // get client from clientMgr failed, which means this qn isn't a shard leader anymore, skip it's health check b.trySetQueryNodeUnReachable(node, err) log.RatedInfo(10, "get client failed", zap.Int64("node", node), zap.Error(err)) return struct{}{}, nil } + defer b.clientMgr.ReleaseClient(node) resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { @@ -304,6 +314,7 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) { zap.Int64("nodeID", node)) // stop the heartbeat b.metricsMap.Remove(node) + b.knownNodeInfos.Remove(node) return } diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index 69d57c5f550ff..abd55070562fa 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -42,11 +43,12 @@ type LookAsideBalancerSuite struct { func (suite *LookAsideBalancerSuite) SetupTest() { suite.clientMgr = NewMockShardClientManager(suite.T()) + suite.clientMgr.EXPECT().ReleaseClient(mock.Anything).Maybe() suite.balancer = NewLookAsideBalancer(suite.clientMgr) suite.balancer.Start(context.Background()) qn := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe() + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe() } @@ -298,22 +300,46 @@ func (suite *LookAsideBalancerSuite) TestCancelWorkload() { } func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { + qn := mocks.NewMockQueryNodeClient(suite.T()) + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe() qn2 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil).Maybe() qn2.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, }, nil).Maybe() + suite.clientMgr.ExpectedCalls = nil + suite.clientMgr.EXPECT().ReleaseClient(mock.Anything).Maybe() + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) { + if ni.nodeID == 1 { + return qn, nil + } + + if ni.nodeID == 2 { + return qn2, nil + } + return nil, errors.New("unexpected node") + }).Maybe() metrics1 := &CostMetrics{} metrics1.ts.Store(time.Now().UnixMilli()) metrics1.unavailable.Store(true) suite.balancer.metricsMap.Insert(1, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 1, + }, + }) metrics2 := &CostMetrics{} metrics2.ts.Store(time.Now().UnixMilli()) metrics2.unavailable.Store(true) suite.balancer.metricsMap.Insert(2, metrics2) + suite.balancer.knownNodeInfos.Insert(2, nodeInfo{}) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 2, + }, + }) suite.Eventually(func() bool { metrics, ok := suite.balancer.metricsMap.Get(1) return ok && metrics.unavailable.Load() @@ -339,10 +365,16 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() { metrics1.ts.Store(time.Now().UnixMilli()) metrics1.unavailable.Store(true) suite.balancer.metricsMap.Insert(2, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 2, + }, + }) // test get shard client from client mgr return nil suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(nil, errors.New("shard client not found")) + suite.clientMgr.EXPECT().ReleaseClient(mock.Anything).Maybe() + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("shard client not found")) // expected stopping the health check after failure times reaching the limit suite.Eventually(func() bool { return !suite.balancer.metricsMap.Contain(2) @@ -352,7 +384,9 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() { func (suite *LookAsideBalancerSuite) TestNodeRecover() { // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) + suite.clientMgr.ExpectedCalls = nil + suite.clientMgr.EXPECT().ReleaseClient(mock.Anything) + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Abnormal, @@ -368,6 +402,11 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() { metrics1 := &CostMetrics{} metrics1.ts.Store(time.Now().UnixMilli()) suite.balancer.metricsMap.Insert(3, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 3, + }, + }) suite.Eventually(func() bool { metrics, ok := suite.balancer.metricsMap.Get(3) return ok && metrics.unavailable.Load() @@ -384,7 +423,9 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { Params.Save(Params.ProxyCfg.HealthCheckTimeout.Key, "1000") // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) + suite.clientMgr.ExpectedCalls = nil + suite.clientMgr.EXPECT().ReleaseClient(mock.Anything) + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Abnormal, @@ -394,6 +435,11 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { metrics1 := &CostMetrics{} metrics1.ts.Store(time.Now().UnixMilli()) suite.balancer.metricsMap.Insert(3, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 3, + }, + }) suite.Eventually(func() bool { metrics, ok := suite.balancer.metricsMap.Get(3) return ok && metrics.unavailable.Load() diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 16679ebc3e9d4..134d799d8bc7c 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -72,6 +72,7 @@ type Cache interface { GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) DeprecateShardCache(database, collectionName string) InvalidateShardLeaderCache(collections []int64) + ListShardLocation() map[int64]nodeInfo RemoveCollection(ctx context.Context, database, collectionName string) RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string RemovePartition(ctx context.Context, database, collectionName string, partitionName string) @@ -288,9 +289,7 @@ func (info *collectionInfo) isCollectionCached() bool { // shardLeaders wraps shard leader mapping for iteration. type shardLeaders struct { - idx *atomic.Int64 - deprecated *atomic.Bool - + idx *atomic.Int64 collectionID int64 shardLeaders map[string][]nodeInfo } @@ -419,19 +418,19 @@ func (m *MetaCache) getCollection(database, collectionName string, collectionID return nil, false } -func (m *MetaCache) getCollectionShardLeader(database, collectionName string) (*shardLeaders, bool) { +func (m *MetaCache) getCollectionShardLeader(database, collectionName string) *shardLeaders { m.leaderMut.RLock() defer m.leaderMut.RUnlock() db, ok := m.collLeader[database] if !ok { - return nil, false + return nil } if leaders, ok := db[collectionName]; ok { - return leaders, !leaders.deprecated.Load() + return leaders } - return nil, false + return nil } func (m *MetaCache) update(ctx context.Context, database, collectionName string, collectionID UniqueID) (*collectionInfo, error) { @@ -957,9 +956,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col zap.String("collectionName", collectionName), zap.Int64("collectionID", collectionID)) - cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName) + cacheShardLeaders := m.getCollectionShardLeader(database, collectionName) if withCache { - if ok { + if cacheShardLeaders != nil { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() iterator := cacheShardLeaders.GetReader() return iterator.Shuffle(), nil @@ -995,11 +994,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col newShardLeaders := &shardLeaders{ collectionID: info.collID, shardLeaders: shards, - deprecated: atomic.NewBool(false), idx: atomic.NewInt64(0), } - // lock leader m.leaderMut.Lock() if _, ok := m.collLeader[database]; !ok { m.collLeader[database] = make(map[string]*shardLeaders) @@ -1008,15 +1005,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col iterator := newShardLeaders.GetReader() ret := iterator.Shuffle() - - oldLeaders := make(map[string][]nodeInfo) - if cacheShardLeaders != nil { - oldLeaders = cacheShardLeaders.shardLeaders - } - // update refcnt in shardClientMgr - // update shard leader's just create a empty client pool - // and init new client will be execute in getClient - _ = m.shardMgr.UpdateShardLeaders(oldLeaders, ret) m.leaderMut.Unlock() metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) @@ -1042,23 +1030,50 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m // DeprecateShardCache clear the shard leader cache of a collection func (m *MetaCache) DeprecateShardCache(database, collectionName string) { log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) - if shards, ok := m.getCollectionShardLeader(database, collectionName); ok { - shards.deprecated.Store(true) + m.leaderMut.Lock() + defer m.leaderMut.Unlock() + dbInfo, ok := m.collLeader[database] + if ok { + delete(dbInfo, collectionName) + if len(dbInfo) == 0 { + delete(dbInfo, database) + } } } +// used for Garbage collection shard client +func (m *MetaCache) ListShardLocation() map[int64]nodeInfo { + m.leaderMut.Lock() + defer m.leaderMut.Unlock() + shardLeaderInfo := make(map[int64]nodeInfo) + + for _, dbInfo := range m.collLeader { + for _, shardLeaders := range dbInfo { + for _, nodeInfos := range shardLeaders.shardLeaders { + for _, node := range nodeInfos { + shardLeaderInfo[node.nodeID] = node + } + } + } + } + return shardLeaderInfo +} + func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) { log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections)) m.leaderMut.Lock() defer m.leaderMut.Unlock() collectionSet := typeutil.NewUniqueSet(collections...) - for _, db := range m.collLeader { - for _, shardLeaders := range db { + for dbName, dbInfo := range m.collLeader { + for collectionName, shardLeaders := range dbInfo { if collectionSet.Contain(shardLeaders.collectionID) { - shardLeaders.deprecated.Store(true) + delete(dbInfo, collectionName) } } + if len(dbInfo) == 0 { + delete(dbInfo, dbName) + } } } diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 3f97a1dca6bbc..7ec54463bd5fc 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -805,7 +805,6 @@ func TestGlobalMetaCache_ShuffleShardLeaders(t *testing.T) { }, } sl := &shardLeaders{ - deprecated: uatomic.NewBool(false), idx: uatomic.NewInt64(5), shardLeaders: shards, } diff --git a/internal/proxy/mock_cache.go b/internal/proxy/mock_cache.go index cb1697b7f2ace..6c961e2a02cf0 100644 --- a/internal/proxy/mock_cache.go +++ b/internal/proxy/mock_cache.go @@ -981,6 +981,53 @@ func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int return _c } +// ListShardLocation provides a mock function with given fields: +func (_m *MockCache) ListShardLocation() map[int64]nodeInfo { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ListShardLocation") + } + + var r0 map[int64]nodeInfo + if rf, ok := ret.Get(0).(func() map[int64]nodeInfo); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]nodeInfo) + } + } + + return r0 +} + +// MockCache_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation' +type MockCache_ListShardLocation_Call struct { + *mock.Call +} + +// ListShardLocation is a helper method to define mock.On call +func (_e *MockCache_Expecter) ListShardLocation() *MockCache_ListShardLocation_Call { + return &MockCache_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")} +} + +func (_c *MockCache_ListShardLocation_Call) Run(run func()) *MockCache_ListShardLocation_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCache_ListShardLocation_Call) Return(_a0 map[int64]nodeInfo) *MockCache_ListShardLocation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCache_ListShardLocation_Call) RunAndReturn(run func() map[int64]nodeInfo) *MockCache_ListShardLocation_Call { + _c.Call.Return(run) + return _c +} + // RefreshPolicyInfo provides a mock function with given fields: op func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error { ret := _m.Called(op) diff --git a/internal/proxy/mock_lb_balancer.go b/internal/proxy/mock_lb_balancer.go index 8fdbf9e5bbc7a..4d99213078f23 100644 --- a/internal/proxy/mock_lb_balancer.go +++ b/internal/proxy/mock_lb_balancer.go @@ -88,6 +88,39 @@ func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Cl return _c } +// RegisterNodeInfo provides a mock function with given fields: nodeInfos +func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) { + _m.Called(nodeInfos) +} + +// MockLBBalancer_RegisterNodeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterNodeInfo' +type MockLBBalancer_RegisterNodeInfo_Call struct { + *mock.Call +} + +// RegisterNodeInfo is a helper method to define mock.On call +// - nodeInfos []nodeInfo +func (_e *MockLBBalancer_Expecter) RegisterNodeInfo(nodeInfos interface{}) *MockLBBalancer_RegisterNodeInfo_Call { + return &MockLBBalancer_RegisterNodeInfo_Call{Call: _e.mock.On("RegisterNodeInfo", nodeInfos)} +} + +func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]nodeInfo)) + }) + return _c +} + +func (_c *MockLBBalancer_RegisterNodeInfo_Call) Return() *MockLBBalancer_RegisterNodeInfo_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call { + _c.Call.Return(run) + return _c +} + // SelectNode provides a mock function with given fields: ctx, availableNodes, nq func (_m *MockLBBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) { ret := _m.Called(ctx, availableNodes, nq) diff --git a/internal/proxy/mock_shardclient_manager.go b/internal/proxy/mock_shardclient_manager.go index edb6205f4660c..3bc433f138f32 100644 --- a/internal/proxy/mock_shardclient_manager.go +++ b/internal/proxy/mock_shardclient_manager.go @@ -54,9 +54,9 @@ func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShard return _c } -// GetClient provides a mock function with given fields: ctx, nodeID -func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) (types.QueryNodeClient, error) { - ret := _m.Called(ctx, nodeID) +// GetClient provides a mock function with given fields: ctx, nodeInfo1 +func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo1 nodeInfo) (types.QueryNodeClient, error) { + ret := _m.Called(ctx, nodeInfo1) if len(ret) == 0 { panic("no return value specified for GetClient") @@ -64,19 +64,19 @@ func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) ( var r0 types.QueryNodeClient var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (types.QueryNodeClient, error)); ok { - return rf(ctx, nodeID) + if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) (types.QueryNodeClient, error)); ok { + return rf(ctx, nodeInfo1) } - if rf, ok := ret.Get(0).(func(context.Context, int64) types.QueryNodeClient); ok { - r0 = rf(ctx, nodeID) + if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) types.QueryNodeClient); ok { + r0 = rf(ctx, nodeInfo1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(types.QueryNodeClient) } } - if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { - r1 = rf(ctx, nodeID) + if rf, ok := ret.Get(1).(func(context.Context, nodeInfo) error); ok { + r1 = rf(ctx, nodeInfo1) } else { r1 = ret.Error(1) } @@ -91,14 +91,14 @@ type MockShardClientManager_GetClient_Call struct { // GetClient is a helper method to define mock.On call // - ctx context.Context -// - nodeID int64 -func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeID interface{}) *MockShardClientManager_GetClient_Call { - return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeID)} +// - nodeInfo1 nodeInfo +func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo1 interface{}) *MockShardClientManager_GetClient_Call { + return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo1)} } -func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeID int64)) *MockShardClientManager_GetClient_Call { +func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo1 nodeInfo)) *MockShardClientManager_GetClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64)) + run(args[0].(context.Context), args[1].(nodeInfo)) }) return _c } @@ -108,87 +108,73 @@ func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClien return _c } -func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, int64) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call { +func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, nodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call { _c.Call.Return(run) return _c } -// SetClientCreatorFunc provides a mock function with given fields: creator -func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) { - _m.Called(creator) +// ReleaseClient provides a mock function with given fields: nodeID +func (_m *MockShardClientManager) ReleaseClient(nodeID int64) { + _m.Called(nodeID) } -// MockShardClientManager_SetClientCreatorFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetClientCreatorFunc' -type MockShardClientManager_SetClientCreatorFunc_Call struct { +// MockShardClientManager_ReleaseClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseClient' +type MockShardClientManager_ReleaseClient_Call struct { *mock.Call } -// SetClientCreatorFunc is a helper method to define mock.On call -// - creator queryNodeCreatorFunc -func (_e *MockShardClientManager_Expecter) SetClientCreatorFunc(creator interface{}) *MockShardClientManager_SetClientCreatorFunc_Call { - return &MockShardClientManager_SetClientCreatorFunc_Call{Call: _e.mock.On("SetClientCreatorFunc", creator)} +// ReleaseClient is a helper method to define mock.On call +// - nodeID int64 +func (_e *MockShardClientManager_Expecter) ReleaseClient(nodeID interface{}) *MockShardClientManager_ReleaseClient_Call { + return &MockShardClientManager_ReleaseClient_Call{Call: _e.mock.On("ReleaseClient", nodeID)} } -func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Run(run func(creator queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call { +func (_c *MockShardClientManager_ReleaseClient_Call) Run(run func(nodeID int64)) *MockShardClientManager_ReleaseClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(queryNodeCreatorFunc)) + run(args[0].(int64)) }) return _c } -func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Return() *MockShardClientManager_SetClientCreatorFunc_Call { +func (_c *MockShardClientManager_ReleaseClient_Call) Return() *MockShardClientManager_ReleaseClient_Call { _c.Call.Return() return _c } -func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run func(queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call { +func (_c *MockShardClientManager_ReleaseClient_Call) RunAndReturn(run func(int64)) *MockShardClientManager_ReleaseClient_Call { _c.Call.Return(run) return _c } -// UpdateShardLeaders provides a mock function with given fields: oldLeaders, newLeaders -func (_m *MockShardClientManager) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error { - ret := _m.Called(oldLeaders, newLeaders) - - if len(ret) == 0 { - panic("no return value specified for UpdateShardLeaders") - } - - var r0 error - if rf, ok := ret.Get(0).(func(map[string][]nodeInfo, map[string][]nodeInfo) error); ok { - r0 = rf(oldLeaders, newLeaders) - } else { - r0 = ret.Error(0) - } - - return r0 +// SetClientCreatorFunc provides a mock function with given fields: creator +func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) { + _m.Called(creator) } -// MockShardClientManager_UpdateShardLeaders_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateShardLeaders' -type MockShardClientManager_UpdateShardLeaders_Call struct { +// MockShardClientManager_SetClientCreatorFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetClientCreatorFunc' +type MockShardClientManager_SetClientCreatorFunc_Call struct { *mock.Call } -// UpdateShardLeaders is a helper method to define mock.On call -// - oldLeaders map[string][]nodeInfo -// - newLeaders map[string][]nodeInfo -func (_e *MockShardClientManager_Expecter) UpdateShardLeaders(oldLeaders interface{}, newLeaders interface{}) *MockShardClientManager_UpdateShardLeaders_Call { - return &MockShardClientManager_UpdateShardLeaders_Call{Call: _e.mock.On("UpdateShardLeaders", oldLeaders, newLeaders)} +// SetClientCreatorFunc is a helper method to define mock.On call +// - creator queryNodeCreatorFunc +func (_e *MockShardClientManager_Expecter) SetClientCreatorFunc(creator interface{}) *MockShardClientManager_SetClientCreatorFunc_Call { + return &MockShardClientManager_SetClientCreatorFunc_Call{Call: _e.mock.On("SetClientCreatorFunc", creator)} } -func (_c *MockShardClientManager_UpdateShardLeaders_Call) Run(run func(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo)) *MockShardClientManager_UpdateShardLeaders_Call { +func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Run(run func(creator queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string][]nodeInfo), args[1].(map[string][]nodeInfo)) + run(args[0].(queryNodeCreatorFunc)) }) return _c } -func (_c *MockShardClientManager_UpdateShardLeaders_Call) Return(_a0 error) *MockShardClientManager_UpdateShardLeaders_Call { - _c.Call.Return(_a0) +func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Return() *MockShardClientManager_SetClientCreatorFunc_Call { + _c.Call.Return() return _c } -func (_c *MockShardClientManager_UpdateShardLeaders_Call) RunAndReturn(run func(map[string][]nodeInfo, map[string][]nodeInfo) error) *MockShardClientManager_UpdateShardLeaders_Call { +func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run func(queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/roundrobin_balancer.go b/internal/proxy/roundrobin_balancer.go index 983514d6ac0a7..e8fe18cf38061 100644 --- a/internal/proxy/roundrobin_balancer.go +++ b/internal/proxy/roundrobin_balancer.go @@ -32,6 +32,8 @@ func NewRoundRobinBalancer() *RoundRobinBalancer { return &RoundRobinBalancer{} } +func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {} + func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) { if len(availableNodes) == 0 { return -1, merr.ErrNodeNotAvailable diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index 8475494e6659c..d72e001b2d724 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/cockroachdb/errors" "go.uber.org/atomic" @@ -32,7 +33,6 @@ type shardClient struct { sync.RWMutex info nodeInfo isClosed bool - refCnt int clients []types.QueryNodeClient idx atomic.Int64 poolSize int @@ -40,13 +40,15 @@ type shardClient struct { initialized atomic.Bool creator queryNodeCreatorFunc + + refCnt *atomic.Int64 } func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) { if !n.initialized.Load() { n.Lock() if !n.initialized.Load() { - if err := n.initClients(); err != nil { + if err := n.initClients(ctx); err != nil { n.Unlock() return nil, err } @@ -55,28 +57,28 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err n.Unlock() } - n.RLock() - defer n.RUnlock() - if n.isClosed { - return nil, errClosed + // Attempt to get a connection from the idle connection pool, supporting context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + client, err := n.roundRobinSelectClient() + if err != nil { + return nil, err + } + n.refCnt.Inc() + return client, nil } - - idx := n.idx.Inc() - return n.clients[int(idx)%n.poolSize], nil } -func (n *shardClient) inc() { - n.Lock() - defer n.Unlock() - if n.isClosed { - return +func (n *shardClient) Release() { + if n.refCnt.Dec() == 0 { + n.Close() } - n.refCnt++ } func (n *shardClient) close() { n.isClosed = true - n.refCnt = 0 for _, client := range n.clients { if err := client.Close(); err != nil { @@ -86,50 +88,36 @@ func (n *shardClient) close() { n.clients = nil } -func (n *shardClient) dec() bool { - n.Lock() - defer n.Unlock() - if n.isClosed { - return true - } - if n.refCnt > 0 { - n.refCnt-- - } - if n.refCnt == 0 { - n.close() - } - return n.refCnt == 0 -} - func (n *shardClient) Close() { n.Lock() defer n.Unlock() n.close() } -func newPoolingShardClient(info *nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { +func newShardClient(info nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { + num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() + if num <= 0 { + num = 1 + } + return &shardClient{ info: nodeInfo{ nodeID: info.nodeID, address: info.address, }, - refCnt: 1, - pooling: true, - creator: creator, + poolSize: num, + creator: creator, + refCnt: atomic.NewInt64(1), }, nil } -func (n *shardClient) initClients() error { - num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() - if num <= 0 { - num = 1 - } - clients := make([]types.QueryNodeClient, 0, num) - for i := 0; i < num; i++ { - client, err := n.creator(context.Background(), n.info.address, n.info.nodeID) +func (n *shardClient) initClients(ctx context.Context) error { + clients := make([]types.QueryNodeClient, 0, n.poolSize) + for i := 0; i < n.poolSize; i++ { + client, err := n.creator(ctx, n.info.address, n.info.nodeID) if err != nil { - // roll back already created clients - for _, c := range clients[:i] { + // Roll back already created clients + for _, c := range clients { c.Close() } return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID)) @@ -138,13 +126,29 @@ func (n *shardClient) initClients() error { } n.clients = clients - n.poolSize = num return nil } +// roundRobinSelectClient selects a client in a round-robin manner +func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) { + n.Lock() + defer n.Unlock() + if n.isClosed { + return nil, errClosed + } + + if len(n.clients) == 0 { + return nil, errors.New("no available clients") + } + + nextClientIndex := n.idx.Inc() % int64(len(n.clients)) + nextClient := n.clients[nextClientIndex] + return nextClient, nil +} + type shardClientMgr interface { - GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) - UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error + GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error) + ReleaseClient(nodeID int64) Close() SetClientCreatorFunc(creator queryNodeCreatorFunc) } @@ -155,6 +159,8 @@ type shardClientMgrImpl struct { data map[UniqueID]*shardClient } clientCreator queryNodeCreatorFunc + + closeCh chan struct{} } // SessionOpt provides a way to set params in SessionManager @@ -176,10 +182,13 @@ func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl { data map[UniqueID]*shardClient }{data: make(map[UniqueID]*shardClient)}, clientCreator: defaultQueryNodeClientCreator, + closeCh: make(chan struct{}), } for _, opt := range options { opt(s) } + + go s.PurgeClient() return s } @@ -187,79 +196,64 @@ func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) c.clientCreator = creator } -// Warning this method may modify parameter `oldLeaders` -func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error { - oldLocalMap := make(map[UniqueID]*nodeInfo) - for _, nodes := range oldLeaders { - for i := range nodes { - n := &nodes[i] - _, ok := oldLocalMap[n.nodeID] - if !ok { - oldLocalMap[n.nodeID] = n - } - } - } - newLocalMap := make(map[UniqueID]*nodeInfo) - - for _, nodes := range newLeaders { - for i := range nodes { - n := &nodes[i] - _, ok := oldLocalMap[n.nodeID] - if !ok { - _, ok2 := newLocalMap[n.nodeID] - if !ok2 { - newLocalMap[n.nodeID] = n - } - } - delete(oldLocalMap, n.nodeID) - } - } - c.clients.Lock() - defer c.clients.Unlock() +func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) { + c.clients.RLock() + client, ok := c.clients.data[info.nodeID] + c.clients.RUnlock() - for _, node := range newLocalMap { - client, ok := c.clients.data[node.nodeID] - if ok { - client.inc() - } else { - // context.Background() is useless - // TODO QueryNode NewClient remove ctx parameter - // TODO Remove Init && Start interface in QueryNode client - if c.clientCreator == nil { - return fmt.Errorf("clientCreator function is nil") - } - client, err := newPoolingShardClient(node, c.clientCreator) + if !ok { + c.clients.Lock() + // Check again after acquiring the lock + client, ok = c.clients.data[info.nodeID] + if !ok { + // Create a new client if it doesn't exist + newClient, err := newShardClient(info, c.clientCreator) if err != nil { - return err + return nil, err } - c.clients.data[node.nodeID] = client + c.clients.data[info.nodeID] = newClient + client = newClient } + c.clients.Unlock() } - for _, node := range oldLocalMap { - client, ok := c.clients.data[node.nodeID] - if ok && client.dec() { - delete(c.clients.data, node.nodeID) + + return client.getClient(ctx) +} + +func (c *shardClientMgrImpl) PurgeClient() { + ticker := time.NewTicker(600 * time.Second) + defer ticker.Stop() + for { + select { + case <-c.closeCh: + return + case <-ticker.C: + shardLocations := globalMetaCache.ListShardLocation() + for nodeID := range c.clients.data { + if _, ok := shardLocations[nodeID]; !ok { + c.clients.Lock() + delete(c.clients.data, nodeID) + c.clients.Unlock() + c.ReleaseClient(nodeID) + } + } } } - return nil } -func (c *shardClientMgrImpl) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) { - c.clients.RLock() - client, ok := c.clients.data[nodeID] - c.clients.RUnlock() - - if !ok { - return nil, fmt.Errorf("can not find client of node %d", nodeID) +func (c *shardClientMgrImpl) ReleaseClient(nodeID int64) { + c.clients.Lock() + defer c.clients.Unlock() + if client, ok := c.clients.data[nodeID]; ok { + client.Release() } - return client.getClient(ctx) } // Close release clients func (c *shardClientMgrImpl) Close() { c.clients.Lock() defer c.clients.Unlock() - + close(c.closeCh) for _, s := range c.clients.data { s.Close() } diff --git a/internal/proxy/shard_client_test.go b/internal/proxy/shard_client_test.go index 0ef6f516caf2d..809ab0404b3ca 100644 --- a/internal/proxy/shard_client_test.go +++ b/internal/proxy/shard_client_test.go @@ -8,94 +8,59 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) -func genShardLeaderInfo(channel string, leaderIDs []UniqueID) map[string][]nodeInfo { - leaders := make(map[string][]nodeInfo) - nodeInfos := make([]nodeInfo, len(leaderIDs)) - for i, id := range leaderIDs { - nodeInfos[i] = nodeInfo{ - nodeID: id, - address: "fake", - } +func TestShardClientMgr(t *testing.T) { + ctx := context.Background() + nodeInfo := nodeInfo{ + nodeID: 1, } - leaders[channel] = nodeInfos - return leaders -} - -func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) { - mgr := newShardClientMgr(withShardClientCreator(nil)) - mgr.clientCreator = nil - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - err := mgr.UpdateShardLeaders(nil, leaders) - assert.Error(t, err) -} -func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) { - mockCreator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { - return &mocks.MockQueryNodeClient{}, nil + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Close().Return(nil) + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil } - mgr := newShardClientMgr(withShardClientCreator(mockCreator)) - - _, err := mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) - - err = mgr.UpdateShardLeaders(nil, nil) - assert.NoError(t, err) - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - err = mgr.UpdateShardLeaders(leaders, nil) - assert.NoError(t, err) -} - -func TestShardClientMgr_UpdateShardLeaders_NonEmpty(t *testing.T) { mgr := newShardClientMgr() - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - err := mgr.UpdateShardLeaders(nil, leaders) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.NoError(t, err) - - newLeaders := genShardLeaderInfo("c1", []UniqueID{2, 3}) - err = mgr.UpdateShardLeaders(leaders, newLeaders) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) + mgr.SetClientCreatorFunc(creator) + _, err := mgr.GetClient(ctx, nodeInfo) + assert.Nil(t, err) + + mgr.ReleaseClient(1) + assert.Equal(t, len(mgr.clients.data), 1) + mgr.Close() + assert.Equal(t, len(mgr.clients.data), 0) } -func TestShardClientMgr_UpdateShardLeaders_Ref(t *testing.T) { - mgr := newShardClientMgr() - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - - for i := 0; i < 2; i++ { - err := mgr.UpdateShardLeaders(nil, leaders) - assert.NoError(t, err) +func TestShardClient(t *testing.T) { + nodeInfo := nodeInfo{ + nodeID: 1, } - partLeaders := genShardLeaderInfo("c1", []UniqueID{1}) - - _, err := mgr.GetClient(context.Background(), UniqueID(1)) - assert.NoError(t, err) - - err = mgr.UpdateShardLeaders(partLeaders, nil) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.NoError(t, err) - - err = mgr.UpdateShardLeaders(partLeaders, nil) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(2)) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(3)) - assert.NoError(t, err) + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Close().Return(nil) + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil + } + shardClient, err := newShardClient(nodeInfo, creator) + assert.Nil(t, err) + assert.Equal(t, len(shardClient.clients), 0) + assert.Equal(t, int64(1), shardClient.refCnt.Load()) + assert.Equal(t, false, shardClient.initialized.Load()) + + ctx := context.Background() + _, err = shardClient.getClient(ctx) + assert.Nil(t, err) + assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()) + assert.Equal(t, int64(2), shardClient.refCnt.Load()) + assert.Equal(t, true, shardClient.initialized.Load()) + + shardClient.Release() + assert.Equal(t, int64(1), shardClient.refCnt.Load()) + + shardClient.Release() + assert.Equal(t, int64(0), shardClient.refCnt.Load()) + assert.Equal(t, true, shardClient.isClosed) } diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index 56c662e34b1b7..bb64f4b73c05c 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -33,12 +33,13 @@ func RoundRobinPolicy( leaders := dml2leaders[channel] for _, target := range leaders { - qn, err := mgr.GetClient(ctx, target.nodeID) + qn, err := mgr.GetClient(ctx, target) if err != nil { log.Warn("query channel failed, node not available", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) combineErr = merr.Combine(combineErr, err) continue } + defer mgr.ReleaseClient(target.nodeID) err = query(ctx, target.nodeID, qn, channel) if err != nil { log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) diff --git a/internal/proxy/task_policies_test.go b/internal/proxy/task_policies_test.go index 5c4b1732824ac..98bc97096cd70 100644 --- a/internal/proxy/task_policies_test.go +++ b/internal/proxy/task_policies_test.go @@ -26,7 +26,6 @@ func TestRoundRobinPolicy(t *testing.T) { "c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}}, "c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}}, } - mgr.UpdateShardLeaders(nil, shard2leaders) querier := &mockQuery{} querier.init() diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 104ddc79b9f81..7e07f4d898254 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -79,8 +79,8 @@ func TestQueryTask_all(t *testing.T) { }, nil).Maybe() mgr := NewMockShardClientManager(t) + mgr.EXPECT().ReleaseClient(mock.Anything) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() - mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() lb := NewLBPolicyImpl(mgr) defer rc.Close() diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 5b5e618a2f6be..6be4eb6753de5 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2111,8 +2111,8 @@ func TestSearchTask_ErrExecute(t *testing.T) { qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := NewMockShardClientManager(t) + mgr.EXPECT().ReleaseClient(mock.Anything) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() - mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() lb := NewLBPolicyImpl(mgr) defer qc.Close() diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index 42f0d63b4480d..f29f1002f749f 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -80,8 +80,8 @@ func (s *StatisticTaskSuite) SetupTest() { s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() mgr := NewMockShardClientManager(s.T()) + mgr.EXPECT().ReleaseClient(mock.Anything).Maybe() mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe() - mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.lb = NewLBPolicyImpl(mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, mgr) From ead8b31551dde1cd83f7c6b2edb938b7e34bfd45 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 8 Nov 2024 19:01:40 +0800 Subject: [PATCH 2/3] fix review comment Signed-off-by: Wei Liu --- internal/proxy/lb_policy.go | 2 +- internal/proxy/lb_policy_test.go | 12 ++++----- internal/proxy/look_aside_balancer.go | 2 +- internal/proxy/look_aside_balancer_test.go | 10 ++++---- internal/proxy/meta_cache.go | 4 +-- internal/proxy/mock_shardclient_manager.go | 20 +++++++-------- internal/proxy/shard_client.go | 29 ++++++++++++++-------- internal/proxy/shard_client_test.go | 6 ++--- internal/proxy/task_policies.go | 2 +- internal/proxy/task_query_test.go | 2 +- internal/proxy/task_search_test.go | 2 +- internal/proxy/task_statistic_test.go | 2 +- 12 files changed, 50 insertions(+), 43 deletions(-) diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 79ef086da0d9f..b2dd8e2ab52c6 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -192,7 +192,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo defer balancer.CancelWorkload(targetNode.nodeID, workload.nq) client, err := lb.clientMgr.GetClient(ctx, targetNode) - defer lb.clientMgr.ReleaseClient(targetNode.nodeID) if err != nil { log.Warn("search/query channel failed, node not available", zap.Int64("collectionID", workload.collectionID), @@ -204,6 +203,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel) return lastErr } + defer lb.clientMgr.ReleaseClientRef(targetNode.nodeID) err = workload.exec(ctx, targetNode.nodeID, client, workload.channel) if err != nil { diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 4a96ef9c824bb..abaf3c215e4c6 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -250,7 +250,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test execute success s.lbBalancer.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClient(mock.Anything) + s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) @@ -289,7 +289,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test get client failed, and retry failed, expected success s.mgr.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClient(mock.Anything) + s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) @@ -310,7 +310,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.Error(err) s.mgr.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClient(mock.Anything) + s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) @@ -331,7 +331,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test exec failed, then retry success s.mgr.ExpectedCalls = nil - s.mgr.EXPECT().ReleaseClient(mock.Anything) + s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) @@ -359,7 +359,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test exec timeout s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) - s.mgr.EXPECT().ReleaseClient(mock.Anything) + s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1) @@ -384,7 +384,7 @@ func (s *LBPolicySuite) TestExecute() { ctx := context.Background() mockErr := errors.New("mock error") // test all channel success - s.mgr.EXPECT().ReleaseClient(mock.Anything) + s.mgr.EXPECT().ReleaseClientRef(mock.Anything) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index aa7c964dda469..71178ffd465ae 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -260,7 +260,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { log.RatedInfo(10, "get client failed", zap.Int64("node", node), zap.Error(err)) return struct{}{}, nil } - defer b.clientMgr.ReleaseClient(node) + defer b.clientMgr.ReleaseClientRef(node) resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index abd55070562fa..d91e24f37d6b3 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -43,7 +43,7 @@ type LookAsideBalancerSuite struct { func (suite *LookAsideBalancerSuite) SetupTest() { suite.clientMgr = NewMockShardClientManager(suite.T()) - suite.clientMgr.EXPECT().ReleaseClient(mock.Anything).Maybe() + suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() suite.balancer = NewLookAsideBalancer(suite.clientMgr) suite.balancer.Start(context.Background()) @@ -309,7 +309,7 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { }, }, nil).Maybe() suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClient(mock.Anything).Maybe() + suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) { if ni.nodeID == 1 { return qn, nil @@ -373,7 +373,7 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() { // test get shard client from client mgr return nil suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClient(mock.Anything).Maybe() + suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("shard client not found")) // expected stopping the health check after failure times reaching the limit suite.Eventually(func() bool { @@ -385,7 +385,7 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() { // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClient(mock.Anything) + suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything) suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ @@ -424,7 +424,7 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().ReleaseClient(mock.Anything) + suite.clientMgr.EXPECT().ReleaseClientRef(mock.Anything) suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 134d799d8bc7c..2b9bba4ae4027 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -1036,7 +1036,7 @@ func (m *MetaCache) DeprecateShardCache(database, collectionName string) { if ok { delete(dbInfo, collectionName) if len(dbInfo) == 0 { - delete(dbInfo, database) + delete(m.collLeader, database) } } } @@ -1072,7 +1072,7 @@ func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) { } } if len(dbInfo) == 0 { - delete(dbInfo, dbName) + delete(m.collLeader, dbName) } } } diff --git a/internal/proxy/mock_shardclient_manager.go b/internal/proxy/mock_shardclient_manager.go index 3bc433f138f32..878168e0e7aa6 100644 --- a/internal/proxy/mock_shardclient_manager.go +++ b/internal/proxy/mock_shardclient_manager.go @@ -113,35 +113,35 @@ func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.C return _c } -// ReleaseClient provides a mock function with given fields: nodeID -func (_m *MockShardClientManager) ReleaseClient(nodeID int64) { +// ReleaseClientRef provides a mock function with given fields: nodeID +func (_m *MockShardClientManager) ReleaseClientRef(nodeID int64) { _m.Called(nodeID) } -// MockShardClientManager_ReleaseClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseClient' -type MockShardClientManager_ReleaseClient_Call struct { +// MockShardClientManager_ReleaseClientRef_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseClientRef' +type MockShardClientManager_ReleaseClientRef_Call struct { *mock.Call } -// ReleaseClient is a helper method to define mock.On call +// ReleaseClientRef is a helper method to define mock.On call // - nodeID int64 -func (_e *MockShardClientManager_Expecter) ReleaseClient(nodeID interface{}) *MockShardClientManager_ReleaseClient_Call { - return &MockShardClientManager_ReleaseClient_Call{Call: _e.mock.On("ReleaseClient", nodeID)} +func (_e *MockShardClientManager_Expecter) ReleaseClientRef(nodeID interface{}) *MockShardClientManager_ReleaseClientRef_Call { + return &MockShardClientManager_ReleaseClientRef_Call{Call: _e.mock.On("ReleaseClientRef", nodeID)} } -func (_c *MockShardClientManager_ReleaseClient_Call) Run(run func(nodeID int64)) *MockShardClientManager_ReleaseClient_Call { +func (_c *MockShardClientManager_ReleaseClientRef_Call) Run(run func(nodeID int64)) *MockShardClientManager_ReleaseClientRef_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(int64)) }) return _c } -func (_c *MockShardClientManager_ReleaseClient_Call) Return() *MockShardClientManager_ReleaseClient_Call { +func (_c *MockShardClientManager_ReleaseClientRef_Call) Return() *MockShardClientManager_ReleaseClientRef_Call { _c.Call.Return() return _c } -func (_c *MockShardClientManager_ReleaseClient_Call) RunAndReturn(run func(int64)) *MockShardClientManager_ReleaseClient_Call { +func (_c *MockShardClientManager_ReleaseClientRef_Call) RunAndReturn(run func(int64)) *MockShardClientManager_ReleaseClientRef_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index d72e001b2d724..301fbc9475c8f 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -66,15 +66,21 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err if err != nil { return nil, err } - n.refCnt.Inc() + n.IncRef() return client, nil } } -func (n *shardClient) Release() { +func (n *shardClient) DecRef() bool { if n.refCnt.Dec() == 0 { n.Close() + return true } + return false +} + +func (n *shardClient) IncRef() { + n.refCnt.Inc() } func (n *shardClient) close() { @@ -148,7 +154,7 @@ func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) { type shardClientMgr interface { GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error) - ReleaseClient(nodeID int64) + ReleaseClientRef(nodeID int64) Close() SetClientCreatorFunc(creator queryNodeCreatorFunc) } @@ -209,6 +215,7 @@ func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (type // Create a new client if it doesn't exist newClient, err := newShardClient(info, c.clientCreator) if err != nil { + c.clients.Unlock() return nil, err } c.clients.data[info.nodeID] = newClient @@ -229,23 +236,23 @@ func (c *shardClientMgrImpl) PurgeClient() { return case <-ticker.C: shardLocations := globalMetaCache.ListShardLocation() - for nodeID := range c.clients.data { + c.clients.Lock() + for nodeID, client := range c.clients.data { if _, ok := shardLocations[nodeID]; !ok { - c.clients.Lock() + client.DecRef() delete(c.clients.data, nodeID) - c.clients.Unlock() - c.ReleaseClient(nodeID) } } + c.clients.Unlock() } } } -func (c *shardClientMgrImpl) ReleaseClient(nodeID int64) { - c.clients.Lock() - defer c.clients.Unlock() +func (c *shardClientMgrImpl) ReleaseClientRef(nodeID int64) { + c.clients.RLock() + defer c.clients.RUnlock() if client, ok := c.clients.data[nodeID]; ok { - client.Release() + client.DecRef() } } diff --git a/internal/proxy/shard_client_test.go b/internal/proxy/shard_client_test.go index 809ab0404b3ca..272b10e06e5cd 100644 --- a/internal/proxy/shard_client_test.go +++ b/internal/proxy/shard_client_test.go @@ -28,7 +28,7 @@ func TestShardClientMgr(t *testing.T) { _, err := mgr.GetClient(ctx, nodeInfo) assert.Nil(t, err) - mgr.ReleaseClient(1) + mgr.ReleaseClientRef(1) assert.Equal(t, len(mgr.clients.data), 1) mgr.Close() assert.Equal(t, len(mgr.clients.data), 0) @@ -57,10 +57,10 @@ func TestShardClient(t *testing.T) { assert.Equal(t, int64(2), shardClient.refCnt.Load()) assert.Equal(t, true, shardClient.initialized.Load()) - shardClient.Release() + shardClient.DecRef() assert.Equal(t, int64(1), shardClient.refCnt.Load()) - shardClient.Release() + shardClient.DecRef() assert.Equal(t, int64(0), shardClient.refCnt.Load()) assert.Equal(t, true, shardClient.isClosed) } diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index bb64f4b73c05c..28937fe8e5751 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -39,7 +39,7 @@ func RoundRobinPolicy( combineErr = merr.Combine(combineErr, err) continue } - defer mgr.ReleaseClient(target.nodeID) + defer mgr.ReleaseClientRef(target.nodeID) err = query(ctx, target.nodeID, qn, channel) if err != nil { log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 7e07f4d898254..36b26866b46a5 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -79,7 +79,7 @@ func TestQueryTask_all(t *testing.T) { }, nil).Maybe() mgr := NewMockShardClientManager(t) - mgr.EXPECT().ReleaseClient(mock.Anything) + mgr.EXPECT().ReleaseClientRef(mock.Anything) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() lb := NewLBPolicyImpl(mgr) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 6be4eb6753de5..aabfed073b0d5 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2111,7 +2111,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := NewMockShardClientManager(t) - mgr.EXPECT().ReleaseClient(mock.Anything) + mgr.EXPECT().ReleaseClientRef(mock.Anything) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() lb := NewLBPolicyImpl(mgr) diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index f29f1002f749f..54a084b71e782 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -80,7 +80,7 @@ func (s *StatisticTaskSuite) SetupTest() { s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() mgr := NewMockShardClientManager(s.T()) - mgr.EXPECT().ReleaseClient(mock.Anything).Maybe() + mgr.EXPECT().ReleaseClientRef(mock.Anything).Maybe() mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe() s.lb = NewLBPolicyImpl(mgr) From f3b448150b41ec1f45a678c5afad995e910052a0 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 11 Nov 2024 20:12:56 +0800 Subject: [PATCH 3/3] fix review comment Signed-off-by: Wei Liu --- internal/proxy/meta_cache.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 2b9bba4ae4027..47d282fae56b6 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -1043,8 +1043,8 @@ func (m *MetaCache) DeprecateShardCache(database, collectionName string) { // used for Garbage collection shard client func (m *MetaCache) ListShardLocation() map[int64]nodeInfo { - m.leaderMut.Lock() - defer m.leaderMut.Unlock() + m.leaderMut.RLock() + defer m.leaderMut.RUnlock() shardLeaderInfo := make(map[int64]nodeInfo) for _, dbInfo := range m.collLeader {