Skip to content

Commit

Permalink
Add ctx control for observer manual check methods (#27531)
Browse files Browse the repository at this point in the history
Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Oct 9, 2023
1 parent 3759857 commit eca79d1
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 15 deletions.
10 changes: 5 additions & 5 deletions internal/querycoordv2/observers/collection_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (ob *CollectionObserver) Stop() {

func (ob *CollectionObserver) Observe(ctx context.Context) {
ob.observeTimeout()
ob.observeLoadStatus()
ob.observeLoadStatus(ctx)
}

func (ob *CollectionObserver) observeTimeout() {
Expand Down Expand Up @@ -158,7 +158,7 @@ func (ob *CollectionObserver) readyToObserve(collectionID int64) bool {
return metaExist && targetExist
}

func (ob *CollectionObserver) observeLoadStatus() {
func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) {
partitions := ob.meta.CollectionManager.GetAllPartitions()
if len(partitions) > 0 {
log.Info("observe partitions status", zap.Int("partitionNum", len(partitions)))
Expand All @@ -170,7 +170,7 @@ func (ob *CollectionObserver) observeLoadStatus() {
}
if ob.readyToObserve(partition.CollectionID) {
replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID())
ob.observePartitionLoadStatus(partition, replicaNum)
ob.observePartitionLoadStatus(ctx, partition, replicaNum)
loading = true
}
}
Expand All @@ -180,7 +180,7 @@ func (ob *CollectionObserver) observeLoadStatus() {
}
}

func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition, replicaNum int32) {
func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, partition *meta.Partition, replicaNum int32) {
log := log.With(
zap.Int64("collectionID", partition.GetCollectionID()),
zap.Int64("partitionID", partition.GetPartitionID()),
Expand Down Expand Up @@ -230,7 +230,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
}

ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount
if loadPercentage == 100 && ob.targetObserver.Check(partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(partition.GetCollectionID()) {
if loadPercentage == 100 && ob.targetObserver.Check(ctx, partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(ctx, partition.GetCollectionID()) {
delete(ob.partitionLoadedCount, partition.GetPartitionID())
}
collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(partition.PartitionID, loadPercentage)
Expand Down
17 changes: 12 additions & 5 deletions internal/querycoordv2/observers/leader_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,20 @@ func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64
return result
}

func (ob *LeaderObserver) CheckTargetVersion(collectionID int64) bool {
func (ob *LeaderObserver) CheckTargetVersion(ctx context.Context, collectionID int64) bool {
notifier := make(chan bool)
ob.manualCheck <- checkRequest{
CollectionID: collectionID,
Notifier: notifier,
select {
case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}:
case <-ctx.Done():
return false
}

select {
case result := <-notifier:
return result
case <-ctx.Done():
return false
}
return <-notifier
}

func (o *LeaderObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView) *querypb.SyncAction {
Expand Down
38 changes: 38 additions & 0 deletions internal/querycoordv2/observers/leader_observer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,44 @@ func (suite *LeaderObserverTestSuite) TestSyncTargetVersion() {
suite.Len(action.SealedInTarget, 1)
}

func (suite *LeaderObserverTestSuite) TestCheckTargetVersion() {
collectionID := int64(1001)
observer := suite.observer

suite.Run("check_channel_blocked", func() {
oldCh := observer.manualCheck
defer func() {
observer.manualCheck = oldCh
}()

// zero-length channel
observer.manualCheck = make(chan checkRequest)

ctx, cancel := context.WithCancel(context.Background())
// cancel context, make test return fast
cancel()

result := observer.CheckTargetVersion(ctx, collectionID)
suite.False(result)
})

suite.Run("check_return_ctx_timeout", func() {
oldCh := observer.manualCheck
defer func() {
observer.manualCheck = oldCh
}()

// make channel length = 1, task received
observer.manualCheck = make(chan checkRequest, 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()

result := observer.CheckTargetVersion(ctx, collectionID)
suite.False(result)
})
}

func TestLeaderObserverSuite(t *testing.T) {
suite.Run(t, new(LeaderObserverTestSuite))
}
17 changes: 12 additions & 5 deletions internal/querycoordv2/observers/target_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,20 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
// Check checks whether the next target is ready,
// and updates the current target if it is,
// returns true if current target is not nil
func (ob *TargetObserver) Check(collectionID int64) bool {
func (ob *TargetObserver) Check(ctx context.Context, collectionID int64) bool {
notifier := make(chan bool)
ob.manualCheck <- checkRequest{
CollectionID: collectionID,
Notifier: notifier,
select {
case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}:
case <-ctx.Done():
return false
}

select {
case result := <-notifier:
return result
case <-ctx.Done():
return false
}
return <-notifier
}

func (ob *TargetObserver) check(collectionID int64) {
Expand Down
96 changes: 96 additions & 0 deletions internal/querycoordv2/observers/target_observer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package observers

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -215,6 +216,101 @@ func (suite *TargetObserverSuite) TearDownSuite() {
suite.observer.Stop()
}

type TargetObserverCheckSuite struct {
suite.Suite

kv kv.MetaKv
// dependency
meta *meta.Meta
targetMgr *meta.TargetManager
distMgr *meta.DistributionManager
broker *meta.MockBroker

observer *TargetObserver

collectionID int64
partitionID int64
}

func (suite *TargetObserverCheckSuite) SetupSuite() {
paramtable.Init()
}

func (suite *TargetObserverCheckSuite) SetupTest() {
var err error
config := GenerateEtcdConfig()
cli, err := etcd.GetEtcdClient(
config.UseEmbedEtcd.GetAsBool(),
config.EtcdUseSSL.GetAsBool(),
config.Endpoints.GetAsStrings(),
config.EtcdTLSCert.GetValue(),
config.EtcdTLSKey.GetValue(),
config.EtcdTLSCACert.GetValue(),
config.EtcdTLSMinVersion.GetValue())
suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())

// meta
store := querycoord.NewCatalog(suite.kv)
idAllocator := RandomIncrementIDAllocator()
suite.meta = meta.NewMeta(idAllocator, store, session.NewNodeManager())

suite.broker = meta.NewMockBroker(suite.T())
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
suite.distMgr = meta.NewDistributionManager()
suite.observer = NewTargetObserver(suite.meta, suite.targetMgr, suite.distMgr, suite.broker)
suite.collectionID = int64(1000)
suite.partitionID = int64(100)

err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1))
suite.NoError(err)
err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID))
suite.NoError(err)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName)
suite.NoError(err)
replicas[0].AddNode(2)
err = suite.meta.ReplicaManager.Put(replicas...)
suite.NoError(err)
}

func (suite *TargetObserverCheckSuite) TestCheckCtxDone() {
observer := suite.observer

suite.Run("check_channel_blocked", func() {
oldCh := observer.manualCheck
defer func() {
observer.manualCheck = oldCh
}()

// zero-length channel
observer.manualCheck = make(chan checkRequest)

ctx, cancel := context.WithCancel(context.Background())
// cancel context, make test return fast
cancel()

result := observer.Check(ctx, suite.collectionID)
suite.False(result)
})

suite.Run("check_return_ctx_timeout", func() {
oldCh := observer.manualCheck
defer func() {
observer.manualCheck = oldCh
}()

// make channel length = 1, task received
observer.manualCheck = make(chan checkRequest, 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()

result := observer.Check(ctx, suite.collectionID)
suite.False(result)
})
}

func TestTargetObserver(t *testing.T) {
suite.Run(t, new(TargetObserverSuite))
suite.Run(t, new(TargetObserverCheckSuite))
}

0 comments on commit eca79d1

Please sign in to comment.