diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 4d9c5622553d9..bf2749cf8556a 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -78,20 +78,6 @@ type ShardDelegator interface { var _ ShardDelegator = (*shardDelegator)(nil) -const ( - initializing int32 = iota - working - stopped -) - -func notStopped(state int32) bool { - return state != stopped -} - -func isWorking(state int32) bool { - return state == working -} - // shardDelegator maintains the shard distribution and streaming part of the data. type shardDelegator struct { // shard information attributes @@ -104,7 +90,7 @@ type shardDelegator struct { workerManager cluster.Manager - lifetime lifetime.Lifetime[int32] + lifetime lifetime.Lifetime[lifetime.State] distribution *distribution segmentManager segments.SegmentManager @@ -133,16 +119,16 @@ func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger { // Serviceable returns whether delegator is serviceable now. func (sd *shardDelegator) Serviceable() bool { - return sd.lifetime.GetState() == working + return lifetime.IsWorking(sd.lifetime.GetState()) } func (sd *shardDelegator) Stopped() bool { - return sd.lifetime.GetState() == stopped + return !lifetime.NotStopped(sd.lifetime.GetState()) } // Start sets delegator to working state. func (sd *shardDelegator) Start() { - sd.lifetime.SetState(working) + sd.lifetime.SetState(lifetime.Working) } // Collection returns delegator collection id. @@ -192,7 +178,7 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu // Search preforms search operation on shard. func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) { log := sd.getLogger(ctx) - if !sd.lifetime.Add(isWorking) { + if !sd.lifetime.Add(lifetime.IsWorking) { return nil, errors.New("delegator is not serviceable") } defer sd.lifetime.Done() @@ -320,7 +306,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq // Query performs query operation on shard. func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) { log := sd.getLogger(ctx) - if !sd.lifetime.Add(isWorking) { + if !sd.lifetime.Add(lifetime.IsWorking) { return nil, errors.New("delegator is not serviceable") } defer sd.lifetime.Done() @@ -385,7 +371,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) // GetStatistics returns statistics aggregated by delegator. func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) { log := sd.getLogger(ctx) - if !sd.lifetime.Add(isWorking) { + if !sd.lifetime.Add(lifetime.IsWorking) { return nil, errors.New("delegator is not serviceable") } defer sd.lifetime.Done() @@ -624,7 +610,7 @@ func (sd *shardDelegator) updateTSafe() { // Close closes the delegator. func (sd *shardDelegator) Close() { - sd.lifetime.SetState(stopped) + sd.lifetime.SetState(lifetime.Stopped) sd.lifetime.Close() // broadcast to all waitTsafe goroutine to quit sd.tsCond.Broadcast() @@ -659,7 +645,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string collection: collection, segmentManager: manager.Segment, workerManager: workerManager, - lifetime: lifetime.NewLifetime(initializing), + lifetime: lifetime.NewLifetime(lifetime.Initializing), distribution: NewDistribution(), deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer), pkOracle: pkoracle.NewPkOracle(), @@ -670,7 +656,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string } m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - if sd.lifetime.Add(notStopped) { + if sd.lifetime.Add(lifetime.NotStopped) { go sd.watchTSafe() } log.Info("finish build new shardDelegator") diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index 0ef707b44a0a5..4e264e7b03e17 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -1050,14 +1050,14 @@ func TestDelegatorWatchTsafe(t *testing.T) { sd := &shardDelegator{ tsafeManager: tsafeManager, vchannelName: channelName, - lifetime: lifetime.NewLifetime(initializing), + lifetime: lifetime.NewLifetime(lifetime.Initializing), latestTsafe: atomic.NewUint64(0), } defer sd.Close() m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - if sd.lifetime.Add(notStopped) { + if sd.lifetime.Add(lifetime.NotStopped) { go sd.watchTSafe() } @@ -1077,7 +1077,7 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { sd := &shardDelegator{ tsafeManager: tsafeManager, vchannelName: channelName, - lifetime: lifetime.NewLifetime(initializing), + lifetime: lifetime.NewLifetime(lifetime.Initializing), latestTsafe: atomic.NewUint64(0), } defer sd.Close() @@ -1085,7 +1085,7 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) signal := make(chan struct{}) - if sd.lifetime.Add(notStopped) { + if sd.lifetime.Add(lifetime.NotStopped) { go func() { sd.watchTSafe() close(signal) diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 7779e00093522..af5e762d03efd 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -389,7 +389,7 @@ func (node *QueryNode) Init() error { // Start mainly start QueryNode's query service. func (node *QueryNode) Start() error { node.startOnce.Do(func() { - node.scheduler.Start(node.ctx) + node.scheduler.Start() paramtable.SetCreateTime(time.Now()) paramtable.SetUpdateTime(time.Now()) @@ -453,6 +453,9 @@ func (node *QueryNode) Stop() error { node.UpdateStateCode(commonpb.StateCode_Abnormal) node.lifetime.Wait() node.cancel() + if node.scheduler != nil { + node.scheduler.Stop() + } if node.pipelineManager != nil { node.pipelineManager.Close() } diff --git a/internal/querynodev2/tasks/concurrent_safe_scheduler.go b/internal/querynodev2/tasks/concurrent_safe_scheduler.go index ab954115db559..c83f94d383d02 100644 --- a/internal/querynodev2/tasks/concurrent_safe_scheduler.go +++ b/internal/querynodev2/tasks/concurrent_safe_scheduler.go @@ -1,8 +1,8 @@ package tasks import ( - "context" "fmt" + "sync" "go.uber.org/atomic" "go.uber.org/zap" @@ -11,6 +11,8 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -30,6 +32,7 @@ func newScheduler(policy schedulePolicy) Scheduler { execChan: make(chan Task), pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)), schedulerCounter: schedulerCounter{}, + lifetime: lifetime.NewLifetime(lifetime.Initializing), } } @@ -44,12 +47,23 @@ type scheduler struct { receiveChan chan addTaskReq execChan chan Task pool *conc.Pool[any] + + // wg is the waitgroup for internal worker goroutine + wg sync.WaitGroup + // lifetime controls scheduler State & make sure all requests accepted will be processed + lifetime lifetime.Lifetime[lifetime.State] + schedulerCounter } // Add a new task into scheduler, // error will be returned if scheduler reaches some limit. func (s *scheduler) Add(task Task) (err error) { + if !s.lifetime.Add(lifetime.IsWorking) { + return merr.WrapErrServiceUnavailable("scheduler closed") + } + defer s.lifetime.Done() + errCh := make(chan error, 1) // TODO: add operation should be fast, is UnsolveLen metric unnesscery? @@ -68,16 +82,31 @@ func (s *scheduler) Add(task Task) (err error) { // Start schedule the owned task asynchronously and continuously. // Start should be only call once. -func (s *scheduler) Start(ctx context.Context) { +func (s *scheduler) Start() { + s.wg.Add(2) + // Start a background task executing loop. - go s.exec(ctx) + go s.exec() // Begin to schedule tasks. - go s.schedule(ctx) + go s.schedule() + + s.lifetime.SetState(lifetime.Working) +} + +func (s *scheduler) Stop() { + s.lifetime.SetState(lifetime.Stopped) + // wait all accepted Add done + s.lifetime.Wait() + // close receiveChan start stopping process for `schedule` + close(s.receiveChan) + // wait workers quit + s.wg.Wait() } // schedule the owned task asynchronously and continuously. -func (s *scheduler) schedule(ctx context.Context) { +func (s *scheduler) schedule() { + defer s.wg.Done() var task Task for { s.setupReadyLenMetric() @@ -87,10 +116,19 @@ func (s *scheduler) schedule(ctx context.Context) { task, nq, execChan = s.setupExecListener(task) select { - case <-ctx.Done(): - log.Warn("unexpected quit of schedule loop") - return - case req := <-s.receiveChan: + case req, ok := <-s.receiveChan: + if !ok { + log.Info("receiveChan closed, processing remaining request") + // drain policy maintained task + for task != nil { + execChan <- task + s.updateWaitingTaskCounter(-1, -nq) + task = s.produceExecChan() + } + log.Info("all task put into exeChan, schedule worker exit") + close(s.execChan) + return + } // Receive add operation request and return the process result. // And consume recv chan as much as possible. s.consumeRecvChan(req, maxReceiveChanBatchConsumeNum) @@ -166,42 +204,42 @@ func (s *scheduler) produceExecChan() Task { } // exec exec the ready task in background continuously. -func (s *scheduler) exec(ctx context.Context) { +func (s *scheduler) exec() { + defer s.wg.Done() log.Info("start execute loop") for { - select { - case <-ctx.Done(): - log.Warn("unexpected quit of exec loop") + t, ok := <-s.execChan + if !ok { + log.Info("scheduler execChan closed, worker exit") return - case t := <-s.execChan: - // Skip this task if task is canceled. - if err := t.Canceled(); err != nil { - log.Warn("task canceled before executing", zap.Error(err)) - t.Done(err) - continue - } - if err := t.PreExecute(); err != nil { - log.Warn("failed to pre-execute task", zap.Error(err)) - t.Done(err) - continue - } + } + // Skip this task if task is canceled. + if err := t.Canceled(); err != nil { + log.Warn("task canceled before executing", zap.Error(err)) + t.Done(err) + continue + } + if err := t.PreExecute(); err != nil { + log.Warn("failed to pre-execute task", zap.Error(err)) + t.Done(err) + continue + } - s.pool.Submit(func() (any, error) { - // Update concurrency metric and notify task done. - metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() - collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1) + s.pool.Submit(func() (any, error) { + // Update concurrency metric and notify task done. + metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() + collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1) - err := t.Execute() + err := t.Execute() - // Update all metric after task finished. - metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() - collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1) + // Update all metric after task finished. + metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() + collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1) - // Notify task done. - t.Done(err) - return nil, err - }) - } + // Notify task done. + t.Done(err) + return nil, err + }) } } diff --git a/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go b/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go index 167eb3af26321..4af3f05677c97 100644 --- a/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go +++ b/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go @@ -21,12 +21,31 @@ func TestScheduler(t *testing.T) { t.Run("fifo", func(t *testing.T) { testScheduler(t, newFIFOPolicy()) }) + t.Run("scheduler_not_working", func(t *testing.T) { + scheduler := newScheduler(newFIFOPolicy()) + + task := newMockTask(mockTaskConfig{ + nq: 1, + executeCost: 10 * time.Millisecond, + execution: func(ctx context.Context) error { + return nil + }, + }) + + err := scheduler.Add(task) + assert.Error(t, err) + + scheduler.Stop() + + err = scheduler.Add(task) + assert.Error(t, err) + }) } func testScheduler(t *testing.T, policy schedulePolicy) { // start a new scheduler scheduler := newScheduler(policy) - go scheduler.Start(context.Background()) + scheduler.Start() var cnt atomic.Int32 n := 100 diff --git a/internal/querynodev2/tasks/tasks.go b/internal/querynodev2/tasks/tasks.go index 59328be571b51..6a0d55b2edeec 100644 --- a/internal/querynodev2/tasks/tasks.go +++ b/internal/querynodev2/tasks/tasks.go @@ -1,9 +1,5 @@ package tasks -import ( - "context" -) - const ( schedulePolicyNameFIFO = "fifo" schedulePolicyNameUserTaskPolling = "user-task-polling" @@ -44,9 +40,12 @@ type Scheduler interface { Add(task Task) error // Start schedule the owned task asynchronously and continuously. - // 1. Stop processing until ctx.Cancel() is called. - // 2. Only call once. - Start(ctx context.Context) + // Shall be called only once + Start() + + // Stop make scheduler deny all incoming tasks + // and cleans up all related resources + Stop() // GetWaitingTaskTotalNQ GetWaitingTaskTotalNQ() int64 diff --git a/pkg/util/lifetime/state.go b/pkg/util/lifetime/state.go new file mode 100644 index 0000000000000..2adebc5fc072a --- /dev/null +++ b/pkg/util/lifetime/state.go @@ -0,0 +1,45 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lifetime + +// Singal alias for chan struct{}. +type Signal chan struct{} + +// BiState provides pre-defined simple binary state - normal or closed. +type BiState int32 + +const ( + Normal BiState = 0 + Closed BiState = 1 +) + +// State provides pre-defined three stage state. +type State int32 + +const ( + Initializing State = iota + Working + Stopped +) + +func NotStopped(state State) bool { + return state != Stopped +} + +func IsWorking(state State) bool { + return state == Working +}