From 2f5637471cc4485e426e9594bf227a50ea0be748 Mon Sep 17 00:00:00 2001 From: SimFG Date: Tue, 3 Dec 2024 10:25:38 +0800 Subject: [PATCH] enhance: add unit test case for the replicate message feature Signed-off-by: SimFG --- internal/datacoord/channel.go | 2 +- internal/proto/root_coord.proto | 1 + internal/proxy/impl_test.go | 6 +- internal/proxy/meta_cache_test.go | 81 ++++++++ internal/proxy/task.go | 24 +-- internal/proxy/task_database.go | 28 ++- internal/proxy/task_database_test.go | 158 ++++++++++++++ internal/proxy/task_insert.go | 2 +- internal/proxy/task_test.go | 93 ++++++--- internal/querynodev2/pipeline/manager_test.go | 1 - .../querynodev2/pipeline/pipeline_test.go | 17 +- .../rootcoord/alter_collection_task_test.go | 41 +++- .../rootcoord/alter_database_task_test.go | 130 +++++++++++- .../rootcoord/create_collection_task_test.go | 64 ++++++ internal/rootcoord/mock_test.go | 25 +++ internal/rootcoord/root_coord.go | 13 ++ internal/rootcoord/root_coord_test.go | 26 +++ pkg/mq/msgdispatcher/dispatcher_test.go | 196 ++++++++++++++++++ pkg/mq/msgstream/mq_msgstream.go | 14 -- pkg/mq/msgstream/mq_msgstream_test.go | 18 ++ pkg/mq/msgstream/msgstream.go | 7 +- pkg/mq/msgstream/msgstream_util_test.go | 89 ++++++++ pkg/util/merr/errors.go | 2 +- 23 files changed, 944 insertions(+), 94 deletions(-) diff --git a/internal/datacoord/channel.go b/internal/datacoord/channel.go index d222d274d0222..0a80b43a0745f 100644 --- a/internal/datacoord/channel.go +++ b/internal/datacoord/channel.go @@ -113,7 +113,7 @@ func (ch *channelMeta) String() string { return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions) } -func (channelMeta) GetDBProperties() []*commonpb.KeyValuePair { +func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair { return nil } diff --git a/internal/proto/root_coord.proto b/internal/proto/root_coord.proto index 435d3050efea7..4fdd208ebc418 100644 --- a/internal/proto/root_coord.proto +++ b/internal/proto/root_coord.proto @@ -146,6 +146,7 @@ service RootCoord { message AllocTimestampRequest { common.MsgBase base = 1; uint32 count = 3; + uint64 blockTimestamp = 4; } message AllocTimestampResponse { diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 974b14e87a424..ba40366211c62 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -2026,13 +2026,13 @@ func TestAlterCollectionReplicateProperty(t *testing.T) { factory := newMockMsgStreamFactory() msgStreamObj := msgstream.NewMockMsgStream(t) msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe() - msgStreamObj.EXPECT().AsProducer(mock.Anything).Return().Maybe() - msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return().Maybe() + msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe() + msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe() msgStreamObj.EXPECT().Close().Return().Maybe() mockMsgID1 := mqcommon.NewMockMessageID(t) mockMsgID2 := mqcommon.NewMockMessageID(t) mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe() - msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{ + msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{ "alter_property": {mockMsgID1, mockMsgID2}, }, nil).Maybe() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 4592786f1cccc..344363538892a 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { wg.Wait() } +func TestMetaCacheGetCollectionWithUpdate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + ctx := context.Background() + rootCoord := mocks.NewMockRootCoordClient(t) + queryCoord := mocks.NewMockQueryCoordClient(t) + rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil) + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) + assert.NoError(t, err) + t.Run("update with name", func(t *testing.T) { + rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + Schema: &schemapb.CollectionSchema{ + Name: "bar", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "p", + }, + { + FieldID: 100, + Name: "pk", + }, + }, + }, + ShardsNum: 1, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"}, + }, nil).Once() + rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + PartitionIDs: []typeutil.UniqueID{11}, + PartitionNames: []string{"p1"}, + CreatedTimestamps: []uint64{11}, + CreatedUtcTimestamps: []uint64{11}, + }, nil).Once() + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once() + c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1) + assert.NoError(t, err) + assert.Equal(t, c.collID, int64(1)) + assert.Equal(t, c.schema.Name, "bar") + }) + + t.Run("update with name", func(t *testing.T) { + rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + Schema: &schemapb.CollectionSchema{ + Name: "bar", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "p", + }, + { + FieldID: 100, + Name: "pk", + }, + }, + }, + ShardsNum: 1, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"}, + }, nil).Once() + rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + PartitionIDs: []typeutil.UniqueID{11}, + PartitionNames: []string{"p1"}, + CreatedTimestamps: []uint64{11}, + CreatedUtcTimestamps: []uint64{11}, + }, nil).Once() + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once() + c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0) + assert.NoError(t, err) + assert.Equal(t, c.collID, int64(1)) + assert.Equal(t, c.schema.Name, "bar") + }) +} + func TestMetaCache_GetCollectionName(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index e57c3c265e898..e6488aa3627ff 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1088,20 +1088,16 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error { } endTS, ok := common.GetReplicateEndTS(t.Properties) if ok && collBasicInfo.replicateID != "" { - var rootcoordTS uint64 - for { - allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ - Count: 1, - }) - if err = merr.CheckRPCCall(allocResp, err); err != nil { - return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) - } - rootcoordTS = allocResp.GetTimestamp() - if rootcoordTS > endTS { - break - } - log.Info("wait for rootcoord ts", zap.Uint64("rootcoord ts", rootcoordTS), zap.Uint64("end ts", endTS)) - time.Sleep(500 * time.Millisecond) + allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: 1, + BlockTimestamp: endTS, + }) + if err = merr.CheckRPCCall(allocResp, err); err != nil { + return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) + } + if allocResp.GetTimestamp() <= endTS { + return merr.WrapErrServiceInternal("alter collection: alloc timestamp failed, timestamp is not greater than endTS", + fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS)) } } diff --git a/internal/proxy/task_database.go b/internal/proxy/task_database.go index 72884ed9b6c3b..bee84860d3c04 100644 --- a/internal/proxy/task_database.go +++ b/internal/proxy/task_database.go @@ -2,7 +2,7 @@ package proxy import ( "context" - "time" + "fmt" "go.uber.org/zap" @@ -290,22 +290,18 @@ func (t *alterDatabaseTask) PreExecute(ctx context.Context) error { } oldReplicateEnable, _ := common.IsReplicateEnabled(cacheInfo.properties) if !oldReplicateEnable { // old replicate enable is false - return nil + return merr.WrapErrParameterInvalidMsg("can't set the replicate end ts property in alter database request when db replicate is disabled") + } + allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: 1, + BlockTimestamp: endTS, + }) + if err = merr.CheckRPCCall(allocResp, err); err != nil { + return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) } - var rootcoordTS uint64 - for { - allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ - Count: 1, - }) - if err = merr.CheckRPCCall(allocResp, err); err != nil { - return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) - } - rootcoordTS = allocResp.GetTimestamp() - if rootcoordTS > endTS { - break - } - log.Info("wait for rootcoord ts", zap.Uint64("rootcoord ts", rootcoordTS), zap.Uint64("end ts", endTS)) - time.Sleep(500 * time.Millisecond) + if allocResp.GetTimestamp() <= endTS { + return merr.WrapErrServiceInternal("alter database: alloc timestamp failed, timestamp is not greater than endTS", + fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS)) } return nil diff --git a/internal/proxy/task_database_test.go b/internal/proxy/task_database_test.go index 17615fba42ee5..3be8c4a8a76e6 100644 --- a/internal/proxy/task_database_test.go +++ b/internal/proxy/task_database_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/grpc/metadata" @@ -201,6 +202,163 @@ func TestAlterDatabase(t *testing.T) { assert.Nil(t, err1) } +func TestAlterDatabaseTaskForReplicateProperty(t *testing.T) { + rc := mocks.NewMockRootCoordClient(t) + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + globalMetaCache = mockCache + + t.Run("replicate id", func(t *testing.T) { + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.MmapEnabledKey, + Value: "true", + }, + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("fail to get database info", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("not enable replicate", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{}, + }, nil).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("fail to alloc ts", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Once() + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("alloc wrong ts", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Once() + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: 999, + }, nil).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("alloc wrong ts", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Once() + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: 1001, + }, nil).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.NoError(t, err) + }) +} + func TestDescribeDatabaseTask(t *testing.T) { rc := mocks.NewMockRootCoordClient(t) diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 8eee9fc4c810e..9de31cd53d600 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -128,7 +128,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName) if err != nil { log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err)) - return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + return merr.WrapErrAsInputError(err) } if replicateID != "" { return merr.WrapErrCollectionReplicateMode("insert") diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 83b0574c598e0..019063bbbc9f1 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -4276,7 +4277,7 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) { globalMetaCache = mockCache ctx := context.Background() mockRootcoord := mocks.NewMockRootCoordClient(t) - t.Run("set replicate property to true", func(t *testing.T) { + t.Run("invalid replicate id", func(t *testing.T) { task := &alterCollectionTask{ AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ Properties: []*commonpb.KeyValuePair{ @@ -4293,7 +4294,7 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) { assert.Error(t, err) }) - t.Run("invalid replicate id", func(t *testing.T) { + t.Run("empty replicate id", func(t *testing.T) { task := &alterCollectionTask{ AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ CollectionName: "test", @@ -4330,27 +4331,69 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) { assert.Error(t, err) }) - // t.Run("fail to wait ts", func(t *testing.T) { - // task := &alterCollectionTask{ - // AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ - // CollectionName: "test", - // Properties: []*commonpb.KeyValuePair{ - // { - // Key: common.ReplicateIDKey, - // Value: "", - // }, - // }, - // }, - // rootCoord: mockRootcoord, - // } - // - // mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ - // Status: merr.Success(), - // Timestamp: 100, - // Count: 1, - // }, nil).Once() - // mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() - // err := task.PreExecute(ctx) - // assert.Error(t, err) - // }) + t.Run("alloc wrong ts", func(t *testing.T) { + task := &alterCollectionTask{ + AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ + CollectionName: "test", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "100", + }, + }, + }, + rootCoord: mockRootcoord, + } + + mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: 99, + }, nil).Once() + err := task.PreExecute(ctx) + assert.Error(t, err) + }) +} + +func TestInsertForReplicate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + globalMetaCache = mockCache + + t.Run("get replicate id fail", func(t *testing.T) { + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &insertTask{ + insertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: "foo", + }, + }, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + t.Run("insert with replicate id", func(t *testing.T) { + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + schema: &schemaInfo{ + CollectionSchema: &schemapb.CollectionSchema{ + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-mac", + }, + }, + }, + }, + replicateID: "local-mac", + }, nil).Once() + task := &insertTask{ + insertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: "foo", + }, + }, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) } diff --git a/internal/querynodev2/pipeline/manager_test.go b/internal/querynodev2/pipeline/manager_test.go index e344fed82ad69..48f09ce3ddfd5 100644 --- a/internal/querynodev2/pipeline/manager_test.go +++ b/internal/querynodev2/pipeline/manager_test.go @@ -28,7 +28,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" - "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/pipeline/pipeline_test.go b/internal/querynodev2/pipeline/pipeline_test.go index bfd020e8bf37c..dd153527e1c67 100644 --- a/internal/querynodev2/pipeline/pipeline_test.go +++ b/internal/querynodev2/pipeline/pipeline_test.go @@ -24,13 +24,14 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" - "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -103,6 +104,12 @@ func (suite *PipelineTestSuite) TestBasic() { schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, + DbProperties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, }) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) @@ -136,16 +143,16 @@ func (suite *PipelineTestSuite) TestBasic() { Collection: suite.collectionManager, Segment: suite.segmentManager, } - pipeline, err := NewPipeLine(collection, suite.channel, manager, suite.msgDispatcher, suite.delegator) + pipelineObj, err := NewPipeLine(collection, suite.channel, manager, suite.msgDispatcher, suite.delegator) suite.NoError(err) // Init Consumer - err = pipeline.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{}) + err = pipelineObj.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{}) suite.NoError(err) - err = pipeline.Start() + err = pipelineObj.Start() suite.NoError(err) - defer pipeline.Close() + defer pipelineObj.Close() // build input msg in := suite.buildMsgPack(schema) diff --git a/internal/rootcoord/alter_collection_task_test.go b/internal/rootcoord/alter_collection_task_test.go index e2b53e958da50..f0a0edb75ea5f 100644 --- a/internal/rootcoord/alter_collection_task_test.go +++ b/internal/rootcoord/alter_collection_task_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -29,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" ) func Test_alterCollectionTask_Prepare(t *testing.T) { @@ -217,14 +219,25 @@ func Test_alterCollectionTask_Execute(t *testing.T) { assert.NoError(t, err) }) - t.Run("alter successfully", func(t *testing.T) { + t.Run("alter successfully2", func(t *testing.T) { meta := mockrootcoord.NewIMetaTable(t) meta.On("GetCollectionByName", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Collection{CollectionID: int64(1)}, nil) + ).Return(&model.Collection{ + CollectionID: int64(1), + Name: "cn", + DBName: "foo", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + }, nil) meta.On("AlterCollection", mock.Anything, mock.Anything, @@ -237,19 +250,37 @@ func Test_alterCollectionTask_Execute(t *testing.T) { broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { return nil } - - core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker)) + packChan := make(chan *msgstream.MsgPack, 10) + ticker := newChanTimeTickSync(packChan) + ticker.addDmlChannels("by-dev-rootcoord-dml_1") + + core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker)) + newPros := append(properties, &commonpb.KeyValuePair{ + Key: common.ReplicateEndTSKey, + Value: "10000", + }) task := &alterCollectionTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.AlterCollectionRequest{ Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "cn", - Properties: properties, + Properties: newPros, }, } err := task.Execute(context.Background()) assert.NoError(t, err) + time.Sleep(time.Second) + select { + case pack := <-packChan: + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) + replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase()) + assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection()) + assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) + default: + assert.Fail(t, "no message sent") + } }) t.Run("test update collection props", func(t *testing.T) { diff --git a/internal/rootcoord/alter_database_task_test.go b/internal/rootcoord/alter_database_task_test.go index 7e5887f6538ba..47b66e176b4e7 100644 --- a/internal/rootcoord/alter_database_task_test.go +++ b/internal/rootcoord/alter_database_task_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -29,6 +30,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/funcutil" ) func Test_alterDatabaseTask_Prepare(t *testing.T) { @@ -47,6 +50,76 @@ func Test_alterDatabaseTask_Prepare(t *testing.T) { err := task.Prepare(context.Background()) assert.NoError(t, err) }) + + t.Run("replicate id", func(t *testing.T) { + { + // no collections + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT(). + ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*model.Collection{}, nil). + Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.NoError(t, err) + } + { + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + Name: "foo", + }, + }, nil).Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.Error(t, err) + } + { + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("err")). + Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.Error(t, err) + } + }) } func Test_alterDatabaseTask_Execute(t *testing.T) { @@ -146,25 +219,51 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Database{ID: int64(1)}, nil) + ).Return(&model.Database{ + ID: int64(1), + Name: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil) meta.On("AlterDatabase", mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil) - - core := newTestCore(withMeta(meta), withValidProxyManager()) + // the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step + packChan := make(chan *msgstream.MsgPack, 10) + ticker := newChanTimeTickSync(packChan) + ticker.addDmlChannels("by-dev-rootcoord-dml_1") + + core := newTestCore(withMeta(meta), withValidProxyManager(), withTtSynchronizer(ticker)) + newPros := append(properties, + &commonpb.KeyValuePair{Key: common.ReplicateEndTSKey, Value: "1000"}, + ) task := &alterDatabaseTask{ baseTask: newBaseTask(context.Background(), core), Req: &rootcoordpb.AlterDatabaseRequest{ DbName: "cn", - Properties: properties, + Properties: newPros, }, } err := task.Execute(context.Background()) assert.NoError(t, err) + time.Sleep(time.Second) + select { + case pack := <-packChan: + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) + replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase()) + assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) + default: + assert.Fail(t, "no message sent") + } }) t.Run("test update collection props", func(t *testing.T) { @@ -248,3 +347,26 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { assert.Empty(t, ret2) }) } + +func TestMergeProperties(t *testing.T) { + p := MergeProperties([]*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + { + Key: "foo", + Value: "xxx", + }, + }, []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1001", + }, + }) + assert.Len(t, p, 3) + m := funcutil.KeyValuePair2Map(p) + assert.Equal(t, "", m[common.ReplicateIDKey]) + assert.Equal(t, "1001", m[common.ReplicateEndTSKey]) + assert.Equal(t, "xxx", m["foo"]) +} diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index 188fdc1d7947a..62d44562cce0d 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -823,6 +823,70 @@ func Test_createCollectionTask_Prepare(t *testing.T) { }) } +func TestCreateCollectionTask_Prepare_WithProperty(t *testing.T) { + paramtable.Init() + meta := mockrootcoord.NewIMetaTable(t) + t.Run("with db properties", func(t *testing.T) { + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + Name: "foo", + ID: 1, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Twice() + meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{ + util.DefaultDBID: {1, 2}, + }).Once() + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0).Once() + + defer cleanTestEnv() + + collectionName := funcutil.GenRandomStr() + field1 := funcutil.GenRandomStr() + + ticker := newRocksMqTtSynchronizer() + core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta)) + + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + Name: field1, + DataType: schemapb.DataType_Int64, + }, + }, + } + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task := createCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, + CollectionName: collectionName, + Schema: marshaledSchema, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "hoo", + }, + }, + }, + dbID: 1, + } + task.Req.ShardsNum = common.DefaultShardsNum + err = task.Prepare(context.Background()) + assert.Len(t, task.dbProperties, 1) + assert.Len(t, task.Req.Properties, 0) + assert.NoError(t, err) + }) +} + func Test_createCollectionTask_Execute(t *testing.T) { t.Run("add same collection with different parameters", func(t *testing.T) { defer cleanTestEnv() diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 9dc973be8133a..5f59f27c1246a 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -1058,6 +1058,31 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync { return ticker } +func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync { + f := msgstream.NewMockMqFactory() + f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) { + stream := msgstream.NewWastedMockMsgStream() + stream.BroadcastFunc = func(pack *msgstream.MsgPack) error { + log.Info("mock Broadcast") + packChan <- pack + return nil + } + stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { + log.Info("mock BroadcastMark") + packChan <- pack + return map[string][]msgstream.MessageID{}, nil + } + stream.AsProducerFunc = func(channels []string) { + } + stream.ChanFunc = func() <-chan *msgstream.MsgPack { + return packChan + } + return stream, nil + } + + return newTickerWithFactory(f) +} + type mockDdlTsLockManager struct { DdlTsLockManager GetMinDdlTsFunc func() Timestamp diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index d13d7fe8dd992..58c609dbdc47c 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1746,6 +1746,19 @@ func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestam }, nil } + if in.BlockTimestamp > 0 { + blockTime, _ := tsoutil.ParseTS(in.BlockTimestamp) + lastTime := c.tsoAllocator.GetLastSavedTime() + deltaDuration := blockTime.Sub(lastTime) + if deltaDuration > 0 { + log.Info("wait for block timestamp", + zap.Time("blockTime", blockTime), + zap.Time("lastTime", lastTime), + zap.Duration("delta", deltaDuration)) + time.Sleep(deltaDuration + time.Millisecond*200) + } + } + ts, err := c.tsoAllocator.GenerateTSO(in.GetCount()) if err != nil { log.Ctx(ctx).Error("failed to allocate timestamp", zap.String("role", typeutil.RootCoordRole), diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index b8915e5deab79..e6b8c380f4e9b 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -856,6 +856,32 @@ func TestRootCoord_AllocTimestamp(t *testing.T) { assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp()) assert.Equal(t, count, resp.GetCount()) }) + + t.Run("block timestamp", func(t *testing.T) { + alloc := newMockTsoAllocator() + count := uint32(10) + current := time.Now() + ts := tsoutil.ComposeTSByTime(current.Add(time.Second), 1) + alloc.GenerateTSOF = func(count uint32) (uint64, error) { + // end ts + return ts, nil + } + alloc.GetLastSavedTimeF = func() time.Time { + return current + } + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withTsoAllocator(alloc)) + resp, err := c.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: count, + BlockTimestamp: tsoutil.ComposeTSByTime(current.Add(time.Second), 0), + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + // begin ts + assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp()) + assert.Equal(t, count, resp.GetCount()) + }) } func TestRootCoord_AllocID(t *testing.T) { diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index 242980e6f8b34..d4c20fae8c3fb 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -17,6 +17,8 @@ package msgdispatcher import ( + "fmt" + "math/rand" "sync" "testing" "time" @@ -26,6 +28,8 @@ import ( "github.com/stretchr/testify/mock" "golang.org/x/net/context" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" ) @@ -138,3 +142,195 @@ func BenchmarkDispatcher_handle(b *testing.B) { // BenchmarkDispatcher_handle-12 9568 122123 ns/op // PASS } + +func TestGroupMessage(t *testing.T) { + d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false) + assert.NoError(t, err) + d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil)) + d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo"))) + { + // no replicate msg + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 5, + EndTimestamp: 5, + }, + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + Timestamp: 5, + }, + ShardName: "mock_pchannel_0_1v0", + }, + }, + }, + }) + assert.Len(t, packs, 1) + } + + { + // equal to replicateID + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test", + }, + }, + }, + }, + }, + }) + assert.Len(t, packs, 2) + { + replicatePack := packs["mock_pchannel_0_2v0"] + assert.EqualValues(t, 100, replicatePack.BeginTs) + assert.EqualValues(t, 100, replicatePack.EndTs) + assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp) + assert.Len(t, replicatePack.Msgs, 0) + } + { + replicatePack := packs["mock_pchannel_0_1v0"] + assert.EqualValues(t, 1, replicatePack.BeginTs) + assert.EqualValues(t, 10, replicatePack.EndTs) + assert.EqualValues(t, 1, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 10, replicatePack.EndPositions[0].Timestamp) + assert.Len(t, replicatePack.Msgs, 0) + } + } + + { + // not equal to replicateID + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test-1", // not equal to replicateID + }, + }, + }, + }, + }, + }) + assert.Len(t, packs, 1) + replicatePack := packs["mock_pchannel_0_2v0"] + assert.Nil(t, replicatePack) + } + + { + // replicate end + replicateTarget := d.targets["mock_pchannel_0_2v0"] + assert.NotNil(t, replicateTarget.replicateConfig) + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test", + }, + }, + IsEnd: true, + Database: "foo", + }, + }, + }, + }) + assert.Len(t, packs, 2) + replicatePack := packs["mock_pchannel_0_2v0"] + assert.EqualValues(t, 100, replicatePack.BeginTs) + assert.EqualValues(t, 100, replicatePack.EndTs) + assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp) + assert.Nil(t, replicateTarget.replicateConfig) + } +} diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 9ba194af992d5..76486865e81a1 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -283,20 +283,6 @@ func (ms *mqMsgStream) isSkipSystemTT() bool { return ms.replicateID != "" } -func (ms *mqMsgStream) checkReplicateEndMsg(msg TsMsg) bool { - if !ms.isSkipSystemTT() || - msg.Type() != commonpb.MsgType_Replicate || - ms.checkFunc == nil { - return false - } - replicateMsg := msg.(*ReplicateMsg) - check := ms.checkFunc(replicateMsg) - if check { - ms.replicateID = "" - } - return check -} - // checkReplicateID check the replicate id of the message, return values: isMatch, isReplicate func (ms *mqMsgStream) checkReplicateID(msg TsMsg) (bool, bool) { if !ms.isSkipSystemTT() { diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index 6c9973f3373cd..1ff559c191c63 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -708,6 +708,21 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) + replicatePack := MsgPack{} + replicatePack.Msgs = append(replicatePack.Msgs, &ReplicateMsg{ + BaseMsg: BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{100}, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + }, + }, + }) + msgPack2 := MsgPack{} msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5)) @@ -721,6 +736,9 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + err = inputStream.Produce(ctx, &replicatePack) + require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index de7c403316bc2..3b1d4b3c32d8d 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -75,7 +75,6 @@ type MsgStream interface { CheckTopicValid(channel string) error ForceEnableProduce(can bool) - SetReplicate(config *ReplicateConfig) } type ReplicateConfig struct { @@ -92,7 +91,7 @@ func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig { replicateConfig := &ReplicateConfig{ ReplicateID: replicateID, CheckFunc: func(msg *ReplicateMsg) bool { - if !msg.IsEnd { + if !msg.GetIsEnd() { return false } log.Info("check replicate msg", @@ -100,10 +99,10 @@ func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig { zap.String("dbName", dbName), zap.String("colName", colName), zap.Any("msg", msg)) - if msg.IsCluster { + if msg.GetIsCluster() { return true } - return msg.Database == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "") + return msg.GetDatabase() == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "") }, } return replicateConfig diff --git a/pkg/mq/msgstream/msgstream_util_test.go b/pkg/mq/msgstream/msgstream_util_test.go index f8d1754eac205..9811ea5e78670 100644 --- a/pkg/mq/msgstream/msgstream_util_test.go +++ b/pkg/mq/msgstream/msgstream_util_test.go @@ -24,6 +24,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/common" ) @@ -80,3 +82,90 @@ func TestGetLatestMsgID(t *testing.T) { assert.Equal(t, []byte("mock"), id) } } + +func TestReplicateConfig(t *testing.T) { + t.Run("get replicate id", func(t *testing.T) { + { + msg := &InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local", + }, + }, + }, + } + assert.Equal(t, "local", GetReplicateID(msg)) + assert.True(t, MatchReplicateID(msg, "local")) + } + { + msg := &InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{}, + }, + } + assert.Equal(t, "", GetReplicateID(msg)) + assert.False(t, MatchReplicateID(msg, "local")) + } + { + msg := &MarshalFailTsMsg{} + assert.Equal(t, "", GetReplicateID(msg)) + } + }) + + t.Run("get replicate config", func(t *testing.T) { + { + assert.Nil(t, GetReplicateConfig("", "", "")) + } + { + rc := GetReplicateConfig("local", "db", "col") + assert.Equal(t, "local", rc.ReplicateID) + checkFunc := rc.CheckFunc + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{}, + })) + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + IsCluster: true, + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db1", + }, + })) + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + Collection: "col1", + }, + })) + } + { + rc := GetReplicateConfig("local", "db", "col") + checkFunc := rc.CheckFunc + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db1", + Collection: "col1", + }, + })) + } + }) +} diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index cc4cc7fdf8828..8de8203461eaa 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -70,7 +70,7 @@ var ( ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false) ErrCollectionOnRecovering = newMilvusError("collection on recovering", 106, true) ErrCollectionVectorClusteringKeyNotAllowed = newMilvusError("vector clustering key not allowed", 107, false) - ErrCollectionReplicateMode = newMilvusError("can't operate when the collection is replicate mode", 108, false) + ErrCollectionReplicateMode = newMilvusError("can't operate on the collection under standby mode", 108, false) // Partition related ErrPartitionNotFound = newMilvusError("partition not found", 200, false)