Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine querynode scheduler lifetime #26915

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,6 @@ type ShardDelegator interface {

var _ ShardDelegator = (*shardDelegator)(nil)

const (
initializing int32 = iota
working
stopped
)

func notStopped(state int32) bool {
return state != stopped
}

func isWorking(state int32) bool {
return state == working
}

// shardDelegator maintains the shard distribution and streaming part of the data.
type shardDelegator struct {
// shard information attributes
Expand All @@ -104,7 +90,7 @@ type shardDelegator struct {

workerManager cluster.Manager

lifetime lifetime.Lifetime[int32]
lifetime lifetime.Lifetime[lifetime.State]

distribution *distribution
segmentManager segments.SegmentManager
Expand Down Expand Up @@ -133,16 +119,16 @@ func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger {

// Serviceable returns whether delegator is serviceable now.
func (sd *shardDelegator) Serviceable() bool {
return sd.lifetime.GetState() == working
return lifetime.IsWorking(sd.lifetime.GetState())
}

func (sd *shardDelegator) Stopped() bool {
return sd.lifetime.GetState() == stopped
return !lifetime.NotStopped(sd.lifetime.GetState())
}

// Start sets delegator to working state.
func (sd *shardDelegator) Start() {
sd.lifetime.SetState(working)
sd.lifetime.SetState(lifetime.Working)
}

// Collection returns delegator collection id.
Expand Down Expand Up @@ -192,7 +178,7 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
Expand Down Expand Up @@ -320,7 +306,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
// Query performs query operation on shard.
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
Expand Down Expand Up @@ -385,7 +371,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
// GetStatistics returns statistics aggregated by delegator.
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
log := sd.getLogger(ctx)
if !sd.lifetime.Add(isWorking) {
if !sd.lifetime.Add(lifetime.IsWorking) {
return nil, errors.New("delegator is not serviceable")
}
defer sd.lifetime.Done()
Expand Down Expand Up @@ -624,7 +610,7 @@ func (sd *shardDelegator) updateTSafe() {

// Close closes the delegator.
func (sd *shardDelegator) Close() {
sd.lifetime.SetState(stopped)
sd.lifetime.SetState(lifetime.Stopped)
sd.lifetime.Close()
// broadcast to all waitTsafe goroutine to quit
sd.tsCond.Broadcast()
Expand Down Expand Up @@ -659,7 +645,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
collection: collection,
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(),
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
pkOracle: pkoracle.NewPkOracle(),
Expand All @@ -670,7 +656,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go sd.watchTSafe()
}
log.Info("finish build new shardDelegator")
Expand Down
8 changes: 4 additions & 4 deletions internal/querynodev2/delegator/delegator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1050,14 +1050,14 @@ func TestDelegatorWatchTsafe(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()

m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go sd.watchTSafe()
}

Expand All @@ -1077,15 +1077,15 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) {
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: lifetime.NewLifetime(initializing),
lifetime: lifetime.NewLifetime(lifetime.Initializing),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()

m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
signal := make(chan struct{})
if sd.lifetime.Add(notStopped) {
if sd.lifetime.Add(lifetime.NotStopped) {
go func() {
sd.watchTSafe()
close(signal)
Expand Down
5 changes: 4 additions & 1 deletion internal/querynodev2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (node *QueryNode) Init() error {
// Start mainly start QueryNode's query service.
func (node *QueryNode) Start() error {
node.startOnce.Do(func() {
node.scheduler.Start(node.ctx)
node.scheduler.Start()

paramtable.SetCreateTime(time.Now())
paramtable.SetUpdateTime(time.Now())
Expand Down Expand Up @@ -452,6 +452,9 @@ func (node *QueryNode) Stop() error {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
node.lifetime.Wait()
node.cancel()
if node.scheduler != nil {
node.scheduler.Stop()
}
if node.pipelineManager != nil {
node.pipelineManager.Close()
}
Expand Down
114 changes: 76 additions & 38 deletions internal/querynodev2/tasks/concurrent_safe_scheduler.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package tasks

import (
"context"
"fmt"
"sync"

"go.uber.org/atomic"
"go.uber.org/zap"
Expand All @@ -11,6 +11,8 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
Expand All @@ -30,6 +32,7 @@ func newScheduler(policy schedulePolicy) Scheduler {
execChan: make(chan Task),
pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)),
schedulerCounter: schedulerCounter{},
lifetime: lifetime.NewLifetime(lifetime.Initializing),
}
}

Expand All @@ -44,12 +47,23 @@ type scheduler struct {
receiveChan chan addTaskReq
execChan chan Task
pool *conc.Pool[any]

// wg is the waitgroup for internal worker goroutine
wg sync.WaitGroup
// lifetime controls scheduler State & make sure all requests accepted will be processed
lifetime lifetime.Lifetime[lifetime.State]

schedulerCounter
}

// Add a new task into scheduler,
// error will be returned if scheduler reaches some limit.
func (s *scheduler) Add(task Task) (err error) {
if !s.lifetime.Add(lifetime.IsWorking) {
return merr.WrapErrServiceUnavailable("scheduler closed")
}
defer s.lifetime.Done()

errCh := make(chan error, 1)

// TODO: add operation should be fast, is UnsolveLen metric unnesscery?
Expand All @@ -68,16 +82,31 @@ func (s *scheduler) Add(task Task) (err error) {

// Start schedule the owned task asynchronously and continuously.
// Start should be only call once.
func (s *scheduler) Start(ctx context.Context) {
func (s *scheduler) Start() {
s.wg.Add(2)

// Start a background task executing loop.
go s.exec(ctx)
go s.exec()

// Begin to schedule tasks.
go s.schedule(ctx)
go s.schedule()

s.lifetime.SetState(lifetime.Working)
}

func (s *scheduler) Stop() {
s.lifetime.SetState(lifetime.Stopped)
// wait all accepted Add done
s.lifetime.Wait()
// close receiveChan start stopping process for `schedule`
close(s.receiveChan)
// wait workers quit
s.wg.Wait()
}

// schedule the owned task asynchronously and continuously.
func (s *scheduler) schedule(ctx context.Context) {
func (s *scheduler) schedule() {
defer s.wg.Done()
var task Task
for {
s.setupReadyLenMetric()
Expand All @@ -87,10 +116,19 @@ func (s *scheduler) schedule(ctx context.Context) {
task, nq, execChan = s.setupExecListener(task)

select {
case <-ctx.Done():
log.Warn("unexpected quit of schedule loop")
return
case req := <-s.receiveChan:
case req, ok := <-s.receiveChan:
if !ok {
log.Info("receiveChan closed, processing remaining request")
// drain policy maintained task
for task != nil {
execChan <- task
s.updateWaitingTaskCounter(-1, -nq)
task = s.produceExecChan()
}
log.Info("all task put into exeChan, schedule worker exit")
close(s.execChan)
return
}
// Receive add operation request and return the process result.
// And consume recv chan as much as possible.
s.consumeRecvChan(req, maxReceiveChanBatchConsumeNum)
Expand Down Expand Up @@ -166,42 +204,42 @@ func (s *scheduler) produceExecChan() Task {
}

// exec exec the ready task in background continuously.
func (s *scheduler) exec(ctx context.Context) {
func (s *scheduler) exec() {
defer s.wg.Done()
log.Info("start execute loop")
for {
select {
case <-ctx.Done():
log.Warn("unexpected quit of exec loop")
t, ok := <-s.execChan
if !ok {
log.Info("scheduler execChan closed, worker exit")
return
case t := <-s.execChan:
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}
}
// Skip this task if task is canceled.
if err := t.Canceled(); err != nil {
log.Warn("task canceled before executing", zap.Error(err))
t.Done(err)
continue
}
if err := t.PreExecute(); err != nil {
log.Warn("failed to pre-execute task", zap.Error(err))
t.Done(err)
continue
}

s.pool.Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)
s.pool.Submit(func() (any, error) {
// Update concurrency metric and notify task done.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1)

err := t.Execute()
err := t.Execute()

// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1)
// Update all metric after task finished.
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1)

// Notify task done.
t.Done(err)
return nil, err
})
}
// Notify task done.
t.Done(err)
return nil, err
})
}
}

Expand Down
21 changes: 20 additions & 1 deletion internal/querynodev2/tasks/concurrent_safe_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,31 @@ func TestScheduler(t *testing.T) {
t.Run("fifo", func(t *testing.T) {
testScheduler(t, newFIFOPolicy())
})
t.Run("scheduler_not_working", func(t *testing.T) {
scheduler := newScheduler(newFIFOPolicy())

task := newMockTask(mockTaskConfig{
nq: 1,
executeCost: 10 * time.Millisecond,
execution: func(ctx context.Context) error {
return nil
},
})

err := scheduler.Add(task)
assert.Error(t, err)

scheduler.Stop()

err = scheduler.Add(task)
assert.Error(t, err)
})
}

func testScheduler(t *testing.T, policy schedulePolicy) {
// start a new scheduler
scheduler := newScheduler(policy)
go scheduler.Start(context.Background())
scheduler.Start()

var cnt atomic.Int32
n := 100
Expand Down
13 changes: 6 additions & 7 deletions internal/querynodev2/tasks/tasks.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package tasks

import (
"context"
)

const (
schedulePolicyNameFIFO = "fifo"
schedulePolicyNameUserTaskPolling = "user-task-polling"
Expand Down Expand Up @@ -44,9 +40,12 @@ type Scheduler interface {
Add(task Task) error

// Start schedule the owned task asynchronously and continuously.
// 1. Stop processing until ctx.Cancel() is called.
// 2. Only call once.
Start(ctx context.Context)
// Shall be called only once
Start()

// Stop make scheduler deny all incoming tasks
// and cleans up all related resources
Stop()

// GetWaitingTaskTotalNQ
GetWaitingTaskTotalNQ() int64
Expand Down
Loading