diff --git a/internal/datacoord/handler.go b/internal/datacoord/handler.go index b192f3e98d1b1..a16824790627d 100644 --- a/internal/datacoord/handler.go +++ b/internal/datacoord/handler.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/retry" @@ -42,6 +43,15 @@ type Handler interface { CheckShouldDropChannel(ch string) bool FinishDropChannel(ch string, collectionID int64) error GetCollection(ctx context.Context, collectionID UniqueID) (*collectionInfo, error) + GetCurrentSegmentsView(ctx context.Context, channel RWChannel, partitionIDs ...UniqueID) *SegmentsView +} + +type SegmentsView struct { + FlushedSegmentIDs []int64 + GrowingSegmentIDs []int64 + DroppedSegmentIDs []int64 + L0SegmentIDs []int64 + ImportingSegmentIDs []int64 } // ServerHandler is a helper of Server @@ -107,27 +117,26 @@ func (h *ServerHandler) GetDataVChanPositions(channel RWChannel, partitionID Uni // dropped segmentIDs ---> dropped segments // level zero segmentIDs ---> L0 segments func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs ...UniqueID) *datapb.VchannelInfo { - partStatsVersionsMap := make(map[int64]int64) validPartitions := lo.Filter(partitionIDs, func(partitionID int64, _ int) bool { return partitionID > allPartitionID }) - if len(validPartitions) <= 0 { - collInfo, err := h.s.handler.GetCollection(h.s.ctx, channel.GetCollectionID()) - if err != nil || collInfo == nil { - log.Warn("collectionInfo is nil") - return nil + filterWithPartition := len(validPartitions) > 0 + validPartitionsMap := make(map[int64]bool) + partStatsVersions := h.s.meta.partitionStatsMeta.GetChannelPartitionsStatsVersion(channel.GetCollectionID(), channel.GetName()) + partStatsVersionsMap := make(map[int64]int64) + if filterWithPartition { + for _, partitionID := range validPartitions { + partStatsVersionsMap[partitionID] = partStatsVersions[partitionID] + validPartitionsMap[partitionID] = true } - validPartitions = collInfo.Partitions - } - for _, partitionID := range validPartitions { - currentPartitionStatsVersion := h.s.meta.partitionStatsMeta.GetCurrentPartitionStatsVersion(channel.GetCollectionID(), partitionID, channel.GetName()) - partStatsVersionsMap[partitionID] = currentPartitionStatsVersion + validPartitionsMap[common.AllPartitionsID] = true + } else { + partStatsVersionsMap = partStatsVersions } var ( - flushedIDs = make(typeutil.UniqueSet) - droppedIDs = make(typeutil.UniqueSet) - growingIDs = make(typeutil.UniqueSet) - levelZeroIDs = make(typeutil.UniqueSet) - newFlushedIDs = make(typeutil.UniqueSet) + flushedIDs = make(typeutil.UniqueSet) + droppedIDs = make(typeutil.UniqueSet) + growingIDs = make(typeutil.UniqueSet) + levelZeroIDs = make(typeutil.UniqueSet) ) // cannot use GetSegmentsByChannel since dropped segments are needed here @@ -138,6 +147,9 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . indexed := typeutil.NewUniqueSet(lo.Map(indexedSegments, func(segment *SegmentInfo, _ int) int64 { return segment.GetID() })...) for _, s := range segments { + if filterWithPartition && !validPartitionsMap[s.GetPartitionID()] { + continue + } if s.GetStartPosition() == nil && s.GetDmlPosition() == nil { continue } @@ -182,6 +194,41 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . // Retrieve unIndexed expected result: // unIndexed: c, d // ================================================ + + segmentIndexed := func(segID UniqueID) bool { + return indexed.Contain(segID) || validSegmentInfos[segID].GetNumOfRows() < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() + } + + flushedIDs, droppedIDs = retrieveSegment(validSegmentInfos, flushedIDs, droppedIDs, segmentIndexed) + + log.Info("GetQueryVChanPositions", + zap.Int64("collectionID", channel.GetCollectionID()), + zap.String("channel", channel.GetName()), + zap.Int("numOfSegments", len(segments)), + zap.Int("result flushed", len(flushedIDs)), + zap.Int("result growing", len(growingIDs)), + zap.Int("result L0", len(levelZeroIDs)), + zap.Any("partition stats", partStatsVersionsMap), + ) + + return &datapb.VchannelInfo{ + CollectionID: channel.GetCollectionID(), + ChannelName: channel.GetName(), + SeekPosition: h.GetChannelSeekPosition(channel, partitionIDs...), + FlushedSegmentIds: flushedIDs.Collect(), + UnflushedSegmentIds: growingIDs.Collect(), + DroppedSegmentIds: droppedIDs.Collect(), + LevelZeroSegmentIds: levelZeroIDs.Collect(), + PartitionStatsVersions: partStatsVersionsMap, + } +} + +func retrieveSegment(validSegmentInfos map[int64]*SegmentInfo, + flushedIDs, droppedIDs typeutil.UniqueSet, + segmentIndexed func(segID UniqueID) bool, +) (typeutil.UniqueSet, typeutil.UniqueSet) { + newFlushedIDs := make(typeutil.UniqueSet) + isValid := func(ids ...UniqueID) bool { for _, id := range ids { if seg, ok := validSegmentInfos[id]; !ok || seg == nil || seg.GetIsInvisible() { @@ -192,7 +239,6 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . } var compactionFromExist func(segID UniqueID) bool - compactionFromExist = func(segID UniqueID) bool { compactionFrom := validSegmentInfos[segID].GetCompactionFrom() if len(compactionFrom) == 0 || !isValid(compactionFrom...) { @@ -209,10 +255,6 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . return false } - segmentIndexed := func(segID UniqueID) bool { - return indexed.Contain(segID) || validSegmentInfos[segID].GetNumOfRows() < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() - } - retrieve := func() bool { continueRetrieve := false for id := range flushedIDs { @@ -239,27 +281,73 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . newFlushedIDs = make(typeutil.UniqueSet) } - flushedIDs = newFlushedIDs + return newFlushedIDs, droppedIDs +} - log.Info("GetQueryVChanPositions", +func (h *ServerHandler) GetCurrentSegmentsView(ctx context.Context, channel RWChannel, partitionIDs ...UniqueID) *SegmentsView { + validPartitions := lo.Filter(partitionIDs, func(partitionID int64, _ int) bool { return partitionID > allPartitionID }) + filterWithPartition := len(validPartitions) > 0 + validPartitionsMap := make(map[int64]bool) + validPartitionsMap[common.AllPartitionsID] = true + for _, partitionID := range validPartitions { + validPartitionsMap[partitionID] = true + } + + var ( + flushedIDs = make(typeutil.UniqueSet) + droppedIDs = make(typeutil.UniqueSet) + growingIDs = make(typeutil.UniqueSet) + importingIDs = make(typeutil.UniqueSet) + levelZeroIDs = make(typeutil.UniqueSet) + ) + + // cannot use GetSegmentsByChannel since dropped segments are needed here + segments := h.s.meta.GetRealSegmentsForChannel(channel.GetName()) + + validSegmentInfos := make(map[int64]*SegmentInfo) + for _, s := range segments { + if filterWithPartition && !validPartitionsMap[s.GetPartitionID()] { + continue + } + if s.GetStartPosition() == nil && s.GetDmlPosition() == nil { + continue + } + + validSegmentInfos[s.GetID()] = s + switch { + case s.GetState() == commonpb.SegmentState_Dropped: + droppedIDs.Insert(s.GetID()) + case s.GetState() == commonpb.SegmentState_Importing: + importingIDs.Insert(s.GetID()) + case s.GetLevel() == datapb.SegmentLevel_L0: + levelZeroIDs.Insert(s.GetID()) + case s.GetState() == commonpb.SegmentState_Growing: + growingIDs.Insert(s.GetID()) + default: + flushedIDs.Insert(s.GetID()) + } + } + + flushedIDs, droppedIDs = retrieveSegment(validSegmentInfos, flushedIDs, droppedIDs, func(segID UniqueID) bool { + return true + }) + + log.Ctx(ctx).Info("GetCurrentSegmentsView", zap.Int64("collectionID", channel.GetCollectionID()), zap.String("channel", channel.GetName()), zap.Int("numOfSegments", len(segments)), zap.Int("result flushed", len(flushedIDs)), zap.Int("result growing", len(growingIDs)), + zap.Int("result importing", len(importingIDs)), zap.Int("result L0", len(levelZeroIDs)), - zap.Any("partition stats", partStatsVersionsMap), ) - return &datapb.VchannelInfo{ - CollectionID: channel.GetCollectionID(), - ChannelName: channel.GetName(), - SeekPosition: h.GetChannelSeekPosition(channel, partitionIDs...), - FlushedSegmentIds: flushedIDs.Collect(), - UnflushedSegmentIds: growingIDs.Collect(), - DroppedSegmentIds: droppedIDs.Collect(), - LevelZeroSegmentIds: levelZeroIDs.Collect(), - PartitionStatsVersions: partStatsVersionsMap, + return &SegmentsView{ + FlushedSegmentIDs: flushedIDs.Collect(), + GrowingSegmentIDs: growingIDs.Collect(), + DroppedSegmentIDs: droppedIDs.Collect(), + L0SegmentIDs: levelZeroIDs.Collect(), + ImportingSegmentIDs: importingIDs.Collect(), } } diff --git a/internal/datacoord/handler_test.go b/internal/datacoord/handler_test.go index de202ca37c307..9774ddf951252 100644 --- a/internal/datacoord/handler_test.go +++ b/internal/datacoord/handler_test.go @@ -30,8 +30,9 @@ func TestGetQueryVChanPositionsRetrieveM2N(t *testing.T) { channel := "ch1" svr.meta.AddCollection(&collectionInfo{ - ID: 1, - Schema: schema, + ID: 1, + Partitions: []int64{0}, + Schema: schema, StartPositions: []*commonpb.KeyDataPair{ { Key: channel, @@ -130,8 +131,9 @@ func TestGetQueryVChanPositions(t *testing.T) { defer closeTestServer(t, svr) schema := newTestSchema() svr.meta.AddCollection(&collectionInfo{ - ID: 0, - Schema: schema, + ID: 0, + Partitions: []int64{0, 1}, + Schema: schema, StartPositions: []*commonpb.KeyDataPair{ { Key: "ch1", @@ -302,14 +304,22 @@ func TestGetQueryVChanPositions_PartitionStats(t *testing.T) { version: {Version: version}, }, }, + partitionID + 1: { + currentVersion: version + 1, + infos: map[int64]*datapb.PartitionStatsInfo{ + version + 1: {Version: version + 1}, + }, + }, }, } - partitionIDs := make([]UniqueID, 0) - partitionIDs = append(partitionIDs, partitionID) - vChannelInfo := svr.handler.GetQueryVChanPositions(&channelMeta{Name: vchannel, CollectionID: collectionID}, partitionIDs...) + vChannelInfo := svr.handler.GetQueryVChanPositions(&channelMeta{Name: vchannel, CollectionID: collectionID}, partitionID) statsVersions := vChannelInfo.GetPartitionStatsVersions() assert.Equal(t, 1, len(statsVersions)) assert.Equal(t, int64(100), statsVersions[partitionID]) + + vChannelInfo2 := svr.handler.GetQueryVChanPositions(&channelMeta{Name: vchannel, CollectionID: collectionID}) + statsVersions2 := vChannelInfo2.GetPartitionStatsVersions() + assert.Equal(t, 2, len(statsVersions2)) } func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { @@ -583,8 +593,9 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { defer closeTestServer(t, svr) schema := newTestSchema() svr.meta.AddCollection(&collectionInfo{ - ID: 0, - Schema: schema, + ID: 0, + Partitions: []int64{0}, + Schema: schema, }) err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{ TenantID: "", @@ -959,6 +970,178 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { }) } +func TestGetCurrentSegmentsView(t *testing.T) { + svr := newTestServer(t) + defer closeTestServer(t, svr) + schema := newTestSchema() + svr.meta.AddCollection(&collectionInfo{ + ID: 0, + Partitions: []int64{0}, + Schema: schema, + }) + err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{ + TenantID: "", + CollectionID: 0, + FieldID: 2, + IndexID: 1, + }) + assert.NoError(t, err) + seg1 := &datapb.SegmentInfo{ + ID: 1, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) + assert.NoError(t, err) + seg2 := &datapb.SegmentInfo{ + ID: 2, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{1}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) + assert.NoError(t, err) + seg3 := &datapb.SegmentInfo{ + ID: 3, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg3)) + assert.NoError(t, err) + seg4 := &datapb.SegmentInfo{ + ID: 4, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg4)) + assert.NoError(t, err) + seg5 := &datapb.SegmentInfo{ + ID: 5, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{3, 4}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5)) + assert.NoError(t, err) + seg6 := &datapb.SegmentInfo{ + ID: 6, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{3, 4}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg6)) + assert.NoError(t, err) + seg7 := &datapb.SegmentInfo{ + ID: 7, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + Level: datapb.SegmentLevel_L0, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg7)) + assert.NoError(t, err) + seg8 := &datapb.SegmentInfo{ + ID: 8, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Growing, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg8)) + assert.NoError(t, err) + seg9 := &datapb.SegmentInfo{ + ID: 9, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Importing, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg9)) + assert.NoError(t, err) + + view := svr.handler.GetCurrentSegmentsView(context.Background(), &channelMeta{Name: "ch1", CollectionID: 0}) + assert.ElementsMatch(t, []int64{2, 3, 4}, view.FlushedSegmentIDs) + assert.ElementsMatch(t, []int64{8}, view.GrowingSegmentIDs) + assert.ElementsMatch(t, []int64{1}, view.DroppedSegmentIDs) + assert.ElementsMatch(t, []int64{7}, view.L0SegmentIDs) + assert.ElementsMatch(t, []int64{9}, view.ImportingSegmentIDs) +} + func TestShouldDropChannel(t *testing.T) { type myRootCoord struct { mocks2.MockRootCoordClient diff --git a/internal/datacoord/mock_handler.go b/internal/datacoord/mock_handler.go index 20f84a8f53094..7c3ec969e8a48 100644 --- a/internal/datacoord/mock_handler.go +++ b/internal/datacoord/mock_handler.go @@ -174,6 +174,70 @@ func (_c *NMockHandler_GetCollection_Call) RunAndReturn(run func(context.Context return _c } +// GetCurrentSegmentsView provides a mock function with given fields: ctx, channel, partitionIDs +func (_m *NMockHandler) GetCurrentSegmentsView(ctx context.Context, channel RWChannel, partitionIDs ...int64) *SegmentsView { + _va := make([]interface{}, len(partitionIDs)) + for _i := range partitionIDs { + _va[_i] = partitionIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, channel) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetCurrentSegmentsView") + } + + var r0 *SegmentsView + if rf, ok := ret.Get(0).(func(context.Context, RWChannel, ...int64) *SegmentsView); ok { + r0 = rf(ctx, channel, partitionIDs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*SegmentsView) + } + } + + return r0 +} + +// NMockHandler_GetCurrentSegmentsView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCurrentSegmentsView' +type NMockHandler_GetCurrentSegmentsView_Call struct { + *mock.Call +} + +// GetCurrentSegmentsView is a helper method to define mock.On call +// - ctx context.Context +// - channel RWChannel +// - partitionIDs ...int64 +func (_e *NMockHandler_Expecter) GetCurrentSegmentsView(ctx interface{}, channel interface{}, partitionIDs ...interface{}) *NMockHandler_GetCurrentSegmentsView_Call { + return &NMockHandler_GetCurrentSegmentsView_Call{Call: _e.mock.On("GetCurrentSegmentsView", + append([]interface{}{ctx, channel}, partitionIDs...)...)} +} + +func (_c *NMockHandler_GetCurrentSegmentsView_Call) Run(run func(ctx context.Context, channel RWChannel, partitionIDs ...int64)) *NMockHandler_GetCurrentSegmentsView_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]int64, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(int64) + } + } + run(args[0].(context.Context), args[1].(RWChannel), variadicArgs...) + }) + return _c +} + +func (_c *NMockHandler_GetCurrentSegmentsView_Call) Return(_a0 *SegmentsView) *NMockHandler_GetCurrentSegmentsView_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *NMockHandler_GetCurrentSegmentsView_Call) RunAndReturn(run func(context.Context, RWChannel, ...int64) *SegmentsView) *NMockHandler_GetCurrentSegmentsView_Call { + _c.Call.Return(run) + return _c +} + // GetDataVChanPositions provides a mock function with given fields: ch, partitionID func (_m *NMockHandler) GetDataVChanPositions(ch RWChannel, partitionID int64) *datapb.VchannelInfo { ret := _m.Called(ch, partitionID) diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index 963ec0ac73cf0..f3f1890a9c972 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -763,6 +763,10 @@ func (h *mockHandler) GetCollection(_ context.Context, collectionID UniqueID) (* return &collectionInfo{ID: collectionID}, nil } +func (h *mockHandler) GetCurrentSegmentsView(ctx context.Context, channel RWChannel, partitionIDs ...UniqueID) *SegmentsView { + return nil +} + func newMockHandlerWithMeta(meta *meta) *mockHandler { return &mockHandler{ meta: meta, diff --git a/internal/datacoord/partition_stats_meta.go b/internal/datacoord/partition_stats_meta.go index ce762fae052a7..1429106e719b9 100644 --- a/internal/datacoord/partition_stats_meta.go +++ b/internal/datacoord/partition_stats_meta.go @@ -202,3 +202,16 @@ func (psm *partitionStatsMeta) GetPartitionStats(collectionID, partitionID int64 } return psm.partitionStatsInfos[vChannel][partitionID].infos[version] } + +func (psm *partitionStatsMeta) GetChannelPartitionsStatsVersion(collectionID int64, vChannel string) map[int64]int64 { + psm.RLock() + defer psm.RUnlock() + + result := make(map[int64]int64) + partitionsStats := psm.partitionStatsInfos[vChannel] + for partitionID, info := range partitionsStats { + result[partitionID] = info.currentVersion + } + + return result +} diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 19bf9c92d5cd9..4a9499dca2278 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -18,6 +18,7 @@ package datacoord import ( "context" + "fmt" "math/rand" "os" "os/signal" @@ -574,6 +575,17 @@ func TestGetSegmentsByStates(t *testing.T) { t.Run("normal case", func(t *testing.T) { svr := newTestServer(t) defer closeTestServer(t, svr) + channelManager := NewMockChannelManager(t) + channelName := "ch" + channelManager.EXPECT().GetChannelsByCollectionID(mock.Anything).RunAndReturn(func(id int64) []RWChannel { + return []RWChannel{ + &channelMeta{ + Name: channelName + fmt.Sprint(id), + CollectionID: id, + }, + } + }).Maybe() + svr.channelManager = channelManager type testCase struct { collID int64 partID int64 @@ -622,31 +634,92 @@ func TestGetSegmentsByStates(t *testing.T) { expected: []int64{9, 10}, }, } + svr.meta.AddCollection(&collectionInfo{ + ID: 1, + Partitions: []int64{1, 2}, + Schema: nil, + StartPositions: []*commonpb.KeyDataPair{ + { + Key: "ch1", + Data: []byte{8, 9, 10}, + }, + }, + }) + svr.meta.AddCollection(&collectionInfo{ + ID: 2, + Partitions: []int64{3}, + Schema: nil, + StartPositions: []*commonpb.KeyDataPair{ + { + Key: "ch1", + Data: []byte{8, 9, 10}, + }, + }, + }) for _, tc := range cases { for _, fs := range tc.flushedSegments { segInfo := &datapb.SegmentInfo{ - ID: fs, - CollectionID: tc.collID, - PartitionID: tc.partID, - State: commonpb.SegmentState_Flushed, + ID: fs, + CollectionID: tc.collID, + PartitionID: tc.partID, + InsertChannel: channelName + fmt.Sprint(tc.collID), + State: commonpb.SegmentState_Flushed, + NumOfRows: 1024, + StartPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{8, 9, 10}, + MsgGroup: "", + }, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{11, 12, 13}, + MsgGroup: "", + Timestamp: 2, + }, } assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } for _, us := range tc.sealedSegments { segInfo := &datapb.SegmentInfo{ - ID: us, - CollectionID: tc.collID, - PartitionID: tc.partID, - State: commonpb.SegmentState_Sealed, + ID: us, + CollectionID: tc.collID, + PartitionID: tc.partID, + InsertChannel: channelName + fmt.Sprint(tc.collID), + State: commonpb.SegmentState_Sealed, + NumOfRows: 1024, + StartPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{8, 9, 10}, + MsgGroup: "", + }, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{11, 12, 13}, + MsgGroup: "", + Timestamp: 2, + }, } assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } for _, us := range tc.growingSegments { segInfo := &datapb.SegmentInfo{ - ID: us, - CollectionID: tc.collID, - PartitionID: tc.partID, - State: commonpb.SegmentState_Growing, + ID: us, + CollectionID: tc.collID, + PartitionID: tc.partID, + InsertChannel: channelName + fmt.Sprint(tc.collID), + State: commonpb.SegmentState_Growing, + NumOfRows: 1024, + StartPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{8, 9, 10}, + MsgGroup: "", + }, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{11, 12, 13}, + MsgGroup: "", + Timestamp: 2, + }, } assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 4eacf2fd54d34..4e4f222ed49c9 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1039,10 +1039,16 @@ func (s *Server) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegment }, nil } var segmentIDs []UniqueID - if partitionID < 0 { - segmentIDs = s.meta.GetSegmentsIDOfCollection(ctx, collectionID) - } else { - segmentIDs = s.meta.GetSegmentsIDOfPartition(ctx, collectionID, partitionID) + channels := s.channelManager.GetChannelsByCollectionID(collectionID) + for _, channel := range channels { + channelSegmentsView := s.handler.GetCurrentSegmentsView(ctx, channel, partitionID) + if channelSegmentsView == nil { + continue + } + segmentIDs = append(segmentIDs, channelSegmentsView.FlushedSegmentIDs...) + segmentIDs = append(segmentIDs, channelSegmentsView.GrowingSegmentIDs...) + segmentIDs = append(segmentIDs, channelSegmentsView.L0SegmentIDs...) + segmentIDs = append(segmentIDs, channelSegmentsView.ImportingSegmentIDs...) } ret := make([]UniqueID, 0, len(segmentIDs)) diff --git a/internal/datacoord/sync_segments_scheduler.go b/internal/datacoord/sync_segments_scheduler.go index 849b41a4460cb..7f672e4f1bad0 100644 --- a/internal/datacoord/sync_segments_scheduler.go +++ b/internal/datacoord/sync_segments_scheduler.go @@ -105,61 +105,65 @@ func (sss *SyncSegmentsScheduler) SyncSegmentsForCollections(ctx context.Context zap.String("channelName", channelName), zap.Error(err)) continue } - for _, partitionID := range collInfo.Partitions { - if err := sss.SyncSegments(ctx, collID, partitionID, channelName, nodeID, pkField.GetFieldID()); err != nil { - log.Warn("sync segment with channel failed, retry next ticker", - zap.Int64("collectionID", collID), - zap.Int64("partitionID", partitionID), - zap.String("channel", channelName), - zap.Error(err)) - continue - } + if err := sss.SyncSegments(ctx, collID, channelName, nodeID, pkField.GetFieldID()); err != nil { + log.Warn("sync segment with channel failed, retry next ticker", + zap.Int64("collectionID", collID), + zap.String("channel", channelName), + zap.Error(err)) + continue } } } } -func (sss *SyncSegmentsScheduler) SyncSegments(ctx context.Context, collectionID, partitionID int64, channelName string, nodeID, pkFieldID int64) error { - log := log.With(zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), +func (sss *SyncSegmentsScheduler) SyncSegments(ctx context.Context, collectionID int64, channelName string, nodeID, pkFieldID int64) error { + log := log.With(zap.Int64("collectionID", collectionID), zap.String("channelName", channelName), zap.Int64("nodeID", nodeID)) // sync all healthy segments, but only check flushed segments on datanode. Because L0 growing segments may not in datacoord's meta. // upon receiving the SyncSegments request, the datanode's segment state may have already transitioned from Growing/Flushing // to Flushed, so the view must include this segment. - segments := sss.meta.SelectSegments(ctx, WithChannel(channelName), SegmentFilterFunc(func(info *SegmentInfo) bool { - return info.GetPartitionID() == partitionID && info.GetLevel() != datapb.SegmentLevel_L0 && isSegmentHealthy(info) + channelSegments := sss.meta.SelectSegments(ctx, WithChannel(channelName), SegmentFilterFunc(func(info *SegmentInfo) bool { + return info.GetLevel() != datapb.SegmentLevel_L0 && isSegmentHealthy(info) })) - req := &datapb.SyncSegmentsRequest{ - ChannelName: channelName, - PartitionId: partitionID, - CollectionId: collectionID, - SegmentInfos: make(map[int64]*datapb.SyncSegmentInfo), - } - for _, seg := range segments { - req.SegmentInfos[seg.ID] = &datapb.SyncSegmentInfo{ - SegmentId: seg.GetID(), - State: seg.GetState(), - Level: seg.GetLevel(), - NumOfRows: seg.GetNumOfRows(), + partitionSegments := lo.GroupBy(channelSegments, func(segment *SegmentInfo) int64 { + return segment.GetPartitionID() + }) + for partitionID, segments := range partitionSegments { + req := &datapb.SyncSegmentsRequest{ + ChannelName: channelName, + PartitionId: partitionID, + CollectionId: collectionID, + SegmentInfos: make(map[int64]*datapb.SyncSegmentInfo), } - statsLogs := make([]*datapb.Binlog, 0) - for _, statsLog := range seg.GetStatslogs() { - if statsLog.GetFieldID() == pkFieldID { - statsLogs = append(statsLogs, statsLog.GetBinlogs()...) + + for _, seg := range segments { + req.SegmentInfos[seg.ID] = &datapb.SyncSegmentInfo{ + SegmentId: seg.GetID(), + State: seg.GetState(), + Level: seg.GetLevel(), + NumOfRows: seg.GetNumOfRows(), + } + statsLogs := make([]*datapb.Binlog, 0) + for _, statsLog := range seg.GetStatslogs() { + if statsLog.GetFieldID() == pkFieldID { + statsLogs = append(statsLogs, statsLog.GetBinlogs()...) + } + } + req.SegmentInfos[seg.ID].PkStatsLog = &datapb.FieldBinlog{ + FieldID: pkFieldID, + Binlogs: statsLogs, } } - req.SegmentInfos[seg.ID].PkStatsLog = &datapb.FieldBinlog{ - FieldID: pkFieldID, - Binlogs: statsLogs, + + if err := sss.sessions.SyncSegments(ctx, nodeID, req); err != nil { + log.Warn("fail to sync segments with node", zap.Error(err)) + return err } + log.Info("sync segments success", zap.Int64("partitionID", partitionID), zap.Int64s("segments", lo.Map(segments, func(t *SegmentInfo, i int) int64 { + return t.GetID() + }))) } - if err := sss.sessions.SyncSegments(ctx, nodeID, req); err != nil { - log.Warn("fail to sync segments with node", zap.Error(err)) - return err - } - log.Info("sync segments success", zap.Int64s("segments", lo.Map(segments, func(t *SegmentInfo, i int) int64 { - return t.GetID() - }))) return nil }