diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index e0c98a592be8e..d115e1a2708e1 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -185,7 +185,8 @@ func (suite *CollectionObserverSuite) SetupTest() { // Dependencies suite.dist = meta.NewDistributionManager() - suite.meta = meta.NewMeta(suite.idAllocator, suite.store, session.NewNodeManager()) + nodeMgr := session.NewNodeManager() + suite.meta = meta.NewMeta(suite.idAllocator, suite.store, nodeMgr) suite.broker = meta.NewMockBroker(suite.T()) suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) suite.targetObserver = NewTargetObserver(suite.meta, @@ -196,7 +197,7 @@ func (suite *CollectionObserverSuite) SetupTest() { suite.checkerController = &checkers.CheckerController{} mockCluster := session.NewMockCluster(suite.T()) - suite.leaderObserver = NewLeaderObserver(suite.dist, suite.meta, suite.targetMgr, suite.broker, mockCluster) + suite.leaderObserver = NewLeaderObserver(suite.dist, suite.meta, suite.targetMgr, suite.broker, mockCluster, nodeMgr) mockCluster.EXPECT().SyncDistribution(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe() // Test object diff --git a/internal/querycoordv2/observers/leader_observer.go b/internal/querycoordv2/observers/leader_observer.go index 5778a29766f6f..ea4bdab79633e 100644 --- a/internal/querycoordv2/observers/leader_observer.go +++ b/internal/querycoordv2/observers/leader_observer.go @@ -48,6 +48,7 @@ type LeaderObserver struct { target *meta.TargetManager broker meta.Broker cluster session.Cluster + nodeMgr *session.NodeManager dispatcher *taskDispatcher[int64] @@ -118,6 +119,11 @@ func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64 for _, replica := range replicas { leaders := o.dist.ChannelDistManager.GetShardLeadersByReplica(replica) for ch, leaderID := range leaders { + if ok, _ := o.nodeMgr.IsStoppingNode(leaderID); ok { + // no need to correct leader's view which is loaded on stopping node + continue + } + leaderView := o.dist.LeaderViewManager.GetLeaderShardView(leaderID, ch) if leaderView == nil { continue @@ -326,6 +332,7 @@ func NewLeaderObserver( targetMgr *meta.TargetManager, broker meta.Broker, cluster session.Cluster, + nodeMgr *session.NodeManager, ) *LeaderObserver { ob := &LeaderObserver{ dist: dist, @@ -333,6 +340,7 @@ func NewLeaderObserver( target: targetMgr, broker: broker, cluster: cluster, + nodeMgr: nodeMgr, } dispatcher := newTaskDispatcher[int64](ob.observeCollection) diff --git a/internal/querycoordv2/observers/leader_observer_test.go b/internal/querycoordv2/observers/leader_observer_test.go index 3a2738f4ff1aa..a471457b5a6e3 100644 --- a/internal/querycoordv2/observers/leader_observer_test.go +++ b/internal/querycoordv2/observers/leader_observer_test.go @@ -71,7 +71,8 @@ func (suite *LeaderObserverTestSuite) SetupTest() { // meta store := querycoord.NewCatalog(suite.kv) idAllocator := RandomIncrementIDAllocator() - suite.meta = meta.NewMeta(idAllocator, store, session.NewNodeManager()) + nodeMgr := session.NewNodeManager() + suite.meta = meta.NewMeta(idAllocator, store, nodeMgr) suite.broker = meta.NewMockBroker(suite.T()) suite.mockCluster = session.NewMockCluster(suite.T()) @@ -80,7 +81,7 @@ func (suite *LeaderObserverTestSuite) SetupTest() { // }, nil).Maybe() distManager := meta.NewDistributionManager() targetManager := meta.NewTargetManager(suite.broker, suite.meta) - suite.observer = NewLeaderObserver(distManager, suite.meta, targetManager, suite.broker, suite.mockCluster) + suite.observer = NewLeaderObserver(distManager, suite.meta, targetManager, suite.broker, suite.mockCluster, nodeMgr) } func (suite *LeaderObserverTestSuite) TearDownTest() { diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 8ef648ff01ce6..18a0b0bcb8f56 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -362,6 +362,7 @@ func (s *Server) initObserver() { s.targetMgr, s.broker, s.cluster, + s.nodeMgr, ) s.targetObserver = observers.NewTargetObserver( s.meta, diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index d6237e3bce5d4..b1b8c3f2ce31d 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -1317,7 +1317,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi // translate segment action removeActions := make([]*querypb.SyncAction, 0) - addSegments := make(map[int64][]*querypb.SegmentLoadInfo) + group, ctx := errgroup.WithContext(ctx) for _, action := range req.GetActions() { log := log.With(zap.String("Action", action.GetType().String())) @@ -1331,7 +1331,26 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi log.Warn("sync request from legacy querycoord without load info, skip") continue } - addSegments[action.GetNodeID()] = append(addSegments[action.GetNodeID()], action.GetInfo()) + + // to pass segment'version, we call load segment one by one + action := action + group.Go(func() error { + return shardDelegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), + commonpbutil.WithMsgID(req.Base.GetMsgID()), + ), + Infos: []*querypb.SegmentLoadInfo{action.GetInfo()}, + Schema: req.GetSchema(), + LoadMeta: req.GetLoadMeta(), + CollectionID: req.GetCollectionID(), + ReplicaID: req.GetReplicaID(), + DstNodeID: action.GetNodeID(), + Version: action.GetVersion(), + NeedTransfer: false, + LoadScope: querypb.LoadScope_Delta, + }) + }) case querypb.SyncType_UpdateVersion: log.Info("sync action", zap.Int64("TargetVersion", action.GetTargetVersion())) pipeline := node.pipelineManager.Get(req.GetChannel()) @@ -1353,25 +1372,10 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi } } - for nodeID, infos := range addSegments { - err := shardDelegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), - commonpbutil.WithMsgID(req.Base.GetMsgID()), - ), - Infos: infos, - Schema: req.GetSchema(), - LoadMeta: req.GetLoadMeta(), - CollectionID: req.GetCollectionID(), - ReplicaID: req.GetReplicaID(), - DstNodeID: nodeID, - Version: req.GetVersion(), - NeedTransfer: false, - LoadScope: querypb.LoadScope_Delta, - }) - if err != nil { - return merr.Status(err), nil - } + err := group.Wait() + if err != nil { + log.Warn("failed to sync distribution", zap.Error(err)) + return merr.Status(err), nil } for _, action := range removeActions { diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 59c73e86ce154..494f27f3254c7 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -45,6 +46,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/streamrpc" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -1765,6 +1767,7 @@ func (suite *ServiceSuite) TestSyncDistribution_Normal() { suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) + // test sync targte version syncVersionAction := &querypb.SyncAction{ Type: querypb.SyncType_UpdateVersion, SealedInTarget: []int64{3}, @@ -1777,6 +1780,37 @@ func (suite *ServiceSuite) TestSyncDistribution_Normal() { status, err = suite.node.SyncDistribution(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + + // test sync segments + segmentVersion := int64(111) + syncSegmentVersion := &querypb.SyncAction{ + Type: querypb.SyncType_Set, + SegmentID: suite.validSegmentIDs[0], + NodeID: 0, + PartitionID: suite.partitionIDs[0], + Info: &querypb.SegmentLoadInfo{}, + Version: segmentVersion, + } + req.Actions = []*querypb.SyncAction{syncSegmentVersion} + + testChannel := "test_sync_segment" + req.Channel = testChannel + + // expected call load segment with right segment version + var versionMatch bool + mockDelegator := delegator.NewMockShardDelegator(suite.T()) + mockDelegator.EXPECT().LoadSegments(mock.Anything, mock.Anything). + RunAndReturn(func(ctx context.Context, req *querypb.LoadSegmentsRequest) error { + log.Info("version", zap.Int64("versionInload", req.GetVersion())) + versionMatch = req.GetVersion() == segmentVersion + return nil + }) + suite.node.delegators.Insert(testChannel, mockDelegator) + + status, err = suite.node.SyncDistribution(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + suite.True(versionMatch) } func (suite *ServiceSuite) TestSyncDistribution_ReleaseResultCheck() {