Skip to content

Commit

Permalink
enhance: try to speed up the loading of small collections (milvus-io#…
Browse files Browse the repository at this point in the history
…33570)

- issue: milvus-io#33569

Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored Jun 7, 2024
1 parent 9c2e325 commit ecee7d9
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 25 deletions.
2 changes: 2 additions & 0 deletions configs/milvus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ queryCoord:
channelTaskTimeout: 60000 # 1 minute
segmentTaskTimeout: 120000 # 2 minute
distPullInterval: 500
collectionObserverInterval: 200
checkExecutedFlagInterval: 100
heartbeatAvailableInterval: 10000 # 10s, Only QueryNodes which fetched heartbeats within the duration are available
loadTimeoutSeconds: 600
distRequestTimeout: 5000 # the request timeout for querycoord fetching data distribution from querynodes, in milliseconds
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/dist/dist_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (dc *ControllerImpl) SyncAll(ctx context.Context) {
if err != nil {
log.Warn("SyncAll come across err when getting data distribution", zap.Error(err))
} else {
handler.handleDistResp(resp)
handler.handleDistResp(resp, true)
}
}(h)
}
Expand Down
1 change: 1 addition & 0 deletions internal/querycoordv2/dist/dist_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func (suite *DistControllerTestSuite) SetupTest() {
suite.broker = meta.NewMockBroker(suite.T())
targetManager := meta.NewTargetManager(suite.broker, suite.meta)
suite.mockScheduler = task.NewMockScheduler(suite.T())
suite.mockScheduler.EXPECT().GetExecutedFlag(mock.Anything).Return(nil).Maybe()
suite.controller = NewDistController(suite.mockCluster, suite.nodeMgr, distManager, targetManager, suite.mockScheduler)
}

Expand Down
46 changes: 32 additions & 14 deletions internal/querycoordv2/dist/dist_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func (dh *distHandler) start(ctx context.Context) {
log.Info("start dist handler")
ticker := time.NewTicker(Params.QueryCoordCfg.DistPullInterval.GetAsDuration(time.Millisecond))
defer ticker.Stop()
checkExecutedFlagTicker := time.NewTicker(Params.QueryCoordCfg.CheckExecutedFlagInterval.GetAsDuration(time.Millisecond))
defer checkExecutedFlagTicker.Stop()
failures := 0
for {
select {
Expand All @@ -67,25 +69,39 @@ func (dh *distHandler) start(ctx context.Context) {
case <-dh.c:
log.Info("close dist handler")
return
case <-ticker.C:
resp, err := dh.getDistribution(ctx)
if err != nil {
node := dh.nodeManager.Get(dh.nodeID)
fields := []zap.Field{zap.Int("times", failures)}
if node != nil {
fields = append(fields, zap.Time("lastHeartbeat", node.LastHeartbeat()))
case <-checkExecutedFlagTicker.C:
executedFlagChan := dh.scheduler.GetExecutedFlag(dh.nodeID)
if executedFlagChan != nil {
select {
case <-executedFlagChan:
dh.pullDist(ctx, &failures, false)
default:
}
fields = append(fields, zap.Error(err))
log.RatedWarn(30.0, "failed to get data distribution", fields...)
} else {
failures = 0
dh.handleDistResp(resp)
}
case <-ticker.C:
dh.pullDist(ctx, &failures, true)
}
}
}

func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse) {
func (dh *distHandler) pullDist(ctx context.Context, failures *int, dispatchTask bool) {
resp, err := dh.getDistribution(ctx)
if err != nil {
node := dh.nodeManager.Get(dh.nodeID)
*failures = *failures + 1
fields := []zap.Field{zap.Int("times", *failures)}
if node != nil {
fields = append(fields, zap.Time("lastHeartbeat", node.LastHeartbeat()))
}
fields = append(fields, zap.Error(err))
log.RatedWarn(30.0, "failed to get data distribution", fields...)
} else {
*failures = 0
dh.handleDistResp(resp, dispatchTask)
}
}

func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse, dispatchTask bool) {
node := dh.nodeManager.Get(resp.GetNodeID())
if node == nil {
return
Expand Down Expand Up @@ -113,7 +129,9 @@ func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse)
dh.updateLeaderView(resp)
}

dh.scheduler.Dispatch(dh.nodeID)
if dispatchTask {
dh.scheduler.Dispatch(dh.nodeID)
}
}

func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistributionResponse) {
Expand Down
77 changes: 68 additions & 9 deletions internal/querycoordv2/dist/dist_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ type DistHandlerSuite struct {
meta *meta.Meta
broker *meta.MockBroker

nodeID int64
client *session.MockCluster
nodeManager *session.NodeManager
scheduler *task.MockScheduler
dist *meta.DistributionManager
target *meta.MockTargetManager
nodeID int64
client *session.MockCluster
nodeManager *session.NodeManager
scheduler *task.MockScheduler
dispatchMockCall *mock.Call
executedFlagChan chan struct{}
dist *meta.DistributionManager
target *meta.MockTargetManager

handler *distHandler
}
Expand All @@ -61,12 +63,18 @@ func (suite *DistHandlerSuite) SetupSuite() {
suite.target = meta.NewMockTargetManager(suite.T())
suite.ctx = context.Background()

suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe()
suite.executedFlagChan = make(chan struct{}, 1)
suite.scheduler.EXPECT().GetExecutedFlag(mock.Anything).Return(suite.executedFlagChan).Maybe()
suite.target.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.target.EXPECT().GetDmChannel(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
}

func (suite *DistHandlerSuite) TestBasic() {
if suite.dispatchMockCall != nil {
suite.dispatchMockCall.Unset()
suite.dispatchMockCall = nil
}
suite.dispatchMockCall = suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe()
suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Expand Down Expand Up @@ -104,10 +112,15 @@ func (suite *DistHandlerSuite) TestBasic() {
suite.handler = newDistHandler(suite.ctx, suite.nodeID, suite.client, suite.nodeManager, suite.scheduler, suite.dist, suite.target)
defer suite.handler.stop()

time.Sleep(10 * time.Second)
time.Sleep(3 * time.Second)
}

func (suite *DistHandlerSuite) TestGetDistributionFailed() {
if suite.dispatchMockCall != nil {
suite.dispatchMockCall.Unset()
suite.dispatchMockCall = nil
}
suite.dispatchMockCall = suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe()
suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Expand All @@ -118,7 +131,53 @@ func (suite *DistHandlerSuite) TestGetDistributionFailed() {
suite.handler = newDistHandler(suite.ctx, suite.nodeID, suite.client, suite.nodeManager, suite.scheduler, suite.dist, suite.target)
defer suite.handler.stop()

time.Sleep(10 * time.Second)
time.Sleep(3 * time.Second)
}

func (suite *DistHandlerSuite) TestForcePullDist() {
if suite.dispatchMockCall != nil {
suite.dispatchMockCall.Unset()
suite.dispatchMockCall = nil
}

suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{
NodeID: 1,
Address: "localhost",
Hostname: "localhost",
}))
suite.client.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{
Status: merr.Success(),
NodeID: 1,
Channels: []*querypb.ChannelVersionInfo{
{
Channel: "test-channel-1",
Collection: 1,
Version: 1,
},
},
Segments: []*querypb.SegmentVersionInfo{
{
ID: 1,
Collection: 1,
Partition: 1,
Channel: "test-channel-1",
Version: 1,
},
},

LeaderViews: []*querypb.LeaderView{
{
Collection: 1,
Channel: "test-channel-1",
},
},
LastModifyTs: 1,
}, nil)
suite.executedFlagChan <- struct{}{}
suite.handler = newDistHandler(suite.ctx, suite.nodeID, suite.client, suite.nodeManager, suite.scheduler, suite.dist, suite.target)
defer suite.handler.stop()

time.Sleep(300 * time.Millisecond)
}

func TestDistHandlerSuite(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/observers/collection_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (ob *CollectionObserver) Start() {
ctx, cancel := context.WithCancel(context.Background())
ob.cancel = cancel

const observePeriod = time.Second
observePeriod := Params.QueryCoordCfg.CollectionObserverInterval.GetAsDuration(time.Millisecond)
ob.wg.Add(1)
go func() {
defer ob.wg.Done()
Expand Down
11 changes: 11 additions & 0 deletions internal/querycoordv2/task/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type Executor struct {

executingTasks *typeutil.ConcurrentSet[string] // task index
executingTaskNum atomic.Int32
executedFlag chan struct{}
}

func NewExecutor(meta *meta.Meta,
Expand All @@ -82,6 +83,7 @@ func NewExecutor(meta *meta.Meta,
nodeMgr: nodeMgr,

executingTasks: typeutil.NewConcurrentSet[string](),
executedFlag: make(chan struct{}, 1),
}
}

Expand Down Expand Up @@ -131,12 +133,21 @@ func (ex *Executor) Execute(task Task, step int) bool {
return true
}

func (ex *Executor) GetExecutedFlag() <-chan struct{} {
return ex.executedFlag
}

func (ex *Executor) removeTask(task Task, step int) {
if task.Err() != nil {
log.Info("execute action done, remove it",
zap.Int64("taskID", task.ID()),
zap.Int("step", step),
zap.Error(task.Err()))
} else {
select {
case ex.executedFlag <- struct{}{}:
default:
}
}

ex.executingTasks.Remove(task.Index())
Expand Down
44 changes: 44 additions & 0 deletions internal/querycoordv2/task/mock_scheduler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions internal/querycoordv2/task/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ type Scheduler interface {
RemoveByNode(node int64)
GetNodeSegmentDelta(nodeID int64) int
GetNodeChannelDelta(nodeID int64) int
GetExecutedFlag(nodeID int64) <-chan struct{}
GetChannelTaskNum() int
GetSegmentTaskNum() int
}
Expand Down Expand Up @@ -485,6 +486,18 @@ func (scheduler *taskScheduler) GetNodeChannelDelta(nodeID int64) int {
return calculateNodeDelta(nodeID, scheduler.channelTasks)
}

func (scheduler *taskScheduler) GetExecutedFlag(nodeID int64) <-chan struct{} {
scheduler.rwmutex.RLock()
defer scheduler.rwmutex.RUnlock()

executor, ok := scheduler.executors[nodeID]
if !ok {
return nil
}

return executor.GetExecutedFlag()
}

func (scheduler *taskScheduler) GetChannelTaskNum() int {
scheduler.rwmutex.RLock()
defer scheduler.rwmutex.RUnlock()
Expand Down
12 changes: 12 additions & 0 deletions internal/querycoordv2/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() {

// Process tasks
suite.dispatchAndWait(targetNode)
suite.assertExecutedFlagChan(targetNode)
suite.AssertTaskNum(segmentsNum, 0, 0, segmentsNum)

// Process tasks done
Expand Down Expand Up @@ -1536,6 +1537,17 @@ func (suite *TaskSuite) dispatchAndWait(node int64) {
suite.FailNow("executor hangs in executing tasks", "count=%d keys=%+v", count, keys)
}

func (suite *TaskSuite) assertExecutedFlagChan(targetNode int64) {
flagChan := suite.scheduler.GetExecutedFlag(targetNode)
if flagChan != nil {
select {
case <-flagChan:
default:
suite.FailNow("task not executed")
}
}
}

func (suite *TaskSuite) TestLeaderTaskRemove() {
ctx := context.Background()
timeout := 10 * time.Second
Expand Down
21 changes: 21 additions & 0 deletions pkg/util/paramtable/component_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,9 @@ type queryCoordConfig struct {
GracefulStopTimeout ParamItem `refreshable:"true"`
EnableStoppingBalance ParamItem `refreshable:"true"`
ChannelExclusiveNodeFactor ParamItem `refreshable:"true"`

CollectionObserverInterval ParamItem `refreshable:"false"`
CheckExecutedFlagInterval ParamItem `refreshable:"false"`
}

func (p *queryCoordConfig) init(base *BaseTable) {
Expand Down Expand Up @@ -2054,6 +2057,24 @@ func (p *queryCoordConfig) init(base *BaseTable) {
Export: true,
}
p.ChannelExclusiveNodeFactor.Init(base.mgr)

p.CollectionObserverInterval = ParamItem{
Key: "queryCoord.collectionObserverInterval",
Version: "2.4.4",
DefaultValue: "200",
Doc: "the interval of collection observer",
Export: false,
}
p.CollectionObserverInterval.Init(base.mgr)

p.CheckExecutedFlagInterval = ParamItem{
Key: "queryCoord.checkExecutedFlagInterval",
Version: "2.4.4",
DefaultValue: "100",
Doc: "the interval of check executed flag to force to pull dist",
Export: false,
}
p.CheckExecutedFlagInterval.Init(base.mgr)
}

// /////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit ecee7d9

Please sign in to comment.