diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 416d3ea1d4675..69b16c9102a0f 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -481,61 +481,69 @@ func (scheduler *taskScheduler) GetSegmentTaskDelta(nodeID, collectionID int64) scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - targetActions := make([]Action, 0) + targetActions := make(map[int64][]Action, 0) for _, t := range scheduler.segmentTasks { if collectionID != -1 && collectionID != t.CollectionID() { continue } for _, action := range t.Actions() { - if action.Node() == nodeID { - targetActions = append(targetActions, action) + if action.Node() == nodeID || nodeID == -1 { + if _, ok := targetActions[t.CollectionID()]; !ok { + targetActions[t.CollectionID()] = make([]Action, 0) + } + targetActions[t.CollectionID()] = append(targetActions[t.CollectionID()], action) } } } - return scheduler.calculateTaskDelta(collectionID, targetActions) + return scheduler.calculateTaskDelta(targetActions) } func (scheduler *taskScheduler) GetChannelTaskDelta(nodeID, collectionID int64) int { scheduler.rwmutex.RLock() defer scheduler.rwmutex.RUnlock() - targetActions := make([]Action, 0) + targetActions := make(map[int64][]Action, 0) for _, t := range scheduler.channelTasks { if collectionID != -1 && collectionID != t.CollectionID() { continue } for _, action := range t.Actions() { - if action.Node() == nodeID { - targetActions = append(targetActions, action) + if action.Node() == nodeID || nodeID == -1 { + if _, ok := targetActions[t.CollectionID()]; !ok { + targetActions[t.CollectionID()] = make([]Action, 0) + } + targetActions[t.CollectionID()] = append(targetActions[t.CollectionID()], action) } } } - return scheduler.calculateTaskDelta(collectionID, targetActions) + return scheduler.calculateTaskDelta(targetActions) } -func (scheduler *taskScheduler) calculateTaskDelta(collectionID int64, targetActions []Action) int { +func (scheduler *taskScheduler) calculateTaskDelta(targetActions map[int64][]Action) int { sum := 0 - for _, action := range targetActions { - delta := 0 - if action.Type() == ActionTypeGrow { - delta = 1 - } else if action.Type() == ActionTypeReduce { - delta = -1 - } + for collectionID, actions := range targetActions { + for _, action := range actions { + delta := 0 + if action.Type() == ActionTypeGrow { + delta = 1 + } else if action.Type() == ActionTypeReduce { + delta = -1 + } - switch action := action.(type) { - case *SegmentAction: - // skip growing segment's count, cause doesn't know realtime row number of growing segment - if action.Scope() == querypb.DataScope_Historical { - segment := scheduler.targetMgr.GetSealedSegment(collectionID, action.segmentID, meta.NextTargetFirst) - if segment != nil { - sum += int(segment.GetNumOfRows()) * delta + switch action := action.(type) { + case *SegmentAction: + // skip growing segment's count, cause doesn't know realtime row number of growing segment + if action.Scope() == querypb.DataScope_Historical { + segment := scheduler.targetMgr.GetSealedSegment(collectionID, action.SegmentID(), meta.NextTargetFirst) + if segment != nil { + sum += int(segment.GetNumOfRows()) * delta + } } + case *ChannelAction: + sum += delta } - case *ChannelAction: - sum += delta } } return sum diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 9e82566bfee6b..6d6fcfadbfb24 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -1810,6 +1810,96 @@ func (suite *TaskSuite) TestBalanceChannelWithL0SegmentTask() { suite.Equal(2, task.step) } +func (suite *TaskSuite) TestCalculateTaskDelta() { + ctx := context.Background() + scheduler := suite.newScheduler() + + mockTarget := meta.NewMockTargetManager(suite.T()) + mockTarget.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&datapb.SegmentInfo{ + NumOfRows: 100, + }) + scheduler.targetMgr = mockTarget + + coll := int64(1001) + nodeID := int64(1) + channelName := "channel-1" + segmentID := int64(1) + // add segment task for collection + task1, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + NewSegmentActionWithScope(nodeID, ActionTypeGrow, "", segmentID, querypb.DataScope_Historical), + ) + suite.NoError(err) + err = scheduler.Add(task1) + suite.NoError(err) + task2, err := NewChannelTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll, + suite.replica, + NewChannelAction(nodeID, ActionTypeGrow, channelName), + ) + suite.NoError(err) + err = scheduler.Add(task2) + suite.NoError(err) + + coll2 := int64(1005) + nodeID2 := int64(2) + channelName2 := "channel-2" + segmentID2 := int64(2) + task3, err := NewSegmentTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll2, + suite.replica, + NewSegmentActionWithScope(nodeID2, ActionTypeGrow, "", segmentID2, querypb.DataScope_Historical), + ) + suite.NoError(err) + err = scheduler.Add(task3) + suite.NoError(err) + task4, err := NewChannelTask( + ctx, + 10*time.Second, + WrapIDSource(0), + coll2, + suite.replica, + NewChannelAction(nodeID2, ActionTypeGrow, channelName2), + ) + suite.NoError(err) + err = scheduler.Add(task4) + suite.NoError(err) + + // check task delta with collectionID and nodeID + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, coll)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, coll2)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, coll2)) + + // check task delta with collectionID=-1 + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, -1)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, -1)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, -1)) + suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, -1)) + + // check task delta with nodeID=-1 + suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll)) + suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll)) + suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll)) + + // check task delta with nodeID=-1 and collectionID=-1 + suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) + suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) + suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1)) + suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1)) +} + func TestTask(t *testing.T) { suite.Run(t, new(TaskSuite)) }