From 12cc500009dfeb2a5ca90a13a364b27f4bcbccf9 Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Fri, 6 Dec 2024 16:34:41 +0800 Subject: [PATCH] enhance: [2.4] Reduce segmentManager lock granularity (#37869) Use a channel level key lock for segments in segmentManager. issue: https://github.com/milvus-io/milvus/issues/37633, https://github.com/milvus-io/milvus/issues/37630 pr: https://github.com/milvus-io/milvus/pull/37836 --------- Signed-off-by: bigsheeper --- internal/datacoord/mock_segment_manager.go | 66 ++-- internal/datacoord/segment_manager.go | 372 ++++++++++++--------- internal/datacoord/segment_manager_test.go | 100 ++++-- internal/datacoord/server_test.go | 49 +-- internal/datacoord/services.go | 35 +- internal/datacoord/services_test.go | 64 ++-- 6 files changed, 371 insertions(+), 315 deletions(-) diff --git a/internal/datacoord/mock_segment_manager.go b/internal/datacoord/mock_segment_manager.go index 8a61177baa829..58f68a81c847f 100644 --- a/internal/datacoord/mock_segment_manager.go +++ b/internal/datacoord/mock_segment_manager.go @@ -79,9 +79,9 @@ func (_c *MockManager_AllocSegment_Call) RunAndReturn(run func(context.Context, return _c } -// DropSegment provides a mock function with given fields: ctx, segmentID -func (_m *MockManager) DropSegment(ctx context.Context, segmentID int64) { - _m.Called(ctx, segmentID) +// DropSegment provides a mock function with given fields: ctx, channel, segmentID +func (_m *MockManager) DropSegment(ctx context.Context, channel string, segmentID int64) { + _m.Called(ctx, channel, segmentID) } // MockManager_DropSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropSegment' @@ -91,14 +91,15 @@ type MockManager_DropSegment_Call struct { // DropSegment is a helper method to define mock.On call // - ctx context.Context +// - channel string // - segmentID int64 -func (_e *MockManager_Expecter) DropSegment(ctx interface{}, segmentID interface{}) *MockManager_DropSegment_Call { - return &MockManager_DropSegment_Call{Call: _e.mock.On("DropSegment", ctx, segmentID)} +func (_e *MockManager_Expecter) DropSegment(ctx interface{}, channel interface{}, segmentID interface{}) *MockManager_DropSegment_Call { + return &MockManager_DropSegment_Call{Call: _e.mock.On("DropSegment", ctx, channel, segmentID)} } -func (_c *MockManager_DropSegment_Call) Run(run func(ctx context.Context, segmentID int64)) *MockManager_DropSegment_Call { +func (_c *MockManager_DropSegment_Call) Run(run func(ctx context.Context, channel string, segmentID int64)) *MockManager_DropSegment_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64)) + run(args[0].(context.Context), args[1].(string), args[2].(int64)) }) return _c } @@ -108,7 +109,7 @@ func (_c *MockManager_DropSegment_Call) Return() *MockManager_DropSegment_Call { return _c } -func (_c *MockManager_DropSegment_Call) RunAndReturn(run func(context.Context, int64)) *MockManager_DropSegment_Call { +func (_c *MockManager_DropSegment_Call) RunAndReturn(run func(context.Context, string, int64)) *MockManager_DropSegment_Call { _c.Call.Return(run) return _c } @@ -148,17 +149,8 @@ func (_c *MockManager_DropSegmentsOfChannel_Call) RunAndReturn(run func(context. } // ExpireAllocations provides a mock function with given fields: channel, ts -func (_m *MockManager) ExpireAllocations(channel string, ts uint64) error { - ret := _m.Called(channel, ts) - - var r0 error - if rf, ok := ret.Get(0).(func(string, uint64) error); ok { - r0 = rf(channel, ts) - } else { - r0 = ret.Error(0) - } - - return r0 +func (_m *MockManager) ExpireAllocations(channel string, ts uint64) { + _m.Called(channel, ts) } // MockManager_ExpireAllocations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExpireAllocations' @@ -180,12 +172,12 @@ func (_c *MockManager_ExpireAllocations_Call) Run(run func(channel string, ts ui return _c } -func (_c *MockManager_ExpireAllocations_Call) Return(_a0 error) *MockManager_ExpireAllocations_Call { - _c.Call.Return(_a0) +func (_c *MockManager_ExpireAllocations_Call) Return() *MockManager_ExpireAllocations_Call { + _c.Call.Return() return _c } -func (_c *MockManager_ExpireAllocations_Call) RunAndReturn(run func(string, uint64) error) *MockManager_ExpireAllocations_Call { +func (_c *MockManager_ExpireAllocations_Call) RunAndReturn(run func(string, uint64)) *MockManager_ExpireAllocations_Call { _c.Call.Return(run) return _c } @@ -246,25 +238,25 @@ func (_c *MockManager_GetFlushableSegments_Call) RunAndReturn(run func(context.C return _c } -// SealAllSegments provides a mock function with given fields: ctx, collectionID, segIDs -func (_m *MockManager) SealAllSegments(ctx context.Context, collectionID int64, segIDs []int64) ([]int64, error) { - ret := _m.Called(ctx, collectionID, segIDs) +// SealAllSegments provides a mock function with given fields: ctx, channel, segIDs +func (_m *MockManager) SealAllSegments(ctx context.Context, channel string, segIDs []int64) ([]int64, error) { + ret := _m.Called(ctx, channel, segIDs) var r0 []int64 var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64, []int64) ([]int64, error)); ok { - return rf(ctx, collectionID, segIDs) + if rf, ok := ret.Get(0).(func(context.Context, string, []int64) ([]int64, error)); ok { + return rf(ctx, channel, segIDs) } - if rf, ok := ret.Get(0).(func(context.Context, int64, []int64) []int64); ok { - r0 = rf(ctx, collectionID, segIDs) + if rf, ok := ret.Get(0).(func(context.Context, string, []int64) []int64); ok { + r0 = rf(ctx, channel, segIDs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int64) } } - if rf, ok := ret.Get(1).(func(context.Context, int64, []int64) error); ok { - r1 = rf(ctx, collectionID, segIDs) + if rf, ok := ret.Get(1).(func(context.Context, string, []int64) error); ok { + r1 = rf(ctx, channel, segIDs) } else { r1 = ret.Error(1) } @@ -279,15 +271,15 @@ type MockManager_SealAllSegments_Call struct { // SealAllSegments is a helper method to define mock.On call // - ctx context.Context -// - collectionID int64 +// - channel string // - segIDs []int64 -func (_e *MockManager_Expecter) SealAllSegments(ctx interface{}, collectionID interface{}, segIDs interface{}) *MockManager_SealAllSegments_Call { - return &MockManager_SealAllSegments_Call{Call: _e.mock.On("SealAllSegments", ctx, collectionID, segIDs)} +func (_e *MockManager_Expecter) SealAllSegments(ctx interface{}, channel interface{}, segIDs interface{}) *MockManager_SealAllSegments_Call { + return &MockManager_SealAllSegments_Call{Call: _e.mock.On("SealAllSegments", ctx, channel, segIDs)} } -func (_c *MockManager_SealAllSegments_Call) Run(run func(ctx context.Context, collectionID int64, segIDs []int64)) *MockManager_SealAllSegments_Call { +func (_c *MockManager_SealAllSegments_Call) Run(run func(ctx context.Context, channel string, segIDs []int64)) *MockManager_SealAllSegments_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64), args[2].([]int64)) + run(args[0].(context.Context), args[1].(string), args[2].([]int64)) }) return _c } @@ -297,7 +289,7 @@ func (_c *MockManager_SealAllSegments_Call) Return(_a0 []int64, _a1 error) *Mock return _c } -func (_c *MockManager_SealAllSegments_Call) RunAndReturn(run func(context.Context, int64, []int64) ([]int64, error)) *MockManager_SealAllSegments_Call { +func (_c *MockManager_SealAllSegments_Call) RunAndReturn(run func(context.Context, string, []int64) ([]int64, error)) *MockManager_SealAllSegments_Call { _c.Call.Return(run) return _c } diff --git a/internal/datacoord/segment_manager.go b/internal/datacoord/segment_manager.go index ccc1b255399e0..52507f2f7e4b0 100644 --- a/internal/datacoord/segment_manager.go +++ b/internal/datacoord/segment_manager.go @@ -75,14 +75,14 @@ type Manager interface { // AllocSegment allocates rows and record the allocation. AllocSegment(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error) // DropSegment drops the segment from manager. - DropSegment(ctx context.Context, segmentID UniqueID) + DropSegment(ctx context.Context, channel string, segmentID UniqueID) // SealAllSegments seals all segments of collection with collectionID and return sealed segments. // If segIDs is not empty, also seals segments in segIDs. - SealAllSegments(ctx context.Context, collectionID UniqueID, segIDs []UniqueID) ([]UniqueID, error) + SealAllSegments(ctx context.Context, channel string, segIDs []UniqueID) ([]UniqueID, error) // GetFlushableSegments returns flushable segment ids GetFlushableSegments(ctx context.Context, channel string, ts Timestamp) ([]UniqueID, error) // ExpireAllocations notifies segment status to expire old allocations - ExpireAllocations(channel string, ts Timestamp) error + ExpireAllocations(channel string, ts Timestamp) // DropSegmentsOfChannel drops all segments in a channel DropSegmentsOfChannel(ctx context.Context, channel string) } @@ -104,11 +104,15 @@ var _ Manager = (*SegmentManager)(nil) // SegmentManager handles L1 segment related logic type SegmentManager struct { - meta *meta - mu lock.RWMutex - allocator allocator - helper allocHelper - segments []UniqueID + meta *meta + allocator allocator + helper allocHelper + + channelLock *lock.KeyLock[string] + channel2Growing *typeutil.ConcurrentMap[string, typeutil.UniqueSet] + channel2Sealed *typeutil.ConcurrentMap[string, typeutil.UniqueSet] + + // Policies estimatePolicy calUpperLimitPolicy allocPolicy AllocatePolicy segmentSealPolicies []SegmentSealPolicy @@ -209,7 +213,9 @@ func newSegmentManager(meta *meta, allocator allocator, opts ...allocOption) (*S meta: meta, allocator: allocator, helper: defaultAllocHelper(), - segments: make([]UniqueID, 0), + channelLock: lock.NewKeyLock[string](), + channel2Growing: typeutil.NewConcurrentMap[string, typeutil.UniqueSet](), + channel2Sealed: typeutil.NewConcurrentMap[string, typeutil.UniqueSet](), estimatePolicy: defaultCalUpperLimitPolicy(), allocPolicy: defaultAllocatePolicy(), segmentSealPolicies: defaultSegmentSealPolicy(), @@ -219,49 +225,57 @@ func newSegmentManager(meta *meta, allocator allocator, opts ...allocOption) (*S for _, opt := range opts { opt.apply(manager) } - manager.loadSegmentsFromMeta() - if err := manager.maybeResetLastExpireForSegments(); err != nil { + latestTs, err := manager.genLastExpireTsForSegments() + if err != nil { return nil, err } + manager.loadSegmentsFromMeta(latestTs) return manager, nil } // loadSegmentsFromMeta generate corresponding segment status for each segment from meta -func (s *SegmentManager) loadSegmentsFromMeta() { - segments := s.meta.GetUnFlushedSegments() - segmentsID := make([]UniqueID, 0, len(segments)) - for _, segment := range segments { - if segment.Level != datapb.SegmentLevel_L0 { - segmentsID = append(segmentsID, segment.GetID()) +func (s *SegmentManager) loadSegmentsFromMeta(latestTs Timestamp) { + unflushed := s.meta.GetUnFlushedSegments() + unflushed = lo.Filter(unflushed, func(segment *SegmentInfo, _ int) bool { + return segment.Level != datapb.SegmentLevel_L0 + }) + channel2Segments := lo.GroupBy(unflushed, func(segment *SegmentInfo) string { + return segment.GetInsertChannel() + }) + for channel, segmentInfos := range channel2Segments { + growing := typeutil.NewUniqueSet() + sealed := typeutil.NewUniqueSet() + for _, segment := range segmentInfos { + // for all sealed and growing segments, need to reset last expire + if segment != nil && segment.GetState() == commonpb.SegmentState_Growing { + s.meta.SetLastExpire(segment.GetID(), latestTs) + growing.Insert(segment.GetID()) + } + if segment != nil && segment.GetState() == commonpb.SegmentState_Sealed { + sealed.Insert(segment.GetID()) + } } + s.channel2Growing.Insert(channel, growing) + s.channel2Sealed.Insert(channel, sealed) } - s.segments = segmentsID } -func (s *SegmentManager) maybeResetLastExpireForSegments() error { - // for all sealed and growing segments, need to reset last expire - if len(s.segments) > 0 { - var latestTs uint64 - allocateErr := retry.Do(context.Background(), func() error { - ts, tryErr := s.genExpireTs(context.Background()) +func (s *SegmentManager) genLastExpireTsForSegments() (Timestamp, error) { + var latestTs uint64 + allocateErr := retry.Do(context.Background(), func() error { + ts, tryErr := s.genExpireTs(context.Background()) + if tryErr != nil { log.Warn("failed to get ts from rootCoord for globalLastExpire", zap.Error(tryErr)) - if tryErr != nil { - return tryErr - } - latestTs = ts - return nil - }, retry.Attempts(Params.DataCoordCfg.AllocLatestExpireAttempt.GetAsUint()), retry.Sleep(200*time.Millisecond)) - if allocateErr != nil { - log.Warn("cannot allocate latest lastExpire from rootCoord", zap.Error(allocateErr)) - return errors.New("global max expire ts is unavailable for segment manager") - } - for _, sID := range s.segments { - if segment := s.meta.GetSegment(sID); segment != nil && segment.GetState() == commonpb.SegmentState_Growing { - s.meta.SetLastExpire(sID, latestTs) - } + return tryErr } + latestTs = ts + return nil + }, retry.Attempts(Params.DataCoordCfg.AllocLatestExpireAttempt.GetAsUint()), retry.Sleep(200*time.Millisecond)) + if allocateErr != nil { + log.Warn("cannot allocate latest lastExpire from rootCoord", zap.Error(allocateErr)) + return 0, errors.New("global max expire ts is unavailable for segment manager") } - return nil + return latestTs, nil } // AllocSegment allocate segment per request collcation, partication, channel and rows @@ -275,38 +289,33 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID With(zap.Int64("requestRows", requestRows)) _, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "Alloc-Segment") defer sp.End() - s.mu.Lock() - defer s.mu.Unlock() + + s.channelLock.Lock(channelName) + defer s.channelLock.Unlock(channelName) // filter segments - validSegments := make(map[UniqueID]struct{}) - invalidSegments := make(map[UniqueID]struct{}) - segments := make([]*SegmentInfo, 0) - for _, segmentID := range s.segments { + segmentInfos := make([]*SegmentInfo, 0) + growing, _ := s.channel2Growing.Get(channelName) + growing.Range(func(segmentID int64) bool { segment := s.meta.GetHealthySegment(segmentID) if segment == nil { - invalidSegments[segmentID] = struct{}{} - continue + log.Warn("failed to get segment, remove it", zap.String("channel", channelName), zap.Int64("segmentID", segmentID)) + growing.Remove(segmentID) + return true } - validSegments[segmentID] = struct{}{} - - if !satisfy(segment, collectionID, partitionID, channelName) || !isGrowing(segment) || segment.GetLevel() == datapb.SegmentLevel_L0 { - continue + if segment.GetPartitionID() != partitionID { + return true } - segments = append(segments, segment) - } - - if len(invalidSegments) > 0 { - log.Warn("Failed to get segments infos from meta, clear them", zap.Int64s("segmentIDs", lo.Keys(invalidSegments))) - } - s.segments = lo.Keys(validSegments) + segmentInfos = append(segmentInfos, segment) + return true + }) // Apply allocation policy. maxCountPerSegment, err := s.estimateMaxNumOfRows(collectionID) if err != nil { return nil, err } - newSegmentAllocations, existedSegmentAllocations := s.allocPolicy(segments, + newSegmentAllocations, existedSegmentAllocations := s.allocPolicy(segmentInfos, requestRows, int64(maxCountPerSegment), datapb.SegmentLevel_L1) // create new segments and add allocations @@ -339,15 +348,6 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID return allocations, nil } -func satisfy(segment *SegmentInfo, collectionID, partitionID UniqueID, channel string) bool { - return segment.GetCollectionID() == collectionID && segment.GetPartitionID() == partitionID && - segment.GetInsertChannel() == channel -} - -func isGrowing(segment *SegmentInfo) bool { - return segment.GetState() == commonpb.SegmentState_Growing -} - func (s *SegmentManager) genExpireTs(ctx context.Context) (Timestamp, error) { ts, err := s.allocator.allocTimestamp(ctx) if err != nil { @@ -392,7 +392,8 @@ func (s *SegmentManager) openNewSegment(ctx context.Context, collectionID Unique log.Error("failed to add segment to DataCoord", zap.Error(err)) return nil, err } - s.segments = append(s.segments, id) + growing, _ := s.channel2Growing.GetOrInsert(channelName, typeutil.NewUniqueSet()) + growing.Insert(id) log.Info("datacoord: estimateTotalRows: ", zap.Int64("CollectionID", segmentInfo.CollectionID), zap.Int64("SegmentID", segmentInfo.ID), @@ -412,17 +413,20 @@ func (s *SegmentManager) estimateMaxNumOfRows(collectionID UniqueID) (int, error } // DropSegment drop the segment from manager. -func (s *SegmentManager) DropSegment(ctx context.Context, segmentID UniqueID) { +func (s *SegmentManager) DropSegment(ctx context.Context, channel string, segmentID UniqueID) { _, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "Drop-Segment") defer sp.End() - s.mu.Lock() - defer s.mu.Unlock() - for i, id := range s.segments { - if id == segmentID { - s.segments = append(s.segments[:i], s.segments[i+1:]...) - break - } + + s.channelLock.Lock(channel) + defer s.channelLock.Unlock(channel) + + if growing, ok := s.channel2Growing.Get(channel); ok { + growing.Remove(segmentID) } + if sealed, ok := s.channel2Sealed.Get(channel); ok { + sealed.Remove(segmentID) + } + segment := s.meta.GetHealthySegment(segmentID) if segment == nil { log.Warn("Failed to get segment", zap.Int64("id", segmentID)) @@ -435,30 +439,46 @@ func (s *SegmentManager) DropSegment(ctx context.Context, segmentID UniqueID) { } // SealAllSegments seals all segments of collection with collectionID and return sealed segments -func (s *SegmentManager) SealAllSegments(ctx context.Context, collectionID UniqueID, segIDs []UniqueID) ([]UniqueID, error) { +func (s *SegmentManager) SealAllSegments(ctx context.Context, channel string, segIDs []UniqueID) ([]UniqueID, error) { _, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "Seal-Segments") defer sp.End() - s.mu.Lock() - defer s.mu.Unlock() - var ret []UniqueID - segCandidates := s.segments + s.channelLock.Lock(channel) + defer s.channelLock.Unlock(channel) + + sealed, _ := s.channel2Sealed.GetOrInsert(channel, typeutil.NewUniqueSet()) + growing, _ := s.channel2Growing.Get(channel) + + var ( + sealedSegments []int64 + growingSegments []int64 + ) + if len(segIDs) != 0 { - segCandidates = segIDs + sealedSegments = s.meta.GetSegments(segIDs, func(segment *SegmentInfo) bool { + return isSegmentHealthy(segment) && segment.State == commonpb.SegmentState_Sealed + }) + growingSegments = s.meta.GetSegments(segIDs, func(segment *SegmentInfo) bool { + return isSegmentHealthy(segment) && segment.State == commonpb.SegmentState_Growing + }) + } else { + sealedSegments = s.meta.GetSegments(sealed.Collect(), func(segment *SegmentInfo) bool { + return isSegmentHealthy(segment) + }) + growingSegments = s.meta.GetSegments(growing.Collect(), func(segment *SegmentInfo) bool { + return isSegmentHealthy(segment) + }) } - sealedSegments := s.meta.GetSegments(segCandidates, func(segment *SegmentInfo) bool { - return segment.CollectionID == collectionID && isSegmentHealthy(segment) && segment.State == commonpb.SegmentState_Sealed - }) - growingSegments := s.meta.GetSegments(segCandidates, func(segment *SegmentInfo) bool { - return segment.CollectionID == collectionID && isSegmentHealthy(segment) && segment.State == commonpb.SegmentState_Growing - }) + var ret []UniqueID ret = append(ret, sealedSegments...) for _, id := range growingSegments { if err := s.meta.SetState(id, commonpb.SegmentState_Sealed); err != nil { return nil, err } + sealed.Insert(id) + growing.Remove(id) ret = append(ret, id) } return ret, nil @@ -468,37 +488,54 @@ func (s *SegmentManager) SealAllSegments(ctx context.Context, collectionID Uniqu func (s *SegmentManager) GetFlushableSegments(ctx context.Context, channel string, t Timestamp) ([]UniqueID, error) { _, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "Get-Segments") defer sp.End() - s.mu.Lock() - defer s.mu.Unlock() + + s.channelLock.Lock(channel) + defer s.channelLock.Unlock(channel) + // TODO:move tryToSealSegment and dropEmptySealedSegment outside if err := s.tryToSealSegment(t, channel); err != nil { return nil, err } + // TODO: It's too frequent; perhaps each channel could check once per minute instead. s.cleanupSealedSegment(t, channel) - ret := make([]UniqueID, 0, len(s.segments)) - for _, id := range s.segments { - info := s.meta.GetHealthySegment(id) - if info == nil || info.InsertChannel != channel { - continue + sealed, ok := s.channel2Sealed.Get(channel) + if !ok { + return nil, nil + } + + ret := make([]UniqueID, 0, sealed.Len()) + sealed.Range(func(segmentID int64) bool { + info := s.meta.GetHealthySegment(segmentID) + if info == nil { + return true } if s.flushPolicy(info, t) { - ret = append(ret, id) + ret = append(ret, segmentID) } - } + return true + }) return ret, nil } // ExpireAllocations notify segment status to expire old allocations -func (s *SegmentManager) ExpireAllocations(channel string, ts Timestamp) error { - s.mu.Lock() - defer s.mu.Unlock() - for _, id := range s.segments { +func (s *SegmentManager) ExpireAllocations(channel string, ts Timestamp) { + s.channelLock.Lock(channel) + defer s.channelLock.Unlock(channel) + + growing, ok := s.channel2Growing.Get(channel) + if !ok { + return + } + + growing.Range(func(id int64) bool { segment := s.meta.GetHealthySegment(id) - if segment == nil || segment.InsertChannel != channel { - continue + if segment == nil { + log.Warn("failed to get segment, remove it", zap.String("channel", channel), zap.Int64("segmentID", id)) + growing.Remove(id) + return true } allocations := make([]*Allocation, 0, len(segment.allocations)) for i := 0; i < len(segment.allocations); i++ { @@ -510,76 +547,89 @@ func (s *SegmentManager) ExpireAllocations(channel string, ts Timestamp) error { } } s.meta.SetAllocations(segment.GetID(), allocations) - } - return nil + return true + }) } func (s *SegmentManager) cleanupSealedSegment(ts Timestamp, channel string) { - valids := make([]int64, 0, len(s.segments)) - for _, id := range s.segments { + sealed, ok := s.channel2Sealed.Get(channel) + if !ok { + return + } + sealed.Range(func(id int64) bool { segment := s.meta.GetHealthySegment(id) - if segment == nil || segment.InsertChannel != channel { - valids = append(valids, id) - continue + if segment == nil { + log.Warn("failed to get segment, remove it", zap.String("channel", channel), zap.Int64("segmentID", id)) + sealed.Remove(id) + return true } - - if isEmptySealedSegment(segment, ts) { + // Check if segment is empty + if segment.GetLastExpireTime() <= ts && segment.currRows == 0 { log.Info("remove empty sealed segment", zap.Int64("collection", segment.CollectionID), zap.Int64("segment", id)) - s.meta.SetState(id, commonpb.SegmentState_Dropped) - continue + if err := s.meta.SetState(id, commonpb.SegmentState_Dropped); err != nil { + log.Warn("failed to set segment state to dropped", zap.String("channel", channel), + zap.Int64("segmentID", id), zap.Error(err)) + } else { + sealed.Remove(id) + } } - - valids = append(valids, id) - } - s.segments = valids -} - -func isEmptySealedSegment(segment *SegmentInfo, ts Timestamp) bool { - return segment.GetState() == commonpb.SegmentState_Sealed && segment.GetLastExpireTime() <= ts && segment.currRows == 0 + return true + }) } // tryToSealSegment applies segment & channel seal policies func (s *SegmentManager) tryToSealSegment(ts Timestamp, channel string) error { - channelInfo := make(map[string][]*SegmentInfo) + growing, ok := s.channel2Growing.Get(channel) + if !ok { + return nil + } + sealed, _ := s.channel2Sealed.GetOrInsert(channel, typeutil.NewUniqueSet()) + + channelSegmentInfos := make([]*SegmentInfo, 0, len(growing)) sealedSegments := make(map[int64]struct{}) - for _, id := range s.segments { + + var setStateErr error + growing.Range(func(id int64) bool { info := s.meta.GetHealthySegment(id) - if info == nil || info.InsertChannel != channel { - continue - } - channelInfo[info.InsertChannel] = append(channelInfo[info.InsertChannel], info) - if info.State != commonpb.SegmentState_Growing { - continue + if info == nil { + return true } + channelSegmentInfos = append(channelSegmentInfos, info) // change shouldSeal to segment seal policy logic for _, policy := range s.segmentSealPolicies { if shouldSeal, reason := policy.ShouldSeal(info, ts); shouldSeal { log.Info("Seal Segment for policy matched", zap.Int64("segmentID", info.GetID()), zap.String("reason", reason)) if err := s.meta.SetState(id, commonpb.SegmentState_Sealed); err != nil { - return err + setStateErr = err + return false } sealedSegments[id] = struct{}{} + sealed.Insert(id) + growing.Remove(id) break } } + return true + }) + + if setStateErr != nil { + return setStateErr } - for channel, segmentInfos := range channelInfo { - for _, policy := range s.channelSealPolicies { - vs, reason := policy(channel, segmentInfos, ts) - for _, info := range vs { - if _, ok := sealedSegments[info.GetID()]; ok { - continue - } - if info.State != commonpb.SegmentState_Growing { - continue - } - if err := s.meta.SetState(info.GetID(), commonpb.SegmentState_Sealed); err != nil { - return err - } - log.Info("seal segment for channel seal policy matched", - zap.Int64("segmentID", info.GetID()), zap.String("channel", channel), zap.String("reason", reason)) - sealedSegments[info.GetID()] = struct{}{} + + for _, policy := range s.channelSealPolicies { + vs, reason := policy(channel, channelSegmentInfos, ts) + for _, info := range vs { + if _, ok := sealedSegments[info.GetID()]; ok { + continue + } + if err := s.meta.SetState(info.GetID(), commonpb.SegmentState_Sealed); err != nil { + return err } + log.Info("seal segment for channel seal policy matched", + zap.Int64("segmentID", info.GetID()), zap.String("channel", channel), zap.String("reason", reason)) + sealedSegments[info.GetID()] = struct{}{} + sealed.Insert(info.GetID()) + growing.Remove(info.GetID()) } } return nil @@ -587,24 +637,26 @@ func (s *SegmentManager) tryToSealSegment(ts Timestamp, channel string) error { // DropSegmentsOfChannel drops all segments in a channel func (s *SegmentManager) DropSegmentsOfChannel(ctx context.Context, channel string) { - s.mu.Lock() - defer s.mu.Unlock() + s.channelLock.Lock(channel) + defer s.channelLock.Unlock(channel) - validSegments := make([]int64, 0, len(s.segments)) - for _, sid := range s.segments { + s.channel2Sealed.Remove(channel) + growing, ok := s.channel2Growing.Get(channel) + if !ok { + return + } + growing.Range(func(sid int64) bool { segment := s.meta.GetHealthySegment(sid) if segment == nil { - continue - } - if segment.GetInsertChannel() != channel { - validSegments = append(validSegments, sid) - continue + log.Warn("failed to get segment, remove it", zap.String("channel", channel), zap.Int64("segmentID", sid)) + growing.Remove(sid) + return true } s.meta.SetAllocations(sid, nil) for _, allocation := range segment.allocations { putAllocation(allocation) } - } - - s.segments = validSegments + return true + }) + s.channel2Growing.Remove(channel) } diff --git a/internal/datacoord/segment_manager_test.go b/internal/datacoord/segment_manager_test.go index 329841c29f757..42e5fe7c51ca9 100644 --- a/internal/datacoord/segment_manager_test.go +++ b/internal/datacoord/segment_manager_test.go @@ -34,7 +34,9 @@ import ( "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) func TestManagerOptions(t *testing.T) { @@ -142,19 +144,25 @@ func TestAllocSegment(t *testing.T) { }) t.Run("alloc clear unhealthy segment", func(t *testing.T) { - allocations1, err := segmentManager.AllocSegment(ctx, collID, 100, "c1", 100) + vchannel := "c1" + partitionID := int64(100) + allocations1, err := segmentManager.AllocSegment(ctx, collID, partitionID, vchannel, 100) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations1)) - assert.EqualValues(t, 1, len(segmentManager.segments)) + segments, ok := segmentManager.channel2Growing.Get(vchannel) + assert.True(t, ok) + assert.EqualValues(t, 1, segments.Len()) err = meta.SetState(allocations1[0].SegmentID, commonpb.SegmentState_Dropped) assert.NoError(t, err) - allocations2, err := segmentManager.AllocSegment(ctx, collID, 100, "c1", 100) + allocations2, err := segmentManager.AllocSegment(ctx, collID, partitionID, vchannel, 100) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations2)) // clear old healthy and alloc new - assert.EqualValues(t, 1, len(segmentManager.segments)) + segments, ok = segmentManager.channel2Growing.Get(vchannel) + assert.True(t, ok) + assert.EqualValues(t, 1, segments.Len()) assert.NotEqual(t, allocations1[0].SegmentID, allocations2[0].SegmentID) }) } @@ -217,7 +225,8 @@ func TestLastExpireReset(t *testing.T) { meta.SetCurrentRows(segmentID1, bigRows) meta.SetCurrentRows(segmentID2, bigRows) meta.SetCurrentRows(segmentID3, smallRows) - segmentManager.tryToSealSegment(expire1, channelName) + err = segmentManager.tryToSealSegment(expire1, channelName) + assert.NoError(t, err) assert.Equal(t, commonpb.SegmentState_Sealed, meta.GetSegment(segmentID1).GetState()) assert.Equal(t, commonpb.SegmentState_Sealed, meta.GetSegment(segmentID2).GetState()) assert.Equal(t, commonpb.SegmentState_Growing, meta.GetSegment(segmentID3).GetState()) @@ -270,11 +279,14 @@ func TestLoadSegmentsFromMeta(t *testing.T) { assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) + vchannel := "ch0" + partitionID := int64(100) + sealedSegment := &datapb.SegmentInfo{ ID: 1, CollectionID: collID, - PartitionID: 0, - InsertChannel: "", + PartitionID: partitionID, + InsertChannel: vchannel, State: commonpb.SegmentState_Sealed, MaxRowNum: 100, LastExpireTime: 1000, @@ -282,8 +294,8 @@ func TestLoadSegmentsFromMeta(t *testing.T) { growingSegment := &datapb.SegmentInfo{ ID: 2, CollectionID: collID, - PartitionID: 0, - InsertChannel: "", + PartitionID: partitionID, + InsertChannel: vchannel, State: commonpb.SegmentState_Growing, MaxRowNum: 100, LastExpireTime: 1000, @@ -291,8 +303,8 @@ func TestLoadSegmentsFromMeta(t *testing.T) { flushedSegment := &datapb.SegmentInfo{ ID: 3, CollectionID: collID, - PartitionID: 0, - InsertChannel: "", + PartitionID: partitionID, + InsertChannel: vchannel, State: commonpb.SegmentState_Flushed, MaxRowNum: 100, LastExpireTime: 1000, @@ -304,9 +316,14 @@ func TestLoadSegmentsFromMeta(t *testing.T) { err = meta.AddSegment(context.TODO(), NewSegmentInfo(flushedSegment)) assert.NoError(t, err) - segmentManager, _ := newSegmentManager(meta, mockAllocator) - segments := segmentManager.segments - assert.EqualValues(t, 2, len(segments)) + segmentManager, err := newSegmentManager(meta, mockAllocator) + assert.NoError(t, err) + growing, ok := segmentManager.channel2Growing.Get(vchannel) + assert.True(t, ok) + assert.EqualValues(t, 1, growing.Len()) + sealed, ok := segmentManager.channel2Sealed.Get(vchannel) + assert.True(t, ok) + assert.EqualValues(t, 1, sealed.Len()) } func TestSaveSegmentsToMeta(t *testing.T) { @@ -323,7 +340,7 @@ func TestSaveSegmentsToMeta(t *testing.T) { allocations, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) - _, err = segmentManager.SealAllSegments(context.Background(), collID, nil) + _, err = segmentManager.SealAllSegments(context.Background(), "c1", nil) assert.NoError(t, err) segment := meta.GetHealthySegment(allocations[0].SegmentID) assert.NotNil(t, segment) @@ -345,7 +362,7 @@ func TestSaveSegmentsToMetaWithSpecificSegments(t *testing.T) { allocations, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) - _, err = segmentManager.SealAllSegments(context.Background(), collID, []int64{allocations[0].SegmentID}) + _, err = segmentManager.SealAllSegments(context.Background(), "c1", []int64{allocations[0].SegmentID}) assert.NoError(t, err) segment := meta.GetHealthySegment(allocations[0].SegmentID) assert.NotNil(t, segment) @@ -364,14 +381,14 @@ func TestDropSegment(t *testing.T) { assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) segmentManager, _ := newSegmentManager(meta, mockAllocator) - allocations, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000) + allocations, err := segmentManager.AllocSegment(context.Background(), collID, 100, "c1", 1000) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) segID := allocations[0].SegmentID segment := meta.GetHealthySegment(segID) assert.NotNil(t, segment) - segmentManager.DropSegment(context.Background(), segID) + segmentManager.DropSegment(context.Background(), "c1", segID) segment = meta.GetHealthySegment(segID) assert.NotNil(t, segment) } @@ -433,8 +450,7 @@ func TestExpireAllocation(t *testing.T) { segment := meta.GetHealthySegment(id) assert.NotNil(t, segment) assert.EqualValues(t, 100, len(segment.allocations)) - err = segmentManager.ExpireAllocations("ch1", maxts) - assert.NoError(t, err) + segmentManager.ExpireAllocations("ch1", maxts) segment = meta.GetHealthySegment(id) assert.NotNil(t, segment) assert.EqualValues(t, 0, len(segment.allocations)) @@ -456,7 +472,7 @@ func TestGetFlushableSegments(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) - ids, err := segmentManager.SealAllSegments(context.TODO(), collID, nil) + ids, err := segmentManager.SealAllSegments(context.TODO(), "c1", nil) assert.NoError(t, err) assert.EqualValues(t, 1, len(ids)) assert.EqualValues(t, allocations[0].SegmentID, ids[0]) @@ -750,6 +766,7 @@ func TestAllocationPool(t *testing.T) { } func TestSegmentManager_DropSegmentsOfChannel(t *testing.T) { + partitionID := int64(100) type fields struct { meta *meta segments []UniqueID @@ -772,15 +789,17 @@ func TestSegmentManager_DropSegmentsOfChannel(t *testing.T) { 1: { SegmentInfo: &datapb.SegmentInfo{ ID: 1, + PartitionID: partitionID, InsertChannel: "ch1", - State: commonpb.SegmentState_Flushed, + State: commonpb.SegmentState_Sealed, }, }, 2: { SegmentInfo: &datapb.SegmentInfo{ ID: 2, + PartitionID: partitionID, InsertChannel: "ch2", - State: commonpb.SegmentState_Flushed, + State: commonpb.SegmentState_Growing, }, }, }, @@ -802,13 +821,15 @@ func TestSegmentManager_DropSegmentsOfChannel(t *testing.T) { 1: { SegmentInfo: &datapb.SegmentInfo{ ID: 1, + PartitionID: partitionID, InsertChannel: "ch1", - State: commonpb.SegmentState_Dropped, + State: commonpb.SegmentState_Sealed, }, }, 2: { SegmentInfo: &datapb.SegmentInfo{ ID: 2, + PartitionID: partitionID, InsertChannel: "ch2", State: commonpb.SegmentState_Growing, }, @@ -827,11 +848,36 @@ func TestSegmentManager_DropSegmentsOfChannel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &SegmentManager{ - meta: tt.fields.meta, - segments: tt.fields.segments, + meta: tt.fields.meta, + channelLock: lock.NewKeyLock[string](), + channel2Growing: typeutil.NewConcurrentMap[string, typeutil.UniqueSet](), + channel2Sealed: typeutil.NewConcurrentMap[string, typeutil.UniqueSet](), + } + for _, segmentID := range tt.fields.segments { + segmentInfo := tt.fields.meta.GetSegment(segmentID) + channel := tt.args.channel + if segmentInfo != nil { + channel = segmentInfo.GetInsertChannel() + } + if segmentInfo == nil || segmentInfo.GetState() == commonpb.SegmentState_Growing { + growing, _ := s.channel2Growing.GetOrInsert(channel, typeutil.NewUniqueSet()) + growing.Insert(segmentID) + } else if segmentInfo.GetState() == commonpb.SegmentState_Sealed { + sealed, _ := s.channel2Sealed.GetOrInsert(channel, typeutil.NewUniqueSet()) + sealed.Insert(segmentID) + } } s.DropSegmentsOfChannel(context.TODO(), tt.args.channel) - assert.ElementsMatch(t, tt.want, s.segments) + all := make([]int64, 0) + s.channel2Sealed.Range(func(_ string, segments typeutil.UniqueSet) bool { + all = append(all, segments.Collect()...) + return true + }) + s.channel2Growing.Range(func(_ string, segments typeutil.UniqueSet) bool { + all = append(all, segments.Collect()...) + return true + }) + assert.ElementsMatch(t, tt.want, all) }) } } diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 501aaa7c9c9d8..50ad0d78e9d36 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -822,52 +822,10 @@ func TestServer_getSystemInfoMetrics(t *testing.T) { } } -type spySegmentManager struct { - spyCh chan struct{} -} - -// AllocSegment allocates rows and record the allocation. -func (s *spySegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID, partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error) { - panic("not implemented") // TODO: Implement -} - -func (s *spySegmentManager) allocSegmentForImport(ctx context.Context, collectionID UniqueID, partitionID UniqueID, channelName string, requestRows int64, taskID int64) (*Allocation, error) { - panic("not implemented") // TODO: Implement -} - -// DropSegment drops the segment from manager. -func (s *spySegmentManager) DropSegment(ctx context.Context, segmentID UniqueID) { -} - -// FlushImportSegments set importing segment state to Flushed. -func (s *spySegmentManager) FlushImportSegments(ctx context.Context, collectionID UniqueID, segmentIDs []UniqueID) error { - panic("not implemented") -} - -// SealAllSegments seals all segments of collection with collectionID and return sealed segments -func (s *spySegmentManager) SealAllSegments(ctx context.Context, collectionID UniqueID, segIDs []UniqueID) ([]UniqueID, error) { - panic("not implemented") // TODO: Implement -} - -// GetFlushableSegments returns flushable segment ids -func (s *spySegmentManager) GetFlushableSegments(ctx context.Context, channel string, ts Timestamp) ([]UniqueID, error) { - panic("not implemented") // TODO: Implement -} - -// ExpireAllocations notifies segment status to expire old allocations -func (s *spySegmentManager) ExpireAllocations(channel string, ts Timestamp) error { - panic("not implemented") // TODO: Implement -} - -// DropSegmentsOfChannel drops all segments in a channel -func (s *spySegmentManager) DropSegmentsOfChannel(ctx context.Context, channel string) { - s.spyCh <- struct{}{} -} - func TestDropVirtualChannel(t *testing.T) { t.Run("normal DropVirtualChannel", func(t *testing.T) { - spyCh := make(chan struct{}, 1) - svr := newTestServer(t, WithSegmentManager(&spySegmentManager{spyCh: spyCh})) + segmentManager := NewMockManager(t) + svr := newTestServer(t, WithSegmentManager(segmentManager)) defer closeTestServer(t, svr) @@ -994,12 +952,11 @@ func TestDropVirtualChannel(t *testing.T) { } req.Segments = append(req.Segments, seg2Drop) } + segmentManager.EXPECT().DropSegmentsOfChannel(mock.Anything, mock.Anything).Return() resp, err := svr.DropVirtualChannel(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - <-spyCh - // resend resp, err = svr.DropVirtualChannel(ctx, req) assert.NoError(t, err) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index fd46a5cb2e363..58c26f54b2d52 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -112,17 +112,19 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F } timeOfSeal, _ := tsoutil.ParseTS(ts) - sealedSegmentIDs, err := s.segmentManager.SealAllSegments(ctx, req.GetCollectionID(), req.GetSegmentIDs()) - if err != nil { - return &datapb.FlushResponse{ - Status: merr.Status(errors.Wrapf(err, "failed to flush collection %d", - req.GetCollectionID())), - }, nil - } - sealedSegmentsIDDict := make(map[UniqueID]bool) - for _, sealedSegmentID := range sealedSegmentIDs { - sealedSegmentsIDDict[sealedSegmentID] = true + + for _, channel := range coll.VChannelNames { + sealedSegmentIDs, err := s.segmentManager.SealAllSegments(ctx, channel, req.GetSegmentIDs()) + if err != nil { + return &datapb.FlushResponse{ + Status: merr.Status(errors.Wrapf(err, "failed to flush collection %d", + req.GetCollectionID())), + }, nil + } + for _, sealedSegmentID := range sealedSegmentIDs { + sealedSegmentsIDDict[sealedSegmentID] = true + } } segments := s.meta.GetSegmentsOfCollection(req.GetCollectionID()) @@ -167,7 +169,7 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F log.Info("flush response with segments", zap.Int64("collectionID", req.GetCollectionID()), - zap.Int64s("sealSegments", sealedSegmentIDs), + zap.Int64s("sealSegments", lo.Keys(sealedSegmentsIDDict)), zap.Int("flushedSegmentsCount", len(flushSegmentIDs)), zap.Time("timeOfSeal", timeOfSeal), zap.Uint64("flushTs", ts), @@ -177,7 +179,7 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F Status: merr.Success(), DbID: req.GetDbID(), CollectionID: req.GetCollectionID(), - SegmentIDs: sealedSegmentIDs, + SegmentIDs: lo.Keys(sealedSegmentsIDDict), TimeOfSeal: timeOfSeal.Unix(), FlushSegmentIDs: flushSegmentIDs, FlushTs: ts, @@ -507,10 +509,10 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath // Set segment state if req.GetDropped() { // segmentManager manages growing segments - s.segmentManager.DropSegment(ctx, req.GetSegmentID()) + s.segmentManager.DropSegment(ctx, req.GetChannel(), req.GetSegmentID()) operators = append(operators, UpdateStatusOperator(req.GetSegmentID(), commonpb.SegmentState_Dropped)) } else if req.GetFlushed() { - s.segmentManager.DropSegment(ctx, req.GetSegmentID()) + s.segmentManager.DropSegment(ctx, req.GetChannel(), req.GetSegmentID()) // set segment to SegmentState_Flushing operators = append(operators, UpdateStatusOperator(req.GetSegmentID(), commonpb.SegmentState_Flushing)) } @@ -1446,10 +1448,7 @@ func (s *Server) handleDataNodeTtMsg(ctx context.Context, ttMsg *msgpb.DataNodeT s.updateSegmentStatistics(segmentStats) - if err := s.segmentManager.ExpireAllocations(channel, ts); err != nil { - log.Warn("failed to expire allocations", zap.Error(err)) - return err - } + s.segmentManager.ExpireAllocations(channel, ts) flushableIDs, err := s.segmentManager.GetFlushableSegments(ctx, channel, ts) if err != nil { diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 90a2372b0c9bc..2b2d44e8ac3cd 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -72,7 +72,7 @@ func TestServerSuite(t *testing.T) { suite.Run(t, new(ServerSuite)) } -func genMsg(msgType commonpb.MsgType, ch string, t Timestamp, sourceID int64) *msgstream.DataNodeTtMsg { +func genMsg(msgType commonpb.MsgType, ch string, t Timestamp, sourceID int64, segmentID int64) *msgstream.DataNodeTtMsg { return &msgstream.DataNodeTtMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: []uint32{0}, @@ -85,7 +85,7 @@ func genMsg(msgType commonpb.MsgType, ch string, t Timestamp, sourceID int64) *m }, ChannelName: ch, Timestamp: t, - SegmentsStats: []*commonpb.SegmentStats{{SegmentID: 2, NumRows: 100}}, + SegmentsStats: []*commonpb.SegmentStats{{SegmentID: segmentID, NumRows: 100}}, }, } } @@ -123,7 +123,7 @@ func (s *ServerSuite) TestHandleDataNodeTtMsg() { s.Equal(1, len(segment.allocations)) ts := tsoutil.AddPhysicalDurationOnTs(assign.ExpireTime, -3*time.Minute) - msg := genMsg(commonpb.MsgType_DataNodeTt, chanName, ts, sourceID) + msg := genMsg(commonpb.MsgType_DataNodeTt, chanName, ts, sourceID, assign.GetSegID()) msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{ SegmentID: assign.GetSegID(), NumRows: 1, @@ -135,7 +135,7 @@ func (s *ServerSuite) TestHandleDataNodeTtMsg() { s.EqualValues(chanName, channel) s.EqualValues(sourceID, nodeID) s.Equal(1, len(segments)) - s.EqualValues(2, segments[0].GetID()) + s.EqualValues(assign.GetSegID(), segments[0].GetID()) return fmt.Errorf("mock error") }).Once() @@ -146,7 +146,7 @@ func (s *ServerSuite) TestHandleDataNodeTtMsg() { s.NoError(err) tt := tsoutil.AddPhysicalDurationOnTs(assign.ExpireTime, 48*time.Hour) - msg = genMsg(commonpb.MsgType_DataNodeTt, chanName, tt, sourceID) + msg = genMsg(commonpb.MsgType_DataNodeTt, chanName, tt, sourceID, assign.GetSegID()) msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{ SegmentID: assign.GetSegID(), NumRows: 1, @@ -162,9 +162,10 @@ func (s *ServerSuite) initSuiteForTtChannel() { s.testServer.startDataNodeTtLoop(s.testServer.serverLoopCtx) s.testServer.meta.AddCollection(&collectionInfo{ - ID: 1, - Schema: newTestSchema(), - Partitions: []int64{10}, + ID: 1, + Schema: newTestSchema(), + Partitions: []int64{10}, + VChannelNames: []string{"ch-1", "ch-2"}, }) } @@ -184,19 +185,6 @@ func (s *ServerSuite) TestDataNodeTtChannel_ExpireAfterTt() { signal = make(chan struct{}) collID int64 = 1 ) - mockCluster := NewMockCluster(s.T()) - mockCluster.EXPECT().Close().Once() - mockCluster.EXPECT().Flush(mock.Anything, sourceID, chanName, mock.Anything).RunAndReturn( - func(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error { - s.EqualValues(chanName, channel) - s.EqualValues(sourceID, nodeID) - s.Equal(1, len(segments)) - s.EqualValues(2, segments[0].GetID()) - - signal <- struct{}{} - return nil - }).Once() - s.testServer.cluster = mockCluster s.mockChMgr.EXPECT().Match(sourceID, chanName).Return(true).Once() resp, err := s.testServer.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ @@ -219,14 +207,32 @@ func (s *ServerSuite) TestDataNodeTtChannel_ExpireAfterTt() { s.Require().NotNil(segment) s.Equal(1, len(segment.allocations)) + mockCluster := NewMockCluster(s.T()) + mockCluster.EXPECT().Close().Once() + mockCluster.EXPECT().Flush(mock.Anything, sourceID, chanName, mock.Anything).RunAndReturn( + func(ctx context.Context, nodeID int64, channel string, segments []*datapb.SegmentInfo) error { + s.EqualValues(chanName, channel) + s.EqualValues(sourceID, nodeID) + s.Equal(1, len(segments)) + s.EqualValues(assignedSegmentID, segments[0].GetID()) + + signal <- struct{}{} + return nil + }).Once() + s.testServer.cluster = mockCluster + msgPack := msgstream.MsgPack{} tt := tsoutil.AddPhysicalDurationOnTs(resp.SegIDAssignments[0].ExpireTime, 48*time.Hour) - msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", tt, sourceID) + msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", tt, sourceID, assignedSegmentID) msgPack.Msgs = append(msgPack.Msgs, msg) err = ttMsgStream.Produce(&msgPack) s.Require().NoError(err) - <-signal + select { + case <-signal: + case <-time.After(10 * time.Second): + s.Fail("test timeout") + } segment = s.testServer.meta.GetHealthySegment(assignedSegmentID) s.NotNil(segment) s.Equal(0, len(segment.allocations)) @@ -308,7 +314,7 @@ func (s *ServerSuite) TestDataNodeTtChannel_FlushWithDiffChan() { s.Require().True(merr.Ok(resp2.GetStatus())) msgPack := msgstream.MsgPack{} - msg := genMsg(commonpb.MsgType_DataNodeTt, chanName, assign.ExpireTime, sourceID) + msg := genMsg(commonpb.MsgType_DataNodeTt, chanName, assign.ExpireTime, sourceID, assign.GetSegID()) msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{ SegmentID: assign.GetSegID(), NumRows: 1, @@ -317,7 +323,11 @@ func (s *ServerSuite) TestDataNodeTtChannel_FlushWithDiffChan() { err = ttMsgStream.Produce(&msgPack) s.NoError(err) - <-signal + select { + case <-signal: + case <-time.After(10 * time.Second): + s.Fail("test timeout") + } } func (s *ServerSuite) TestDataNodeTtChannel_SegmentFlushAfterTt() { @@ -381,7 +391,7 @@ func (s *ServerSuite) TestDataNodeTtChannel_SegmentFlushAfterTt() { s.Require().True(merr.Ok(resp2.GetStatus())) msgPack := msgstream.MsgPack{} - msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime, 9999) + msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime, 9999, assign.GetSegID()) msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{ SegmentID: assign.GetSegID(), NumRows: 1, @@ -773,7 +783,7 @@ func (s *ServerSuite) TestFlush_NormalCase() { s.testServer.cluster = mockCluster schema := newTestSchema() - s.testServer.meta.AddCollection(&collectionInfo{ID: 0, Schema: schema, Partitions: []int64{}}) + s.testServer.meta.AddCollection(&collectionInfo{ID: 0, Schema: schema, Partitions: []int64{}, VChannelNames: []string{"channel-1"}}) allocations, err := s.testServer.segmentManager.AllocSegment(context.TODO(), 0, 1, "channel-1", 1) s.NoError(err) s.EqualValues(1, len(allocations))