diff --git a/internal/datacoord/channel.go b/internal/datacoord/channel.go index d222d274d0222..0a80b43a0745f 100644 --- a/internal/datacoord/channel.go +++ b/internal/datacoord/channel.go @@ -113,7 +113,7 @@ func (ch *channelMeta) String() string { return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions) } -func (channelMeta) GetDBProperties() []*commonpb.KeyValuePair { +func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair { return nil } diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index a944f29248a27..92637c78f8b2f 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -2026,13 +2026,13 @@ func TestAlterCollectionReplicateProperty(t *testing.T) { factory := newMockMsgStreamFactory() msgStreamObj := msgstream.NewMockMsgStream(t) msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe() - msgStreamObj.EXPECT().AsProducer(mock.Anything).Return().Maybe() + msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe() msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return().Maybe() msgStreamObj.EXPECT().Close().Return().Maybe() mockMsgID1 := mqcommon.NewMockMessageID(t) mockMsgID2 := mqcommon.NewMockMessageID(t) mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe() - msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{ + msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{ "alter_property": {mockMsgID1, mockMsgID2}, }, nil).Maybe() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 4592786f1cccc..344363538892a 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { wg.Wait() } +func TestMetaCacheGetCollectionWithUpdate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + ctx := context.Background() + rootCoord := mocks.NewMockRootCoordClient(t) + queryCoord := mocks.NewMockQueryCoordClient(t) + rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil) + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) + assert.NoError(t, err) + t.Run("update with name", func(t *testing.T) { + rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + Schema: &schemapb.CollectionSchema{ + Name: "bar", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "p", + }, + { + FieldID: 100, + Name: "pk", + }, + }, + }, + ShardsNum: 1, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"}, + }, nil).Once() + rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + PartitionIDs: []typeutil.UniqueID{11}, + PartitionNames: []string{"p1"}, + CreatedTimestamps: []uint64{11}, + CreatedUtcTimestamps: []uint64{11}, + }, nil).Once() + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once() + c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1) + assert.NoError(t, err) + assert.Equal(t, c.collID, int64(1)) + assert.Equal(t, c.schema.Name, "bar") + }) + + t.Run("update with name", func(t *testing.T) { + rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + Schema: &schemapb.CollectionSchema{ + Name: "bar", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "p", + }, + { + FieldID: 100, + Name: "pk", + }, + }, + }, + ShardsNum: 1, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"}, + }, nil).Once() + rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + PartitionIDs: []typeutil.UniqueID{11}, + PartitionNames: []string{"p1"}, + CreatedTimestamps: []uint64{11}, + CreatedUtcTimestamps: []uint64{11}, + }, nil).Once() + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once() + c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0) + assert.NoError(t, err) + assert.Equal(t, c.collID, int64(1)) + assert.Equal(t, c.schema.Name, "bar") + }) +} + func TestMetaCache_GetCollectionName(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 66dadcc49eb16..097ea60cac85f 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -128,7 +128,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName) if err != nil { log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err)) - return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + return merr.WrapErrAsInputError(err) } if replicateID != "" { return merr.WrapErrCollectionReplicateMode("insert") diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 94e89f4504b63..a1fa550375702 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -4312,28 +4312,48 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) { err := task.PreExecute(ctx) assert.Error(t, err) }) +} - // t.Run("fail to wait ts", func(t *testing.T) { - // task := &alterCollectionTask{ - // AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ - // CollectionName: "test", - // Properties: []*commonpb.KeyValuePair{ - // { - // Key: common.ReplicateIDKey, - // Value: "", - // }, - // }, - // }, - // rootCoord: mockRootcoord, - // } - // - // mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ - // Status: merr.Success(), - // Timestamp: 100, - // Count: 1, - // }, nil).Once() - // mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() - // err := task.PreExecute(ctx) - // assert.Error(t, err) - // }) +func TestInsertForReplicate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + globalMetaCache = mockCache + + t.Run("get replicate id fail", func(t *testing.T) { + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &insertTask{ + insertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: "foo", + }, + }, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + t.Run("insert with replicate id", func(t *testing.T) { + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + schema: &schemaInfo{ + CollectionSchema: &schemapb.CollectionSchema{ + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-mac", + }, + }, + }, + }, + replicateID: "local-mac", + }, nil).Once() + task := &insertTask{ + insertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: "foo", + }, + }, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) } diff --git a/internal/querynodev2/pipeline/pipeline_test.go b/internal/querynodev2/pipeline/pipeline_test.go index e5e02296ad3d6..75e45deb04c66 100644 --- a/internal/querynodev2/pipeline/pipeline_test.go +++ b/internal/querynodev2/pipeline/pipeline_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" @@ -31,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -111,6 +113,12 @@ func (suite *PipelineTestSuite) TestBasic() { schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, + DbProperties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, }) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) @@ -143,16 +151,16 @@ func (suite *PipelineTestSuite) TestBasic() { Collection: suite.collectionManager, Segment: suite.segmentManager, } - pipeline, err := NewPipeLine(collection, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator) + pipelineObj, err := NewPipeLine(collection, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator) suite.NoError(err) // Init Consumer - err = pipeline.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{}) + err = pipelineObj.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{}) suite.NoError(err) - err = pipeline.Start() + err = pipelineObj.Start() suite.NoError(err) - defer pipeline.Close() + defer pipelineObj.Close() // watch tsafe manager listener := suite.tSafeManager.WatchChannel(suite.channel) @@ -161,7 +169,7 @@ func (suite *PipelineTestSuite) TestBasic() { in := suite.buildMsgPack(schema) suite.msgChan <- in - // wait pipeline work + // wait pipelineObj work <-listener.On() // check tsafe diff --git a/internal/rootcoord/alter_collection_task_test.go b/internal/rootcoord/alter_collection_task_test.go index 716f2aa28b87d..8092e9ed56742 100644 --- a/internal/rootcoord/alter_collection_task_test.go +++ b/internal/rootcoord/alter_collection_task_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -29,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" ) func Test_alterCollectionTask_Prepare(t *testing.T) { @@ -217,14 +219,25 @@ func Test_alterCollectionTask_Execute(t *testing.T) { assert.NoError(t, err) }) - t.Run("alter successfully", func(t *testing.T) { + t.Run("alter successfully2", func(t *testing.T) { meta := mockrootcoord.NewIMetaTable(t) meta.On("GetCollectionByName", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Collection{CollectionID: int64(1)}, nil) + ).Return(&model.Collection{ + CollectionID: int64(1), + Name: "cn", + DBName: "foo", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + }, nil) meta.On("AlterCollection", mock.Anything, mock.Anything, @@ -237,19 +250,37 @@ func Test_alterCollectionTask_Execute(t *testing.T) { broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { return nil } - - core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker)) + packChan := make(chan *msgstream.MsgPack, 10) + ticker := newChanTimeTickSync(packChan) + ticker.addDmlChannels("by-dev-rootcoord-dml_1") + + core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker)) + newPros := append(properties, &commonpb.KeyValuePair{ + Key: common.ReplicateEndTSKey, + Value: "10000", + }) task := &alterCollectionTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.AlterCollectionRequest{ Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "cn", - Properties: properties, + Properties: newPros, }, } err := task.Execute(context.Background()) assert.NoError(t, err) + time.Sleep(time.Second) + select { + case pack := <-packChan: + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) + replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase()) + assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection()) + assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) + default: + assert.Fail(t, "no message sent") + } }) t.Run("test update collection props", func(t *testing.T) { diff --git a/internal/rootcoord/alter_database_task_test.go b/internal/rootcoord/alter_database_task_test.go index 69ad9ecfaa832..06f252e9c216c 100644 --- a/internal/rootcoord/alter_database_task_test.go +++ b/internal/rootcoord/alter_database_task_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -29,6 +30,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/funcutil" ) func Test_alterDatabaseTask_Prepare(t *testing.T) { @@ -47,6 +50,76 @@ func Test_alterDatabaseTask_Prepare(t *testing.T) { err := task.Prepare(context.Background()) assert.NoError(t, err) }) + + t.Run("replicate id", func(t *testing.T) { + { + // no collections + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT(). + ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*model.Collection{}, nil). + Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.NoError(t, err) + } + { + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + Name: "foo", + }, + }, nil).Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.Error(t, err) + } + { + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("err")). + Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.Error(t, err) + } + }) } func Test_alterDatabaseTask_Execute(t *testing.T) { @@ -146,25 +219,51 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Database{ID: int64(1)}, nil) + ).Return(&model.Database{ + ID: int64(1), + Name: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil) meta.On("AlterDatabase", mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil) + // the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step + packChan := make(chan *msgstream.MsgPack, 10) + ticker := newChanTimeTickSync(packChan) + ticker.addDmlChannels("by-dev-rootcoord-dml_1") - core := newTestCore(withMeta(meta), withValidProxyManager()) + core := newTestCore(withMeta(meta), withValidProxyManager(), withTtSynchronizer(ticker)) + newPros := append(properties, + &commonpb.KeyValuePair{Key: common.ReplicateEndTSKey, Value: "1000"}, + ) task := &alterDatabaseTask{ baseTask: newBaseTask(context.Background(), core), Req: &rootcoordpb.AlterDatabaseRequest{ DbName: "cn", - Properties: properties, + Properties: newPros, }, } err := task.Execute(context.Background()) assert.NoError(t, err) + time.Sleep(time.Second) + select { + case pack := <-packChan: + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) + replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase()) + assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) + default: + assert.Fail(t, "no message sent") + } }) t.Run("test update collection props", func(t *testing.T) { @@ -213,3 +312,26 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { }) }) } + +func TestMergeProperties(t *testing.T) { + p := MergeProperties([]*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + { + Key: "foo", + Value: "xxx", + }, + }, []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1001", + }, + }) + assert.Len(t, p, 3) + m := funcutil.KeyValuePair2Map(p) + assert.Equal(t, "", m[common.ReplicateIDKey]) + assert.Equal(t, "1001", m[common.ReplicateEndTSKey]) + assert.Equal(t, "xxx", m["foo"]) +} diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index e7d7cc2ce9c88..1a3a7e18ce36f 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -823,6 +823,70 @@ func Test_createCollectionTask_Prepare(t *testing.T) { }) } +func TestCreateCollectionTask_Prepare_WithProperty(t *testing.T) { + paramtable.Init() + meta := mockrootcoord.NewIMetaTable(t) + t.Run("with db properties", func(t *testing.T) { + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + Name: "foo", + ID: 1, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Twice() + meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{ + util.DefaultDBID: {1, 2}, + }).Once() + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0).Once() + + defer cleanTestEnv() + + collectionName := funcutil.GenRandomStr() + field1 := funcutil.GenRandomStr() + + ticker := newRocksMqTtSynchronizer() + core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta)) + + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + Name: field1, + DataType: schemapb.DataType_Int64, + }, + }, + } + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task := createCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, + CollectionName: collectionName, + Schema: marshaledSchema, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "hoo", + }, + }, + }, + dbID: 1, + } + task.Req.ShardsNum = common.DefaultShardsNum + err = task.Prepare(context.Background()) + assert.Len(t, task.dbProperties, 1) + assert.Len(t, task.Req.Properties, 0) + assert.NoError(t, err) + }) +} + func Test_createCollectionTask_Execute(t *testing.T) { t.Run("add same collection with different parameters", func(t *testing.T) { defer cleanTestEnv() diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 9dc973be8133a..5f59f27c1246a 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -1058,6 +1058,31 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync { return ticker } +func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync { + f := msgstream.NewMockMqFactory() + f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) { + stream := msgstream.NewWastedMockMsgStream() + stream.BroadcastFunc = func(pack *msgstream.MsgPack) error { + log.Info("mock Broadcast") + packChan <- pack + return nil + } + stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { + log.Info("mock BroadcastMark") + packChan <- pack + return map[string][]msgstream.MessageID{}, nil + } + stream.AsProducerFunc = func(channels []string) { + } + stream.ChanFunc = func() <-chan *msgstream.MsgPack { + return packChan + } + return stream, nil + } + + return newTickerWithFactory(f) +} + type mockDdlTsLockManager struct { DdlTsLockManager GetMinDdlTsFunc func() Timestamp diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index 5af0f5cc5ee01..0c53d1106e6b3 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -17,6 +17,8 @@ package msgdispatcher import ( + "fmt" + "math/rand" "sync" "testing" "time" @@ -26,6 +28,8 @@ import ( "github.com/stretchr/testify/mock" "golang.org/x/net/context" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" ) @@ -137,3 +141,195 @@ func BenchmarkDispatcher_handle(b *testing.B) { // BenchmarkDispatcher_handle-12 9568 122123 ns/op // PASS } + +func TestGroupMessage(t *testing.T) { + d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false) + assert.NoError(t, err) + d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil)) + d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo"))) + { + // no replicate msg + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 5, + EndTimestamp: 5, + }, + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + Timestamp: 5, + }, + ShardName: "mock_pchannel_0_1v0", + }, + }, + }, + }) + assert.Len(t, packs, 1) + } + + { + // equal to replicateID + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test", + }, + }, + }, + }, + }, + }) + assert.Len(t, packs, 2) + { + replicatePack := packs["mock_pchannel_0_2v0"] + assert.EqualValues(t, 100, replicatePack.BeginTs) + assert.EqualValues(t, 100, replicatePack.EndTs) + assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp) + assert.Len(t, replicatePack.Msgs, 0) + } + { + replicatePack := packs["mock_pchannel_0_1v0"] + assert.EqualValues(t, 1, replicatePack.BeginTs) + assert.EqualValues(t, 10, replicatePack.EndTs) + assert.EqualValues(t, 1, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 10, replicatePack.EndPositions[0].Timestamp) + assert.Len(t, replicatePack.Msgs, 0) + } + } + + { + // not equal to replicateID + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test-1", // not equal to replicateID + }, + }, + }, + }, + }, + }) + assert.Len(t, packs, 1) + replicatePack := packs["mock_pchannel_0_2v0"] + assert.Nil(t, replicatePack) + } + + { + // replicate end + replicateTarget := d.targets["mock_pchannel_0_2v0"] + assert.NotNil(t, replicateTarget.replicateConfig) + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test", + }, + }, + IsEnd: true, + Database: "foo", + }, + }, + }, + }) + assert.Len(t, packs, 2) + replicatePack := packs["mock_pchannel_0_2v0"] + assert.EqualValues(t, 100, replicatePack.BeginTs) + assert.EqualValues(t, 100, replicatePack.EndTs) + assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp) + assert.Nil(t, replicateTarget.replicateConfig) + } +} diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index b46ca44533942..4d8595aee9dbb 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -704,6 +704,21 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) + replicatePack := MsgPack{} + replicatePack.Msgs = append(replicatePack.Msgs, &ReplicateMsg{ + BaseMsg: BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{100}, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + }, + }, + }) + msgPack2 := MsgPack{} msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5)) @@ -717,6 +732,9 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + err = inputStream.Produce(ctx, &replicatePack) + require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 9393baa660955..8bc70721e1379 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -91,7 +91,7 @@ func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig { replicateConfig := &ReplicateConfig{ ReplicateID: replicateID, CheckFunc: func(msg *ReplicateMsg) bool { - if !msg.IsEnd { + if !msg.GetIsEnd() { return false } log.Info("check replicate msg", @@ -99,10 +99,10 @@ func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig { zap.String("dbName", dbName), zap.String("colName", colName), zap.Any("msg", msg)) - if msg.IsCluster { + if msg.GetIsCluster() { return true } - return msg.Database == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "") + return msg.GetDatabase() == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "") }, } return replicateConfig diff --git a/pkg/mq/msgstream/msgstream_util_test.go b/pkg/mq/msgstream/msgstream_util_test.go index f8d1754eac205..9811ea5e78670 100644 --- a/pkg/mq/msgstream/msgstream_util_test.go +++ b/pkg/mq/msgstream/msgstream_util_test.go @@ -24,6 +24,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/common" ) @@ -80,3 +82,90 @@ func TestGetLatestMsgID(t *testing.T) { assert.Equal(t, []byte("mock"), id) } } + +func TestReplicateConfig(t *testing.T) { + t.Run("get replicate id", func(t *testing.T) { + { + msg := &InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local", + }, + }, + }, + } + assert.Equal(t, "local", GetReplicateID(msg)) + assert.True(t, MatchReplicateID(msg, "local")) + } + { + msg := &InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{}, + }, + } + assert.Equal(t, "", GetReplicateID(msg)) + assert.False(t, MatchReplicateID(msg, "local")) + } + { + msg := &MarshalFailTsMsg{} + assert.Equal(t, "", GetReplicateID(msg)) + } + }) + + t.Run("get replicate config", func(t *testing.T) { + { + assert.Nil(t, GetReplicateConfig("", "", "")) + } + { + rc := GetReplicateConfig("local", "db", "col") + assert.Equal(t, "local", rc.ReplicateID) + checkFunc := rc.CheckFunc + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{}, + })) + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + IsCluster: true, + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db1", + }, + })) + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + Collection: "col1", + }, + })) + } + { + rc := GetReplicateConfig("local", "db", "col") + checkFunc := rc.CheckFunc + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db1", + Collection: "col1", + }, + })) + } + }) +}