Skip to content

Commit

Permalink
Refine querynode scheduler lifetime (#26915)
Browse files Browse the repository at this point in the history
This PR refines scheduler lifetime control:
- Move private tri-state into lifetime package
- Make scheduler block incoming "Add" task
- Make scheduler Stop wait until all previously accepted task done

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Sep 28, 2023
1 parent 8c59dba commit 258e1cc
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 75 deletions.
34 changes: 10 additions & 24 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,6 @@ type ShardDelegator interface {

var _ ShardDelegator = (*shardDelegator)(nil)

const (
initializing int32 = iota
working
stopped
)

func notStopped(state int32) bool {
return state != stopped
}

func isWorking(state int32) bool {
return state == working
}

// shardDelegator maintains the shard distribution and streaming part of the data.
type shardDelegator struct {
// shard information attributes
Expand All @@ -104,7 +90,7 @@ type shardDelegator struct {

workerManager cluster.Manager

lifetime lifetime.Lifetime[int32]
lifetime lifetime.Lifetime[lifetime.State]

distribution *distribution
segmentManager segments.SegmentManager
Expand Down Expand Up @@ -133,16 +119,16 @@ func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger {

// Serviceable returns whether delegator is serviceable now.
func (sd *shardDelegator) Serviceable() bool {
return sd.lifetime.GetState() == working
return lifetime.IsWorking(sd.lifetime.GetState())
}

func (sd *shardDelegator) Stopped() bool {
return sd.lifetime.GetState() == stopped
return !lifetime.NotStopped(sd.lifetime.GetState())
}

// Start sets delegator to working state.
func (sd *shardDelegator) Start() {
sd.lifetime.SetState(working)
sd.lifetime.SetState(lifetime.Working)
}

// Collection returns delegator collection id.
Expand Down Expand Up @@ -192,7 +178,7 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
Expand Down Expand Up @@ -320,7 +306,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
// Query performs query operation on shard.
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
Expand Down Expand Up @@ -385,7 +371,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
// GetStatistics returns statistics aggregated by delegator.
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
Expand Down Expand Up @@ -624,7 +610,7 @@ func (sd *shardDelegator) updateTSafe() {

// Close closes the delegator.
func (sd *shardDelegator) Close() {
sd.lifetime.SetState(stopped)
sd.lifetime.SetState(lifetime.Stopped)
sd.lifetime.Close()
// broadcast to all waitTsafe goroutine to quit
sd.tsCond.Broadcast()
Expand Down Expand Up @@ -659,7 +645,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
collection: collection,
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(),
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
pkOracle: pkoracle.NewPkOracle(),
Expand All @@ -670,7 +656,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go sd.watchTSafe()
}
log.Info("finish build new shardDelegator")
Expand Down
8 changes: 4 additions & 4 deletions internal/querynodev2/delegator/delegator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1050,14 +1050,14 @@ func TestDelegatorWatchTsafe(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()

m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go sd.watchTSafe()
}

Expand All @@ -1077,15 +1077,15 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()

m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
signal := make(chan struct{})
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go func() {
sd.watchTSafe()
close(signal)
Expand Down
5 changes: 4 additions & 1 deletion internal/querynodev2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (node *QueryNode) Init() error {
// Start mainly start QueryNode's query service.
func (node *QueryNode) Start() error {
node.startOnce.Do(func() {
node.scheduler.Start(node.ctx)
node.scheduler.Start()

paramtable.SetCreateTime(time.Now())
paramtable.SetUpdateTime(time.Now())
Expand Down Expand Up @@ -453,6 +453,9 @@ func (node *QueryNode) Stop() error {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
node.lifetime.Wait()
node.cancel()
if node.scheduler != nil {
node.scheduler.Stop()
}
if node.pipelineManager != nil {
node.pipelineManager.Close()
}
Expand Down
114 changes: 76 additions & 38 deletions internal/querynodev2/tasks/concurrent_safe_scheduler.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package tasks

import (
"context"
"fmt"
"sync"

"go.uber.org/atomic"
"go.uber.org/zap"
Expand All @@ -11,6 +11,8 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
Expand All @@ -30,6 +32,7 @@ func newScheduler(policy schedulePolicy) Scheduler {
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
}

Expand All @@ -44,12 +47,23 @@ type scheduler struct {
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]

// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
// lifetime controls scheduler State & make sure all requests accepted will be processed
lifetime lifetime.Lifetime[lifetime.State]

schedulerCounter
}

// Add a new task into scheduler,
// error will be returned if scheduler reaches some limit.
func (s *scheduler) Add(task Task) (err error) {
if !s.lifetime.Add(lifetime.IsWorking) {
return merr.WrapErrServiceUnavailable("scheduler closed")
}
defer s.lifetime.Done()

errCh := make(chan error, 1)

// TODO: add operation should be fast, is UnsolveLen metric unnesscery?
Expand All @@ -68,16 +82,31 @@ func (s *scheduler) Add(task Task) (err error) {

// Start schedule the owned task asynchronously and continuously.
// Start should be only call once.
func (s *scheduler) Start(ctx context.Context) {
func (s *scheduler) Start() {
s.wg.Add(2)

// Start a background task executing loop.
go s.exec(ctx)
go s.exec()

// Begin to schedule tasks.
go s.schedule(ctx)
go s.schedule()

s.lifetime.SetState(lifetime.Working)
}

func (s *scheduler) Stop() {
s.lifetime.SetState(lifetime.Stopped)
// wait all accepted Add done
s.lifetime.Wait()
// close receiveChan start stopping process for `schedule`
close(s.receiveChan)
// wait workers quit
s.wg.Wait()
}

// schedule the owned task asynchronously and continuously.
func (s *scheduler) schedule(ctx context.Context) {
func (s *scheduler) schedule() {
defer s.wg.Done()
var task Task
for {
s.setupReadyLenMetric()
Expand All @@ -87,10 +116,19 @@ func (s *scheduler) schedule(ctx context.Context) {
task, nq, execChan = s.setupExecListener(task)

select {
case <-ctx.Done():
log.Warn("unexpected quit of schedule loop")
return
case req := <-s.receiveChan:
case req, ok := <-s.receiveChan:
if !ok {
log.Info("receiveChan closed, processing remaining request")
// drain policy maintained task
for task != nil {
execChan <- task
s.updateWaitingTaskCounter(-1, -nq)
task = s.produceExecChan()
}
log.Info("all task put into exeChan, schedule worker exit")
close(s.execChan)
return
}
// Receive add operation request and return the process result.
// And consume recv chan as much as possible.
s.consumeRecvChan(req, maxReceiveChanBatchConsumeNum)
Expand Down Expand Up @@ -166,42 +204,42 @@ func (s *scheduler) produceExecChan() Task {
}

// exec exec the ready task in background continuously.
func (s *scheduler) exec(ctx context.Context) {
func (s *scheduler) exec() {
defer s.wg.Done()
log.Info("start execute loop")
for {
select {
case <-ctx.Done():
log.Warn("unexpected quit of exec loop")
t, ok := <-s.execChan
if !ok {
log.Info("scheduler execChan closed, worker exit")
return
case t := <-s.execChan:
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}
}
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}

s.pool.Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
s.pool.Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)

err := t.Execute()
err := t.Execute()

// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1)
// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1)

// Notify task done.
t.Done(err)
return nil, err
})
}
// Notify task done.
t.Done(err)
return nil, err
})
}
}

Expand Down
21 changes: 20 additions & 1 deletion internal/querynodev2/tasks/concurrent_safe_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,31 @@ func TestScheduler(t *testing.T) {
t.Run("fifo", func(t *testing.T) {
testScheduler(t, newFIFOPolicy())
})
t.Run("scheduler_not_working", func(t *testing.T) {
scheduler := newScheduler(newFIFOPolicy())

task := newMockTask(mockTaskConfig{
nq: 1,
executeCost: 10 * time.Millisecond,
execution: func(ctx context.Context) error {
return nil
},
})

err := scheduler.Add(task)
assert.Error(t, err)

scheduler.Stop()

err = scheduler.Add(task)
assert.Error(t, err)
})
}

func testScheduler(t *testing.T, policy schedulePolicy) {
// start a new scheduler
scheduler := newScheduler(policy)
go scheduler.Start(context.Background())
scheduler.Start()

var cnt atomic.Int32
n := 100
Expand Down
13 changes: 6 additions & 7 deletions internal/querynodev2/tasks/tasks.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package tasks

import (
"context"
)

const (
schedulePolicyNameFIFO = "fifo"
schedulePolicyNameUserTaskPolling = "user-task-polling"
Expand Down Expand Up @@ -44,9 +40,12 @@ type Scheduler interface {
Add(task Task) error

// Start schedule the owned task asynchronously and continuously.
// 1. Stop processing until ctx.Cancel() is called.
// 2. Only call once.
Start(ctx context.Context)
// Shall be called only once
Start()

// Stop make scheduler deny all incoming tasks
// and cleans up all related resources
Stop()

// GetWaitingTaskTotalNQ
GetWaitingTaskTotalNQ() int64
Expand Down
Loading

0 comments on commit 258e1cc

Please sign in to comment.