diff --git a/internal/proxy/lb_balancer.go b/internal/proxy/lb_balancer.go index f918ab20a7581..3d46cb63c0f19 100644 --- a/internal/proxy/lb_balancer.go +++ b/internal/proxy/lb_balancer.go @@ -23,6 +23,7 @@ import ( ) type LBBalancer interface { + RegisterNodeInfo(nodeInfos []nodeInfo) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) CancelWorkload(node int64, nq int64) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index b7e55ad171603..0270a47475449 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -39,7 +39,7 @@ type ChannelWorkload struct { collectionName string collectionID int64 channel string - shardLeaders []int64 + shardLeaders []nodeInfo nq int64 exec executeFunc retryTimes uint @@ -115,9 +115,20 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll } // try to select the best node from the available nodes -func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { - availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) }) - targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq) +func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) { + filterDelegator := func(nodes []nodeInfo) map[int64]nodeInfo { + ret := make(map[int64]nodeInfo) + for _, node := range nodes { + if !excludeNodes.Contain(node.nodeID) { + ret[node.nodeID] = node + } + } + return ret + } + + availableNodes := filterDelegator(workload.shardLeaders) + balancer.RegisterNodeInfo(lo.Values(availableNodes)) + targetNode, err := balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq) if err != nil { log := log.Ctx(ctx) globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) @@ -127,32 +138,33 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), zap.Error(err)) - return -1, err + return nodeInfo{}, err } - availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) }) + availableNodes = filterDelegator(shardLeaders[workload.channel]) if len(availableNodes) == 0 { - nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }) log.Warn("no available shard delegator found", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64s("nodes", nodes), + zap.Int64s("availableNodes", lo.Keys(availableNodes)), zap.Int64s("excluded", excludeNodes.Collect())) - return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found") + return nodeInfo{}, merr.WrapErrChannelNotAvailable("no available shard delegator found") } - targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq) + balancer.RegisterNodeInfo(lo.Values(availableNodes)) + targetNode, err = balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq) if err != nil { log.Warn("failed to select shard", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64s("availableNodes", availableNodes), + zap.Int64s("availableNodes", lo.Keys(availableNodes)), + zap.Int64s("excluded", excludeNodes.Collect()), zap.Error(err)) - return -1, err + return nodeInfo{}, err } } - return targetNode, nil + return availableNodes[targetNode], nil } // ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes. @@ -167,7 +179,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo log.Warn("failed to select node for shard", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64("nodeID", targetNode), + zap.Int64("nodeID", targetNode.nodeID), zap.Error(err), ) if lastErr != nil { @@ -176,30 +188,30 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo return err } // cancel work load which assign to the target node - defer balancer.CancelWorkload(targetNode, workload.nq) + defer balancer.CancelWorkload(targetNode.nodeID, workload.nq) client, err := lb.clientMgr.GetClient(ctx, targetNode) if err != nil { log.Warn("search/query channel failed, node not available", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64("nodeID", targetNode), + zap.Int64("nodeID", targetNode.nodeID), zap.Error(err)) - excludeNodes.Insert(targetNode) + excludeNodes.Insert(targetNode.nodeID) - lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel) + lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel) return lastErr } - err = workload.exec(ctx, targetNode, client, workload.channel) + err = workload.exec(ctx, targetNode.nodeID, client, workload.channel) if err != nil { log.Warn("search/query channel failed", zap.Int64("collectionID", workload.collectionID), zap.String("channelName", workload.channel), - zap.Int64("nodeID", targetNode), + zap.Int64("nodeID", targetNode.nodeID), zap.Error(err)) - excludeNodes.Insert(targetNode) - lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel) + excludeNodes.Insert(targetNode.nodeID) + lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel) return lastErr } @@ -220,9 +232,9 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad // let every request could retry at least twice, which could retry after update shard leader cache retryTimes := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt() wg, ctx := errgroup.WithContext(ctx) - for channel, nodes := range dml2leaders { - channel := channel - nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID }) + for k, v := range dml2leaders { + channel := k + nodes := v channelRetryTimes := retryTimes if len(nodes) > 0 { channelRetryTimes *= len(nodes) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index f212c83467b23..c17388d40212b 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -52,7 +52,8 @@ type LBPolicySuite struct { lbBalancer *MockLBBalancer lbPolicy *LBPolicyImpl - nodes []int64 + nodeIDs []int64 + nodes []nodeInfo channels []string qnList []*mocks.MockQueryNode @@ -65,7 +66,14 @@ func (s *LBPolicySuite) SetupSuite() { } func (s *LBPolicySuite) SetupTest() { - s.nodes = []int64{1, 2, 3, 4, 5} + s.nodeIDs = make([]int64, 0) + for i := 1; i <= 5; i++ { + s.nodeIDs = append(s.nodeIDs, int64(i)) + s.nodes = append(s.nodes, nodeInfo{ + nodeID: int64(i), + address: "localhost", + }) + } s.channels = []string{"channel1", "channel2"} successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc := mocks.NewMockQueryCoordClient(s.T()) @@ -77,12 +85,12 @@ func (s *LBPolicySuite) SetupTest() { Shards: []*querypb.ShardLeadersList{ { ChannelName: s.channels[0], - NodeIds: s.nodes, + NodeIds: s.nodeIDs, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"}, }, { ChannelName: s.channels[1], - NodeIds: s.nodes, + NodeIds: s.nodeIDs, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"}, }, }, @@ -99,7 +107,6 @@ func (s *LBPolicySuite) SetupTest() { s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.mgr = NewMockShardClientManager(s.T()) - s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.lbBalancer = NewMockLBBalancer(s.T()) s.lbBalancer.EXPECT().Start(context.Background()).Maybe() s.lbPolicy = NewLBPolicyImpl(s.mgr) @@ -167,6 +174,7 @@ func (s *LBPolicySuite) loadCollection() { func (s *LBPolicySuite) TestSelectNode() { ctx := context.Background() + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, @@ -177,10 +185,11 @@ func (s *LBPolicySuite) TestSelectNode() { nq: 1, }, typeutil.NewUniqueSet()) s.NoError(err) - s.Equal(int64(5), targetNode) + s.Equal(int64(5), targetNode.nodeID) // test select node failed, then update shard leader cache and retry, expect success s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ @@ -188,28 +197,29 @@ func (s *LBPolicySuite) TestSelectNode() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: []int64{}, + shardLeaders: []nodeInfo{}, nq: 1, }, typeutil.NewUniqueSet()) s.NoError(err) - s.Equal(int64(3), targetNode) + s.Equal(int64(3), targetNode.nodeID) // test select node always fails, expected failure s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: []int64{}, + shardLeaders: []nodeInfo{}, nq: 1, }, typeutil.NewUniqueSet()) s.ErrorIs(err, merr.ErrNodeNotAvailable) - s.Equal(int64(-1), targetNode) // test all nodes has been excluded, expected failure s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, @@ -218,12 +228,12 @@ func (s *LBPolicySuite) TestSelectNode() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - }, typeutil.NewUniqueSet(s.nodes...)) + }, typeutil.NewUniqueSet(s.nodeIDs...)) s.ErrorIs(err, merr.ErrChannelNotAvailable) - s.Equal(int64(-1), targetNode) // test get shard leaders failed, retry to select node failed s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.qc.ExpectedCalls = nil s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable) @@ -236,7 +246,6 @@ func (s *LBPolicySuite) TestSelectNode() { nq: 1, }, typeutil.NewUniqueSet()) s.ErrorIs(err, merr.ErrServiceUnavailable) - s.Equal(int64(-1), targetNode) } func (s *LBPolicySuite) TestExecuteWithRetry() { @@ -245,6 +254,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test execute success s.lbBalancer.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ @@ -263,6 +273,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { // test select node failed, expected error s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, @@ -282,6 +293,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ @@ -301,6 +313,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ db: dbName, @@ -320,6 +333,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) counter := 0 @@ -369,6 +383,7 @@ func (s *LBPolicySuite) TestExecute() { mockErr := errors.New("mock error") // test all channel success s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index 3bd7c3d63e7fb..0aa26a78a7599 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -44,7 +44,8 @@ type CostMetrics struct { type LookAsideBalancer struct { clientMgr shardClientMgr - metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics] + knownNodeInfos *typeutil.ConcurrentMap[int64, nodeInfo] + metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics] // query node id -> number of consecutive heartbeat failures failedHeartBeatCounter *typeutil.ConcurrentMap[int64, *atomic.Int64] @@ -64,6 +65,7 @@ type LookAsideBalancer struct { func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer { balancer := &LookAsideBalancer{ clientMgr: clientMgr, + knownNodeInfos: typeutil.NewConcurrentMap[int64, nodeInfo](), metricsMap: typeutil.NewConcurrentMap[int64, *CostMetrics](), failedHeartBeatCounter: typeutil.NewConcurrentMap[int64, *atomic.Int64](), closeCh: make(chan struct{}), @@ -88,6 +90,12 @@ func (b *LookAsideBalancer) Close() { }) } +func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) { + for _, node := range nodeInfos { + b.knownNodeInfos.Insert(node.nodeID, node) + } +} + func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) { targetNode := int64(-1) defer func() { @@ -233,9 +241,10 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { case <-ticker.C: var futures []*conc.Future[any] now := time.Now() - b.metricsMap.Range(func(node int64, metrics *CostMetrics) bool { + b.knownNodeInfos.Range(func(node int64, info nodeInfo) bool { futures = append(futures, pool.Submit(func() (any, error) { - if now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() { + metrics, ok := b.metricsMap.Get(node) + if !ok || now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() { checkTimeout := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), checkTimeout) defer cancel() @@ -244,7 +253,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { panic("let it panic") } - qn, err := b.clientMgr.GetClient(ctx, node) + qn, err := b.clientMgr.GetClient(ctx, info) if err != nil { // get client from clientMgr failed, which means this qn isn't a shard leader anymore, skip it's health check b.trySetQueryNodeUnReachable(node, err) @@ -304,6 +313,7 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) { zap.Int64("nodeID", node)) // stop the heartbeat b.metricsMap.Remove(node) + b.knownNodeInfos.Remove(node) return } diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index 69d57c5f550ff..20cc7c51a6892 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -46,7 +47,7 @@ func (suite *LookAsideBalancerSuite) SetupTest() { suite.balancer.Start(context.Background()) qn := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe() + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe() } @@ -298,22 +299,45 @@ func (suite *LookAsideBalancerSuite) TestCancelWorkload() { } func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { + qn := mocks.NewMockQueryNodeClient(suite.T()) + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe() qn2 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil).Maybe() qn2.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, }, nil).Maybe() + suite.clientMgr.ExpectedCalls = nil + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) { + if ni.nodeID == 1 { + return qn, nil + } + + if ni.nodeID == 2 { + return qn2, nil + } + return nil, errors.New("unexpected node") + }).Maybe() metrics1 := &CostMetrics{} metrics1.ts.Store(time.Now().UnixMilli()) metrics1.unavailable.Store(true) suite.balancer.metricsMap.Insert(1, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 1, + }, + }) metrics2 := &CostMetrics{} metrics2.ts.Store(time.Now().UnixMilli()) metrics2.unavailable.Store(true) suite.balancer.metricsMap.Insert(2, metrics2) + suite.balancer.knownNodeInfos.Insert(2, nodeInfo{}) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 2, + }, + }) suite.Eventually(func() bool { metrics, ok := suite.balancer.metricsMap.Get(1) return ok && metrics.unavailable.Load() @@ -339,10 +363,15 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() { metrics1.ts.Store(time.Now().UnixMilli()) metrics1.unavailable.Store(true) suite.balancer.metricsMap.Insert(2, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 2, + }, + }) // test get shard client from client mgr return nil suite.clientMgr.ExpectedCalls = nil - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(nil, errors.New("shard client not found")) + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("shard client not found")) // expected stopping the health check after failure times reaching the limit suite.Eventually(func() bool { return !suite.balancer.metricsMap.Contain(2) @@ -352,7 +381,8 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() { func (suite *LookAsideBalancerSuite) TestNodeRecover() { // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) + suite.clientMgr.ExpectedCalls = nil + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Abnormal, @@ -368,6 +398,11 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() { metrics1 := &CostMetrics{} metrics1.ts.Store(time.Now().UnixMilli()) suite.balancer.metricsMap.Insert(3, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 3, + }, + }) suite.Eventually(func() bool { metrics, ok := suite.balancer.metricsMap.Get(3) return ok && metrics.unavailable.Load() @@ -384,7 +419,8 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { Params.Save(Params.ProxyCfg.HealthCheckTimeout.Key, "1000") // mock qn down for a while and then recover qn3 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) + suite.clientMgr.ExpectedCalls = nil + suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn3, nil) qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Abnormal, @@ -394,6 +430,11 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { metrics1 := &CostMetrics{} metrics1.ts.Store(time.Now().UnixMilli()) suite.balancer.metricsMap.Insert(3, metrics1) + suite.balancer.RegisterNodeInfo([]nodeInfo{ + { + nodeID: 3, + }, + }) suite.Eventually(func() bool { metrics, ok := suite.balancer.metricsMap.Get(3) return ok && metrics.unavailable.Load() diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 16b16c3ce10c0..bd8b20e0f895f 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -72,6 +72,7 @@ type Cache interface { GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) DeprecateShardCache(database, collectionName string) InvalidateShardLeaderCache(collections []int64) + ListShardLocation() map[int64]nodeInfo RemoveCollection(ctx context.Context, database, collectionName string) RemoveCollectionsByID(ctx context.Context, collectionID UniqueID) []string RemovePartition(ctx context.Context, database, collectionName string, partitionName string) @@ -288,9 +289,7 @@ func (info *collectionInfo) isCollectionCached() bool { // shardLeaders wraps shard leader mapping for iteration. type shardLeaders struct { - idx *atomic.Int64 - deprecated *atomic.Bool - + idx *atomic.Int64 collectionID int64 shardLeaders map[string][]nodeInfo } @@ -419,19 +418,19 @@ func (m *MetaCache) getCollection(database, collectionName string, collectionID return nil, false } -func (m *MetaCache) getCollectionShardLeader(database, collectionName string) (*shardLeaders, bool) { +func (m *MetaCache) getCollectionShardLeader(database, collectionName string) *shardLeaders { m.leaderMut.RLock() defer m.leaderMut.RUnlock() db, ok := m.collLeader[database] if !ok { - return nil, false + return nil } if leaders, ok := db[collectionName]; ok { - return leaders, !leaders.deprecated.Load() + return leaders } - return nil, false + return nil } func (m *MetaCache) update(ctx context.Context, database, collectionName string, collectionID UniqueID) (*collectionInfo, error) { @@ -954,9 +953,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col zap.String("collectionName", collectionName), zap.Int64("collectionID", collectionID)) - cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName) + cacheShardLeaders := m.getCollectionShardLeader(database, collectionName) if withCache { - if ok { + if cacheShardLeaders != nil { metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() iterator := cacheShardLeaders.GetReader() return iterator.Shuffle(), nil @@ -992,11 +991,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col newShardLeaders := &shardLeaders{ collectionID: info.collID, shardLeaders: shards, - deprecated: atomic.NewBool(false), idx: atomic.NewInt64(0), } - // lock leader m.leaderMut.Lock() if _, ok := m.collLeader[database]; !ok { m.collLeader[database] = make(map[string]*shardLeaders) @@ -1005,15 +1002,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col iterator := newShardLeaders.GetReader() ret := iterator.Shuffle() - - oldLeaders := make(map[string][]nodeInfo) - if cacheShardLeaders != nil { - oldLeaders = cacheShardLeaders.shardLeaders - } - // update refcnt in shardClientMgr - // update shard leader's just create a empty client pool - // and init new client will be execute in getClient - _ = m.shardMgr.UpdateShardLeaders(oldLeaders, ret) m.leaderMut.Unlock() metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) @@ -1039,23 +1027,50 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m // DeprecateShardCache clear the shard leader cache of a collection func (m *MetaCache) DeprecateShardCache(database, collectionName string) { log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) - if shards, ok := m.getCollectionShardLeader(database, collectionName); ok { - shards.deprecated.Store(true) + m.leaderMut.Lock() + defer m.leaderMut.Unlock() + dbInfo, ok := m.collLeader[database] + if ok { + delete(dbInfo, collectionName) + if len(dbInfo) == 0 { + delete(m.collLeader, database) + } } } +// used for Garbage collection shard client +func (m *MetaCache) ListShardLocation() map[int64]nodeInfo { + m.leaderMut.RLock() + defer m.leaderMut.RUnlock() + shardLeaderInfo := make(map[int64]nodeInfo) + + for _, dbInfo := range m.collLeader { + for _, shardLeaders := range dbInfo { + for _, nodeInfos := range shardLeaders.shardLeaders { + for _, node := range nodeInfos { + shardLeaderInfo[node.nodeID] = node + } + } + } + } + return shardLeaderInfo +} + func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) { log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections)) m.leaderMut.Lock() defer m.leaderMut.Unlock() collectionSet := typeutil.NewUniqueSet(collections...) - for _, db := range m.collLeader { - for _, shardLeaders := range db { + for dbName, dbInfo := range m.collLeader { + for collectionName, shardLeaders := range dbInfo { if collectionSet.Contain(shardLeaders.collectionID) { - shardLeaders.deprecated.Store(true) + delete(dbInfo, collectionName) } } + if len(dbInfo) == 0 { + delete(m.collLeader, dbName) + } } } diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index a1611aacc12e2..475ed2240f69f 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -796,7 +796,6 @@ func TestGlobalMetaCache_ShuffleShardLeaders(t *testing.T) { }, } sl := &shardLeaders{ - deprecated: uatomic.NewBool(false), idx: uatomic.NewInt64(5), shardLeaders: shards, } diff --git a/internal/proxy/mock_cache.go b/internal/proxy/mock_cache.go index f1ef527f8eee2..2601a2a66203e 100644 --- a/internal/proxy/mock_cache.go +++ b/internal/proxy/mock_cache.go @@ -921,6 +921,49 @@ func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int return _c } +// ListShardLocation provides a mock function with given fields: +func (_m *MockCache) ListShardLocation() map[int64]nodeInfo { + ret := _m.Called() + + var r0 map[int64]nodeInfo + if rf, ok := ret.Get(0).(func() map[int64]nodeInfo); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]nodeInfo) + } + } + + return r0 +} + +// MockCache_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation' +type MockCache_ListShardLocation_Call struct { + *mock.Call +} + +// ListShardLocation is a helper method to define mock.On call +func (_e *MockCache_Expecter) ListShardLocation() *MockCache_ListShardLocation_Call { + return &MockCache_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")} +} + +func (_c *MockCache_ListShardLocation_Call) Run(run func()) *MockCache_ListShardLocation_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCache_ListShardLocation_Call) Return(_a0 map[int64]nodeInfo) *MockCache_ListShardLocation_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCache_ListShardLocation_Call) RunAndReturn(run func() map[int64]nodeInfo) *MockCache_ListShardLocation_Call { + _c.Call.Return(run) + return _c +} + // RefreshPolicyInfo provides a mock function with given fields: op func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error { ret := _m.Called(op) diff --git a/internal/proxy/mock_lb_balancer.go b/internal/proxy/mock_lb_balancer.go index 368d6bbbd2cc6..95578d91b4c85 100644 --- a/internal/proxy/mock_lb_balancer.go +++ b/internal/proxy/mock_lb_balancer.go @@ -88,6 +88,39 @@ func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Cl return _c } +// RegisterNodeInfo provides a mock function with given fields: nodeInfos +func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) { + _m.Called(nodeInfos) +} + +// MockLBBalancer_RegisterNodeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterNodeInfo' +type MockLBBalancer_RegisterNodeInfo_Call struct { + *mock.Call +} + +// RegisterNodeInfo is a helper method to define mock.On call +// - nodeInfos []nodeInfo +func (_e *MockLBBalancer_Expecter) RegisterNodeInfo(nodeInfos interface{}) *MockLBBalancer_RegisterNodeInfo_Call { + return &MockLBBalancer_RegisterNodeInfo_Call{Call: _e.mock.On("RegisterNodeInfo", nodeInfos)} +} + +func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]nodeInfo)) + }) + return _c +} + +func (_c *MockLBBalancer_RegisterNodeInfo_Call) Return() *MockLBBalancer_RegisterNodeInfo_Call { + _c.Call.Return() + return _c +} + +func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call { + _c.Call.Return(run) + return _c +} + // SelectNode provides a mock function with given fields: ctx, availableNodes, nq func (_m *MockLBBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) { ret := _m.Called(ctx, availableNodes, nq) diff --git a/internal/proxy/mock_shardclient_manager.go b/internal/proxy/mock_shardclient_manager.go index 33d886a18dc02..107bc597534ff 100644 --- a/internal/proxy/mock_shardclient_manager.go +++ b/internal/proxy/mock_shardclient_manager.go @@ -54,25 +54,25 @@ func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShard return _c } -// GetClient provides a mock function with given fields: ctx, nodeID -func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) (types.QueryNodeClient, error) { - ret := _m.Called(ctx, nodeID) +// GetClient provides a mock function with given fields: ctx, nodeInfo1 +func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo1 nodeInfo) (types.QueryNodeClient, error) { + ret := _m.Called(ctx, nodeInfo1) var r0 types.QueryNodeClient var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (types.QueryNodeClient, error)); ok { - return rf(ctx, nodeID) + if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) (types.QueryNodeClient, error)); ok { + return rf(ctx, nodeInfo1) } - if rf, ok := ret.Get(0).(func(context.Context, int64) types.QueryNodeClient); ok { - r0 = rf(ctx, nodeID) + if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) types.QueryNodeClient); ok { + r0 = rf(ctx, nodeInfo1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(types.QueryNodeClient) } } - if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { - r1 = rf(ctx, nodeID) + if rf, ok := ret.Get(1).(func(context.Context, nodeInfo) error); ok { + r1 = rf(ctx, nodeInfo1) } else { r1 = ret.Error(1) } @@ -87,14 +87,14 @@ type MockShardClientManager_GetClient_Call struct { // GetClient is a helper method to define mock.On call // - ctx context.Context -// - nodeID int64 -func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeID interface{}) *MockShardClientManager_GetClient_Call { - return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeID)} +// - nodeInfo1 nodeInfo +func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo1 interface{}) *MockShardClientManager_GetClient_Call { + return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo1)} } -func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeID int64)) *MockShardClientManager_GetClient_Call { +func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo1 nodeInfo)) *MockShardClientManager_GetClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64)) + run(args[0].(context.Context), args[1].(nodeInfo)) }) return _c } @@ -104,7 +104,7 @@ func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClien return _c } -func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, int64) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call { +func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, nodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call { _c.Call.Return(run) return _c } @@ -142,49 +142,6 @@ func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run fun return _c } -// UpdateShardLeaders provides a mock function with given fields: oldLeaders, newLeaders -func (_m *MockShardClientManager) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error { - ret := _m.Called(oldLeaders, newLeaders) - - var r0 error - if rf, ok := ret.Get(0).(func(map[string][]nodeInfo, map[string][]nodeInfo) error); ok { - r0 = rf(oldLeaders, newLeaders) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockShardClientManager_UpdateShardLeaders_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateShardLeaders' -type MockShardClientManager_UpdateShardLeaders_Call struct { - *mock.Call -} - -// UpdateShardLeaders is a helper method to define mock.On call -// - oldLeaders map[string][]nodeInfo -// - newLeaders map[string][]nodeInfo -func (_e *MockShardClientManager_Expecter) UpdateShardLeaders(oldLeaders interface{}, newLeaders interface{}) *MockShardClientManager_UpdateShardLeaders_Call { - return &MockShardClientManager_UpdateShardLeaders_Call{Call: _e.mock.On("UpdateShardLeaders", oldLeaders, newLeaders)} -} - -func (_c *MockShardClientManager_UpdateShardLeaders_Call) Run(run func(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo)) *MockShardClientManager_UpdateShardLeaders_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string][]nodeInfo), args[1].(map[string][]nodeInfo)) - }) - return _c -} - -func (_c *MockShardClientManager_UpdateShardLeaders_Call) Return(_a0 error) *MockShardClientManager_UpdateShardLeaders_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockShardClientManager_UpdateShardLeaders_Call) RunAndReturn(run func(map[string][]nodeInfo, map[string][]nodeInfo) error) *MockShardClientManager_UpdateShardLeaders_Call { - _c.Call.Return(run) - return _c -} - // NewMockShardClientManager creates a new instance of MockShardClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockShardClientManager(t interface { diff --git a/internal/proxy/roundrobin_balancer.go b/internal/proxy/roundrobin_balancer.go index 983514d6ac0a7..e8fe18cf38061 100644 --- a/internal/proxy/roundrobin_balancer.go +++ b/internal/proxy/roundrobin_balancer.go @@ -32,6 +32,8 @@ func NewRoundRobinBalancer() *RoundRobinBalancer { return &RoundRobinBalancer{} } +func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {} + func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) { if len(availableNodes) == 0 { return -1, merr.ErrNodeNotAvailable diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index 8475494e6659c..83a94cf38b51c 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/cockroachdb/errors" "go.uber.org/atomic" @@ -13,6 +14,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) @@ -31,132 +33,144 @@ var errClosed = errors.New("client is closed") type shardClient struct { sync.RWMutex info nodeInfo - isClosed bool - refCnt int - clients []types.QueryNodeClient - idx atomic.Int64 poolSize int - pooling bool + clients []types.QueryNodeClient + creator queryNodeCreatorFunc initialized atomic.Bool - creator queryNodeCreatorFunc + isClosed bool + + idx atomic.Int64 + lastActiveTs *atomic.Int64 + expiredDuration time.Duration +} + +func newShardClient(info nodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient { + return &shardClient{ + info: info, + creator: creator, + lastActiveTs: atomic.NewInt64(time.Now().UnixNano()), + expiredDuration: expiredDuration, + } } func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) { + n.lastActiveTs.Store(time.Now().UnixNano()) if !n.initialized.Load() { n.Lock() if !n.initialized.Load() { - if err := n.initClients(); err != nil { + if err := n.initClients(ctx); err != nil { n.Unlock() return nil, err } - n.initialized.Store(true) } n.Unlock() } - n.RLock() - defer n.RUnlock() - if n.isClosed { - return nil, errClosed + // Attempt to get a connection from the idle connection pool, supporting context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + client, err := n.roundRobinSelectClient() + if err != nil { + return nil, err + } + return client, nil } - - idx := n.idx.Inc() - return n.clients[int(idx)%n.poolSize], nil } -func (n *shardClient) inc() { - n.Lock() - defer n.Unlock() - if n.isClosed { - return +func (n *shardClient) initClients(ctx context.Context) error { + poolSize := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() + if poolSize <= 0 { + poolSize = 1 } - n.refCnt++ -} -func (n *shardClient) close() { - n.isClosed = true - n.refCnt = 0 - - for _, client := range n.clients { - if err := client.Close(); err != nil { - log.Warn("close grpc client failed", zap.Error(err)) + clients := make([]types.QueryNodeClient, 0, poolSize) + for i := 0; i < poolSize; i++ { + client, err := n.creator(ctx, n.info.address, n.info.nodeID) + if err != nil { + // Roll back already created clients + for _, c := range clients { + c.Close() + } + log.Info("failed to create client for node", zap.Int64("nodeID", n.info.nodeID), zap.Error(err)) + return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID)) } + clients = append(clients, client) } - n.clients = nil + + n.initialized.Store(true) + n.poolSize = poolSize + n.clients = clients + return nil } -func (n *shardClient) dec() bool { - n.Lock() - defer n.Unlock() +func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) { + n.RLock() + defer n.RUnlock() if n.isClosed { - return true - } - if n.refCnt > 0 { - n.refCnt-- + return nil, errClosed } - if n.refCnt == 0 { - n.close() + + if len(n.clients) == 0 { + return nil, errors.New("no available clients") } - return n.refCnt == 0 + + nextClientIndex := n.idx.Inc() % int64(len(n.clients)) + nextClient := n.clients[nextClientIndex] + return nextClient, nil } -func (n *shardClient) Close() { +// Notice: close client should only be called by shard client manager. and after close, the client must be removed from the manager. +// 1. the client hasn't been used for a long time +// 2. shard client manager has been closed. +func (n *shardClient) Close(force bool) bool { n.Lock() defer n.Unlock() - n.close() + if force || n.isExpired() { + n.close() + } + + return n.isClosed } -func newPoolingShardClient(info *nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { - return &shardClient{ - info: nodeInfo{ - nodeID: info.nodeID, - address: info.address, - }, - refCnt: 1, - pooling: true, - creator: creator, - }, nil +func (n *shardClient) isExpired() bool { + return time.Now().UnixNano()-n.lastActiveTs.Load() > n.expiredDuration.Nanoseconds() } -func (n *shardClient) initClients() error { - num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() - if num <= 0 { - num = 1 - } - clients := make([]types.QueryNodeClient, 0, num) - for i := 0; i < num; i++ { - client, err := n.creator(context.Background(), n.info.address, n.info.nodeID) - if err != nil { - // roll back already created clients - for _, c := range clients[:i] { - c.Close() - } - return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID)) +func (n *shardClient) close() { + n.isClosed = true + + for _, client := range n.clients { + if err := client.Close(); err != nil { + log.Warn("close grpc client failed", zap.Error(err)) } - clients = append(clients, client) } - - n.clients = clients - n.poolSize = num - return nil + n.clients = nil } +// roundRobinSelectClient selects a client in a round-robin manner type shardClientMgr interface { - GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) - UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error + GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error) Close() SetClientCreatorFunc(creator queryNodeCreatorFunc) } type shardClientMgrImpl struct { - clients struct { - sync.RWMutex - data map[UniqueID]*shardClient - } + clients *typeutil.ConcurrentMap[UniqueID, *shardClient] clientCreator queryNodeCreatorFunc + closeCh chan struct{} + + purgeInterval time.Duration + expiredDuration time.Duration } +const ( + defaultPurgeInterval = 600 * time.Second + defaultExpiredDuration = 60 * time.Minute +) + // SessionOpt provides a way to set params in SessionManager type shardClientMgrOpt func(s shardClientMgr) @@ -171,15 +185,17 @@ func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int6 // NewShardClientMgr creates a new shardClientMgr func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl { s := &shardClientMgrImpl{ - clients: struct { - sync.RWMutex - data map[UniqueID]*shardClient - }{data: make(map[UniqueID]*shardClient)}, - clientCreator: defaultQueryNodeClientCreator, + clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](), + clientCreator: defaultQueryNodeClientCreator, + closeCh: make(chan struct{}), + purgeInterval: defaultPurgeInterval, + expiredDuration: defaultExpiredDuration, } for _, opt := range options { opt(s) } + + go s.PurgeClient() return s } @@ -187,81 +203,45 @@ func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) c.clientCreator = creator } -// Warning this method may modify parameter `oldLeaders` -func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error { - oldLocalMap := make(map[UniqueID]*nodeInfo) - for _, nodes := range oldLeaders { - for i := range nodes { - n := &nodes[i] - _, ok := oldLocalMap[n.nodeID] - if !ok { - oldLocalMap[n.nodeID] = n - } - } - } - newLocalMap := make(map[UniqueID]*nodeInfo) - - for _, nodes := range newLeaders { - for i := range nodes { - n := &nodes[i] - _, ok := oldLocalMap[n.nodeID] - if !ok { - _, ok2 := newLocalMap[n.nodeID] - if !ok2 { - newLocalMap[n.nodeID] = n - } - } - delete(oldLocalMap, n.nodeID) - } - } - c.clients.Lock() - defer c.clients.Unlock() - - for _, node := range newLocalMap { - client, ok := c.clients.data[node.nodeID] - if ok { - client.inc() - } else { - // context.Background() is useless - // TODO QueryNode NewClient remove ctx parameter - // TODO Remove Init && Start interface in QueryNode client - if c.clientCreator == nil { - return fmt.Errorf("clientCreator function is nil") - } - client, err := newPoolingShardClient(node, c.clientCreator) - if err != nil { - return err - } - c.clients.data[node.nodeID] = client - } - } - for _, node := range oldLocalMap { - client, ok := c.clients.data[node.nodeID] - if ok && client.dec() { - delete(c.clients.data, node.nodeID) - } - } - return nil +func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) { + client, _ := c.clients.GetOrInsert(info.nodeID, newShardClient(info, c.clientCreator, c.expiredDuration)) + return client.getClient(ctx) } -func (c *shardClientMgrImpl) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) { - c.clients.RLock() - client, ok := c.clients.data[nodeID] - c.clients.RUnlock() - - if !ok { - return nil, fmt.Errorf("can not find client of node %d", nodeID) +// PurgeClient purges client if it is not used for a long time +func (c *shardClientMgrImpl) PurgeClient() { + ticker := time.NewTicker(c.purgeInterval) + defer ticker.Stop() + + for { + select { + case <-c.closeCh: + return + case <-ticker.C: + shardLocations := globalMetaCache.ListShardLocation() + c.clients.Range(func(key UniqueID, value *shardClient) bool { + if _, ok := shardLocations[key]; !ok { + // if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it + if value.isExpired() { + closed := value.Close(false) + if closed { + c.clients.Remove(key) + log.Info("remove idle node client", zap.Int64("nodeID", key)) + } + } + } + return true + }) + } } - return client.getClient(ctx) } // Close release clients func (c *shardClientMgrImpl) Close() { - c.clients.Lock() - defer c.clients.Unlock() - - for _, s := range c.clients.data { - s.Close() - } - c.clients.data = make(map[UniqueID]*shardClient) + close(c.closeCh) + c.clients.Range(func(key UniqueID, value *shardClient) bool { + value.Close(true) + c.clients.Remove(key) + return true + }) } diff --git a/internal/proxy/shard_client_test.go b/internal/proxy/shard_client_test.go index 0ef6f516caf2d..ddc8308954716 100644 --- a/internal/proxy/shard_client_test.go +++ b/internal/proxy/shard_client_test.go @@ -3,99 +3,158 @@ package proxy import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) -func genShardLeaderInfo(channel string, leaderIDs []UniqueID) map[string][]nodeInfo { - leaders := make(map[string][]nodeInfo) - nodeInfos := make([]nodeInfo, len(leaderIDs)) - for i, id := range leaderIDs { - nodeInfos[i] = nodeInfo{ - nodeID: id, - address: "fake", - } +func TestShardClientMgr(t *testing.T) { + ctx := context.Background() + nodeInfo := nodeInfo{ + nodeID: 1, } - leaders[channel] = nodeInfos - return leaders -} - -func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) { - mgr := newShardClientMgr(withShardClientCreator(nil)) - mgr.clientCreator = nil - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - err := mgr.UpdateShardLeaders(nil, leaders) - assert.Error(t, err) -} -func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) { - mockCreator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { - return &mocks.MockQueryNodeClient{}, nil + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Close().Return(nil) + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil } - mgr := newShardClientMgr(withShardClientCreator(mockCreator)) - - _, err := mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) - - err = mgr.UpdateShardLeaders(nil, nil) - assert.NoError(t, err) - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - err = mgr.UpdateShardLeaders(leaders, nil) - assert.NoError(t, err) -} - -func TestShardClientMgr_UpdateShardLeaders_NonEmpty(t *testing.T) { mgr := newShardClientMgr() - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - err := mgr.UpdateShardLeaders(nil, leaders) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.NoError(t, err) + mgr.SetClientCreatorFunc(creator) + _, err := mgr.GetClient(ctx, nodeInfo) + assert.Nil(t, err) - newLeaders := genShardLeaderInfo("c1", []UniqueID{2, 3}) - err = mgr.UpdateShardLeaders(leaders, newLeaders) - assert.NoError(t, err) - - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) + mgr.Close() + assert.Equal(t, mgr.clients.Len(), 0) } -func TestShardClientMgr_UpdateShardLeaders_Ref(t *testing.T) { - mgr := newShardClientMgr() - leaders := genShardLeaderInfo("c1", []UniqueID{1, 2, 3}) - - for i := 0; i < 2; i++ { - err := mgr.UpdateShardLeaders(nil, leaders) - assert.NoError(t, err) +func TestShardClient(t *testing.T) { + nodeInfo := nodeInfo{ + nodeID: 1, } - partLeaders := genShardLeaderInfo("c1", []UniqueID{1}) + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Close().Return(nil).Maybe() + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil + } + shardClient := newShardClient(nodeInfo, creator, 3*time.Second) + assert.Equal(t, len(shardClient.clients), 0) + assert.Equal(t, false, shardClient.initialized.Load()) + assert.Equal(t, false, shardClient.isClosed) + + ctx := context.Background() + _, err := shardClient.getClient(ctx) + assert.Nil(t, err) + assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt()) + + // test close + closed := shardClient.Close(false) + assert.False(t, closed) + closed = shardClient.Close(true) + assert.True(t, closed) +} - _, err := mgr.GetClient(context.Background(), UniqueID(1)) - assert.NoError(t, err) +func TestPurgeClient(t *testing.T) { + node := nodeInfo{ + nodeID: 1, + } - err = mgr.UpdateShardLeaders(partLeaders, nil) - assert.NoError(t, err) + returnEmptyResult := atomic.NewBool(false) - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.NoError(t, err) + cache := NewMockCache(t) + cache.EXPECT().ListShardLocation().RunAndReturn(func() map[int64]nodeInfo { + if returnEmptyResult.Load() { + return map[int64]nodeInfo{} + } + return map[int64]nodeInfo{ + 1: node, + } + }) + globalMetaCache = cache - err = mgr.UpdateShardLeaders(partLeaders, nil) - assert.NoError(t, err) + qn := mocks.NewMockQueryNodeClient(t) + qn.EXPECT().Close().Return(nil).Maybe() + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil + } - _, err = mgr.GetClient(context.Background(), UniqueID(1)) - assert.Error(t, err) + s := &shardClientMgrImpl{ + clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](), + clientCreator: creator, + closeCh: make(chan struct{}), + purgeInterval: 1 * time.Second, + expiredDuration: 3 * time.Second, + } - _, err = mgr.GetClient(context.Background(), UniqueID(2)) - assert.NoError(t, err) + go s.PurgeClient() + defer s.Close() + _, err := s.GetClient(context.Background(), node) + assert.Nil(t, err) + qnClient, ok := s.clients.Get(1) + assert.True(t, ok) + assert.True(t, qnClient.lastActiveTs.Load() > 0) + + time.Sleep(2 * time.Second) + // expected client should not been purged before expiredDuration + assert.Equal(t, s.clients.Len(), 1) + assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds()) + + _, err = s.GetClient(context.Background(), node) + assert.Nil(t, err) + time.Sleep(2 * time.Second) + // GetClient should refresh lastActiveTs, expected client should not be purged + assert.Equal(t, s.clients.Len(), 1) + assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds()) + + time.Sleep(2 * time.Second) + // client reach the expiredDuration, expected client should not be purged + assert.Equal(t, s.clients.Len(), 1) + assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds()) + + returnEmptyResult.Store(true) + time.Sleep(2 * time.Second) + // remove client from shard location, expected client should be purged + assert.Equal(t, s.clients.Len(), 0) +} - _, err = mgr.GetClient(context.Background(), UniqueID(3)) - assert.NoError(t, err) +func BenchmarkShardClientMgr(b *testing.B) { + node := nodeInfo{ + nodeID: 1, + } + cache := NewMockCache(b) + cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{ + 1: node, + }).Maybe() + globalMetaCache = cache + qn := mocks.NewMockQueryNodeClient(b) + qn.EXPECT().Close().Return(nil).Maybe() + + creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return qn, nil + } + s := &shardClientMgrImpl{ + clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](), + clientCreator: creator, + closeCh: make(chan struct{}), + purgeInterval: 1 * time.Second, + expiredDuration: 10 * time.Second, + } + go s.PurgeClient() + defer s.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := s.GetClient(context.Background(), node) + assert.Nil(b, err) + } + }) } diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index 56c662e34b1b7..f3a047bd38385 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -33,7 +33,7 @@ func RoundRobinPolicy( leaders := dml2leaders[channel] for _, target := range leaders { - qn, err := mgr.GetClient(ctx, target.nodeID) + qn, err := mgr.GetClient(ctx, target) if err != nil { log.Warn("query channel failed, node not available", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) combineErr = merr.Combine(combineErr, err) diff --git a/internal/proxy/task_policies_test.go b/internal/proxy/task_policies_test.go index 5c4b1732824ac..98bc97096cd70 100644 --- a/internal/proxy/task_policies_test.go +++ b/internal/proxy/task_policies_test.go @@ -26,7 +26,6 @@ func TestRoundRobinPolicy(t *testing.T) { "c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}}, "c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}}, } - mgr.UpdateShardLeaders(nil, shard2leaders) querier := &mockQuery{} querier.init() diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 32512706ff20d..34c7e7bd4e39c 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -79,7 +79,6 @@ func TestQueryTask_all(t *testing.T) { mgr := NewMockShardClientManager(t) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() - mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() lb := NewLBPolicyImpl(mgr) defer rc.Close() diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index c3abf37ad4fef..6fe5a3b1b334b 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1795,7 +1795,6 @@ func TestSearchTask_ErrExecute(t *testing.T) { mgr := NewMockShardClientManager(t) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() - mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() lb := NewLBPolicyImpl(mgr) defer qc.Close() diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index 42f0d63b4480d..e3d3786cfe97c 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -81,7 +81,6 @@ func (s *StatisticTaskSuite) SetupTest() { s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() mgr := NewMockShardClientManager(s.T()) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe() - mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.lb = NewLBPolicyImpl(mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)