diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index d34a4eeaeb848..316f1a552be71 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -487,61 +487,74 @@ func (scheduler *taskScheduler) GetSegmentTaskDelta(nodeID, collectionID int64) scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - targetActions := make([]Action, 0) - for _, t := range scheduler.segmentTasks { - if collectionID != -1 && collectionID != t.CollectionID() { + targetActions := make(map[int64][]Action) + for _, task := range scheduler.segmentTasks { // Map key: replicaSegmentIndex + taskCollID := task.CollectionID() + if collectionID != -1 && collectionID != taskCollID { continue } - for _, action := range t.Actions() { - if action.Node() == nodeID { - targetActions = append(targetActions, action) - } + actions := filterActions(task.Actions(), nodeID) + if len(actions) > 0 { + targetActions[taskCollID] = append(targetActions[taskCollID], actions...) } } - return scheduler.calculateTaskDelta(collectionID, targetActions) + return scheduler.calculateTaskDelta(targetActions) } func (scheduler *taskScheduler) GetChannelTaskDelta(nodeID, collectionID int64) int { scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - targetActions := make([]Action, 0) - for _, t := range scheduler.channelTasks { - if collectionID != -1 && collectionID != t.CollectionID() { + targetActions := make(map[int64][]Action) + for _, task := range scheduler.channelTasks { // Map key: replicaChannelIndex + taskCollID := task.CollectionID() + if collectionID != -1 && collectionID != taskCollID { continue } - for _, action := range t.Actions() { - if action.Node() == nodeID { - targetActions = append(targetActions, action) - } + actions := filterActions(task.Actions(), nodeID) + if len(actions) > 0 { + targetActions[taskCollID] = append(targetActions[taskCollID], actions...) } } - return scheduler.calculateTaskDelta(collectionID, targetActions) + return scheduler.calculateTaskDelta(targetActions) } -func (scheduler *taskScheduler) calculateTaskDelta(collectionID int64, targetActions []Action) int { - sum := 0 - for _, action := range targetActions { - delta := 0 - if action.Type() == ActionTypeGrow { - delta = 1 - } else if action.Type() == ActionTypeReduce { - delta = -1 +// filter actions by nodeID +func filterActions(actions []Action, nodeID int64) []Action { + filtered := make([]Action, 0, len(actions)) + for _, action := range actions { + if nodeID == -1 || action.Node() == nodeID { + filtered = append(filtered, action) } + } + return filtered +} - switch action := action.(type) { - case *SegmentAction: - // skip growing segment's count, cause doesn't know realtime row number of growing segment - if action.Scope == querypb.DataScope_Historical { - segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) - if segment != nil { - sum += int(segment.GetNumOfRows()) * delta +func (scheduler *taskScheduler) calculateTaskDelta(targetActions map[int64][]Action) int { + sum := 0 + for collectionID, actions := range targetActions { + for _, action := range actions { + delta := 0 + if action.Type() == ActionTypeGrow { + delta = 1 + } else if action.Type() == ActionTypeReduce { + delta = -1 + } + + switch action := action.(type) { + case *SegmentAction: + // skip growing segment's count, cause doesn't know realtime row number of growing segment + if action.Scope == querypb.DataScope_Historical { + segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst) + if segment != nil { + sum += int(segment.GetNumOfRows()) * delta + } } + case *ChannelAction: + sum += delta } - case *ChannelAction: - sum += delta } } return sum diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 6de6a70d49175..482282f54f319 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -1855,6 +1855,96 @@ func (suite *TaskSuite) TestGetTasksJSON() { suite.Equal(2, len(tasks)) } +func (suite *TaskSuite) TestCalculateTaskDelta() { + ctx := context.Background() + scheduler := suite.newScheduler() + + mockTarget := meta.NewMockTargetManager(suite.T()) + mockTarget.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&datapb.SegmentInfo{ + NumOfRows: 100, + }) + scheduler.targetMgr = mockTarget + + coll := int64(1001) + nodeID := int64(1) + channelName := "channel-1" + segmentID := int64(1) + // add segment task for collection + task1, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + NewSegmentActionWithScope(nodeID, ActionTypeGrow, "", segmentID, querypb.DataScope_Historical), + ) + suite.NoError(err) + err = scheduler.Add(task1) + suite.NoError(err) + task2, err := NewChannelTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + NewChannelAction(nodeID, ActionTypeGrow, channelName), + ) + suite.NoError(err) + err = scheduler.Add(task2) + suite.NoError(err) + + coll2 := int64(1005) + nodeID2 := int64(2) + channelName2 := "channel-2" + segmentID2 := int64(2) + task3, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll2, + suite.replica, + NewSegmentActionWithScope(nodeID2, ActionTypeGrow, "", segmentID2, querypb.DataScope_Historical), + ) + suite.NoError(err) + err = scheduler.Add(task3) + suite.NoError(err) + task4, err := NewChannelTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll2, + suite.replica, + NewChannelAction(nodeID2, ActionTypeGrow, channelName2), + ) + suite.NoError(err) + err = scheduler.Add(task4) + suite.NoError(err) + + // check task delta with collectionID and nodeID + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, coll)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, coll2)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, coll2)) + + // check task delta with collectionID=-1 + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, -1)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, -1)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, -1)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, -1)) + + // check task delta with nodeID=-1 + suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll)) + + // check task delta with nodeID=-1 and collectionID=-1 + suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) + suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) + suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) + suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) +} + func TestTask(t *testing.T) { suite.Run(t, new(TaskSuite)) }