From 08afdfe0053a28db12d1583310a36e66fd6d0187 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 24 Dec 2024 20:04:01 +0800 Subject: [PATCH] fix: Prevent balancer from overloading the same QueryNode The balancer calculates the workload of executing tasks as an ongoing score for target nodes. However, a logic issue arises when GetSegmentTaskDelta or GetChannelTaskDelta is called with collectionID=-1, which incorrectly returns zero. Due to the incorrect global score, the executing task's workload is not properly reflected for each collection. Consequently, each collection submits its own balance task, leading to the balancer assigning excessive tasks to the same QueryNode. Signed-off-by: Wei Liu --- internal/querycoordv2/task/scheduler.go | 58 +++++++++------- internal/querycoordv2/task/task_test.go | 90 +++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 25 deletions(-) diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index d34a4eeaeb848..f26c5b561a6a4 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -487,61 +487,69 @@ func (scheduler *taskScheduler) GetSegmentTaskDelta(nodeID, collectionID int64) scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - targetActions := make([]Action, 0) + targetActions := make(map[int64][]Action, 0) for _, t := range scheduler.segmentTasks { if collectionID != -1 && collectionID != t.CollectionID() { continue } for _, action := range t.Actions() { - if action.Node() == nodeID { - targetActions = append(targetActions, action) + if action.Node() == nodeID || nodeID == -1 { + if _, ok := targetActions[t.CollectionID()]; !ok { + targetActions[t.CollectionID()] = make([]Action, 0) + } + targetActions[t.CollectionID()] = append(targetActions[t.CollectionID()], action) } } } - return scheduler.calculateTaskDelta(collectionID, targetActions) + return scheduler.calculateTaskDelta(targetActions) } func (scheduler *taskScheduler) GetChannelTaskDelta(nodeID, collectionID int64) int { scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - targetActions := make([]Action, 0) + targetActions := make(map[int64][]Action, 0) for _, t := range scheduler.channelTasks { if collectionID != -1 && collectionID != t.CollectionID() { continue } for _, action := range t.Actions() { - if action.Node() == nodeID { - targetActions = append(targetActions, action) + if action.Node() == nodeID || nodeID == -1 { + if _, ok := targetActions[t.CollectionID()]; !ok { + targetActions[t.CollectionID()] = make([]Action, 0) + } + targetActions[t.CollectionID()] = append(targetActions[t.CollectionID()], action) } } } - return scheduler.calculateTaskDelta(collectionID, targetActions) + return scheduler.calculateTaskDelta(targetActions) } -func (scheduler *taskScheduler) calculateTaskDelta(collectionID int64, targetActions []Action) int { +func (scheduler *taskScheduler) calculateTaskDelta(targetActions map[int64][]Action) int { sum := 0 - for _, action := range targetActions { - delta := 0 - if action.Type() == ActionTypeGrow { - delta = 1 - } else if action.Type() == ActionTypeReduce { - delta = -1 - } + for collectionID, actions := range targetActions { + for _, action := range actions { + delta := 0 + if action.Type() == ActionTypeGrow { + delta = 1 + } else if action.Type() == ActionTypeReduce { + delta = -1 + } - switch action := action.(type) { - case *SegmentAction: - // skip growing segment's count, cause doesn't know realtime row number of growing segment - if action.Scope == querypb.DataScope_Historical { - segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) - if segment != nil { - sum += int(segment.GetNumOfRows()) * delta + switch action := action.(type) { + case *SegmentAction: + // skip growing segment's count, cause doesn't know realtime row number of growing segment + if action.Scope == querypb.DataScope_Historical { + segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) + if segment != nil { + sum += int(segment.GetNumOfRows()) * delta + } } + case *ChannelAction: + sum += delta } - case *ChannelAction: - sum += delta } } return sum diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 6de6a70d49175..482282f54f319 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -1855,6 +1855,96 @@ func (suite *TaskSuite) TestGetTasksJSON() { suite.Equal(2, len(tasks)) } +func (suite *TaskSuite) TestCalculateTaskDelta() { + ctx := context.Background() + scheduler := suite.newScheduler() + + mockTarget := meta.NewMockTargetManager(suite.T()) + mockTarget.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&datapb.SegmentInfo{ + NumOfRows: 100, + }) + scheduler.targetMgr = mockTarget + + coll := int64(1001) + nodeID := int64(1) + channelName := "channel-1" + segmentID := int64(1) + // add segment task for collection + task1, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + NewSegmentActionWithScope(nodeID, ActionTypeGrow, "", segmentID, querypb.DataScope_Historical), + ) + suite.NoError(err) + err = scheduler.Add(task1) + suite.NoError(err) + task2, err := NewChannelTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + NewChannelAction(nodeID, ActionTypeGrow, channelName), + ) + suite.NoError(err) + err = scheduler.Add(task2) + suite.NoError(err) + + coll2 := int64(1005) + nodeID2 := int64(2) + channelName2 := "channel-2" + segmentID2 := int64(2) + task3, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll2, + suite.replica, + NewSegmentActionWithScope(nodeID2, ActionTypeGrow, "", segmentID2, querypb.DataScope_Historical), + ) + suite.NoError(err) + err = scheduler.Add(task3) + suite.NoError(err) + task4, err := NewChannelTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll2, + suite.replica, + NewChannelAction(nodeID2, ActionTypeGrow, channelName2), + ) + suite.NoError(err) + err = scheduler.Add(task4) + suite.NoError(err) + + // check task delta with collectionID and nodeID + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, coll)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, coll2)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, coll2)) + + // check task delta with collectionID=-1 + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, -1)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, -1)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, -1)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, -1)) + + // check task delta with nodeID=-1 + suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll)) + + // check task delta with nodeID=-1 and collectionID=-1 + suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) + suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) + suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) + suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) +} + func TestTask(t *testing.T) { suite.Run(t, new(TaskSuite)) }