Skip to content

Commit

Permalink
fix: Prevent balancer from overloading the same QueryNode (#38719)
Browse files Browse the repository at this point in the history
issue: #38718
The balancer calculates the workload of executing tasks as an ongoing
score for target nodes. However, a logic issue arises when
GetSegmentTaskDelta or GetChannelTaskDelta is called with
collectionID=-1, which incorrectly returns zero.

Due to the incorrect global score, the executing task's workload is not
properly reflected for each collection. Consequently, each collection
submits its own balance task, leading to the balancer assigning
excessive tasks to the same QueryNode.

---------

Signed-off-by: Wei Liu <[email protected]>
  • Loading branch information
weiliu1031 authored Dec 25, 2024
1 parent acc8fb7 commit 9c3f59d
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 33 deletions.
79 changes: 46 additions & 33 deletions internal/querycoordv2/task/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,61 +487,74 @@ func (scheduler *taskScheduler) GetSegmentTaskDelta(nodeID, collectionID int64)
scheduler.rwmutex.RLock()
defer scheduler.rwmutex.RUnlock()

targetActions := make([]Action, 0)
for _, t := range scheduler.segmentTasks {
if collectionID != -1 && collectionID != t.CollectionID() {
targetActions := make(map[int64][]Action)
for _, task := range scheduler.segmentTasks { // Map key: replicaSegmentIndex
taskCollID := task.CollectionID()
if collectionID != -1 && collectionID != taskCollID {
continue
}
for _, action := range t.Actions() {
if action.Node() == nodeID {
targetActions = append(targetActions, action)
}
actions := filterActions(task.Actions(), nodeID)
if len(actions) > 0 {
targetActions[taskCollID] = append(targetActions[taskCollID], actions...)
}
}

return scheduler.calculateTaskDelta(collectionID, targetActions)
return scheduler.calculateTaskDelta(targetActions)
}

func (scheduler *taskScheduler) GetChannelTaskDelta(nodeID, collectionID int64) int {
scheduler.rwmutex.RLock()
defer scheduler.rwmutex.RUnlock()

targetActions := make([]Action, 0)
for _, t := range scheduler.channelTasks {
if collectionID != -1 && collectionID != t.CollectionID() {
targetActions := make(map[int64][]Action)
for _, task := range scheduler.channelTasks { // Map key: replicaChannelIndex
taskCollID := task.CollectionID()
if collectionID != -1 && collectionID != taskCollID {
continue
}
for _, action := range t.Actions() {
if action.Node() == nodeID {
targetActions = append(targetActions, action)
}
actions := filterActions(task.Actions(), nodeID)
if len(actions) > 0 {
targetActions[taskCollID] = append(targetActions[taskCollID], actions...)
}
}

return scheduler.calculateTaskDelta(collectionID, targetActions)
return scheduler.calculateTaskDelta(targetActions)
}

func (scheduler *taskScheduler) calculateTaskDelta(collectionID int64, targetActions []Action) int {
sum := 0
for _, action := range targetActions {
delta := 0
if action.Type() == ActionTypeGrow {
delta = 1
} else if action.Type() == ActionTypeReduce {
delta = -1
// filter actions by nodeID
func filterActions(actions []Action, nodeID int64) []Action {
filtered := make([]Action, 0, len(actions))
for _, action := range actions {
if nodeID == -1 || action.Node() == nodeID {
filtered = append(filtered, action)
}
}
return filtered
}

switch action := action.(type) {
case *SegmentAction:
// skip growing segment's count, cause doesn't know realtime row number of growing segment
if action.Scope == querypb.DataScope_Historical {
segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst)
if segment != nil {
sum += int(segment.GetNumOfRows()) * delta
func (scheduler *taskScheduler) calculateTaskDelta(targetActions map[int64][]Action) int {
sum := 0
for collectionID, actions := range targetActions {
for _, action := range actions {
delta := 0
if action.Type() == ActionTypeGrow {
delta = 1
} else if action.Type() == ActionTypeReduce {
delta = -1
}

switch action := action.(type) {
case *SegmentAction:
// skip growing segment's count, cause doesn't know realtime row number of growing segment
if action.Scope == querypb.DataScope_Historical {
segment := scheduler.targetMgr.GetSealedSegment(scheduler.ctx, collectionID, action.SegmentID, meta.NextTargetFirst)
if segment != nil {
sum += int(segment.GetNumOfRows()) * delta
}
}
case *ChannelAction:
sum += delta
}
case *ChannelAction:
sum += delta
}
}
return sum
Expand Down
90 changes: 90 additions & 0 deletions internal/querycoordv2/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,96 @@ func (suite *TaskSuite) TestGetTasksJSON() {
suite.Equal(2, len(tasks))
}

func (suite *TaskSuite) TestCalculateTaskDelta() {
ctx := context.Background()
scheduler := suite.newScheduler()

mockTarget := meta.NewMockTargetManager(suite.T())
mockTarget.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&datapb.SegmentInfo{
NumOfRows: 100,
})
scheduler.targetMgr = mockTarget

coll := int64(1001)
nodeID := int64(1)
channelName := "channel-1"
segmentID := int64(1)
// add segment task for collection
task1, err := NewSegmentTask(
ctx,
10*time.Second,
WrapIDSource(0),
coll,
suite.replica,
NewSegmentActionWithScope(nodeID, ActionTypeGrow, "", segmentID, querypb.DataScope_Historical),
)
suite.NoError(err)
err = scheduler.Add(task1)
suite.NoError(err)
task2, err := NewChannelTask(
ctx,
10*time.Second,
WrapIDSource(0),
coll,
suite.replica,
NewChannelAction(nodeID, ActionTypeGrow, channelName),
)
suite.NoError(err)
err = scheduler.Add(task2)
suite.NoError(err)

coll2 := int64(1005)
nodeID2 := int64(2)
channelName2 := "channel-2"
segmentID2 := int64(2)
task3, err := NewSegmentTask(
ctx,
10*time.Second,
WrapIDSource(0),
coll2,
suite.replica,
NewSegmentActionWithScope(nodeID2, ActionTypeGrow, "", segmentID2, querypb.DataScope_Historical),
)
suite.NoError(err)
err = scheduler.Add(task3)
suite.NoError(err)
task4, err := NewChannelTask(
ctx,
10*time.Second,
WrapIDSource(0),
coll2,
suite.replica,
NewChannelAction(nodeID2, ActionTypeGrow, channelName2),
)
suite.NoError(err)
err = scheduler.Add(task4)
suite.NoError(err)

// check task delta with collectionID and nodeID
suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, coll))
suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, coll))
suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, coll2))
suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, coll2))

// check task delta with collectionID=-1
suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID, -1))
suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID, -1))
suite.Equal(100, scheduler.GetSegmentTaskDelta(nodeID2, -1))
suite.Equal(1, scheduler.GetChannelTaskDelta(nodeID2, -1))

// check task delta with nodeID=-1
suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll))
suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll))
suite.Equal(100, scheduler.GetSegmentTaskDelta(-1, coll))
suite.Equal(1, scheduler.GetChannelTaskDelta(-1, coll))

// check task delta with nodeID=-1 and collectionID=-1
suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1))
suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1))
suite.Equal(200, scheduler.GetSegmentTaskDelta(-1, -1))
suite.Equal(2, scheduler.GetChannelTaskDelta(-1, -1))
}

func TestTask(t *testing.T) {
suite.Run(t, new(TaskSuite))
}
Expand Down

0 comments on commit 9c3f59d

Please sign in to comment.