diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 69f4c4a4d9099..6e37234ba9a22 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -873,6 +873,7 @@ common: bloomFilterType: BlockedBloomFilter # bloom filter type, support BasicBloomFilter and BlockedBloomFilter maxBloomFalsePositive: 0.001 # max false positive rate for bloom filter bloomFilterApplyBatchSize: 1000 # batch size when to apply pk to bloom filter + collectionReplicateEnable: false # Whether to enable collection replication. usePartitionKeyAsClusteringKey: false # if true, do clustering compaction and segment prune on partition key field useVectorAsClusteringKey: false # if true, do clustering compaction and segment prune on vector field enableVectorClusteringKey: false # if true, enable vector clustering key and vector clustering compaction diff --git a/internal/datacoord/channel.go b/internal/datacoord/channel.go index c13309ba537ef..0a80b43a0745f 100644 --- a/internal/datacoord/channel.go +++ b/internal/datacoord/channel.go @@ -36,6 +36,7 @@ type ROChannel interface { GetSchema() *schemapb.CollectionSchema GetCreateTimestamp() Timestamp GetWatchInfo() *datapb.ChannelWatchInfo + GetDBProperties() []*commonpb.KeyValuePair } type RWChannel interface { @@ -48,6 +49,7 @@ func NewRWChannel(name string, startPos []*commonpb.KeyDataPair, schema *schemapb.CollectionSchema, createTs uint64, + dbProperties []*commonpb.KeyValuePair, ) RWChannel { return &StateChannel{ Name: name, @@ -55,9 +57,11 @@ func NewRWChannel(name string, StartPositions: startPos, Schema: schema, CreateTimestamp: createTs, + DBProperties: dbProperties, } } +// TODO fubang same as StateChannel type channelMeta struct { Name string CollectionID UniqueID @@ -109,6 +113,10 @@ func (ch *channelMeta) String() string { return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions) } +func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair { + return nil +} + type ChannelState string const ( @@ -126,6 +134,7 @@ type StateChannel struct { CollectionID UniqueID StartPositions []*commonpb.KeyDataPair Schema *schemapb.CollectionSchema + DBProperties []*commonpb.KeyValuePair CreateTimestamp uint64 Info *datapb.ChannelWatchInfo @@ -143,6 +152,7 @@ func NewStateChannel(ch RWChannel) *StateChannel { Schema: ch.GetSchema(), CreateTimestamp: ch.GetCreateTimestamp(), Info: ch.GetWatchInfo(), + DBProperties: ch.GetDBProperties(), assignedNode: bufferID, } @@ -156,6 +166,7 @@ func NewStateChannelByWatchInfo(nodeID int64, info *datapb.ChannelWatchInfo) *St Name: info.GetVchan().GetChannelName(), CollectionID: info.GetVchan().GetCollectionID(), Schema: info.GetSchema(), + DBProperties: info.GetDbProperties(), Info: info, assignedNode: nodeID, } @@ -277,3 +288,7 @@ func (c *StateChannel) Assign(nodeID int64) { func (c *StateChannel) setState(state ChannelState) { c.currentState = state } + +func (c *StateChannel) GetDBProperties() []*commonpb.KeyValuePair { + return c.DBProperties +} diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index 28bab70ac18a2..15b3933d5ded7 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -736,20 +736,22 @@ func (m *ChannelManagerImpl) fillChannelWatchInfo(op *ChannelOp) error { schema := ch.GetSchema() if schema == nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() collInfo, err := m.h.GetCollection(ctx, ch.GetCollectionID()) if err != nil { + cancel() return err } + cancel() schema = collInfo.Schema } info := &datapb.ChannelWatchInfo{ - Vchan: reduceVChanSize(vcInfo), - StartTs: startTs, - State: inferStateByOpType(op.Type), - Schema: schema, - OpID: opID, + Vchan: reduceVChanSize(vcInfo), + StartTs: startTs, + State: inferStateByOpType(op.Type), + Schema: schema, + OpID: opID, + DbProperties: ch.GetDBProperties(), } ch.UpdateWatchInfo(info) } diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index 60529aa8d2dea..d1b321ff78ae2 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -138,6 +138,12 @@ type collectionInfo struct { VChannelNames []string } +type dbInfo struct { + ID int64 + Name string + Properties []*commonpb.KeyValuePair +} + // NewMeta creates meta from provided `kv.TxnKV` func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManager storage.ChunkManager) (*meta, error) { im, err := newIndexMeta(ctx, catalog) @@ -244,12 +250,12 @@ func (m *meta) reloadCollectionsFromRootcoord(ctx context.Context, broker broker return err } for _, dbName := range resp.GetDbNames() { - resp, err := broker.ShowCollections(ctx, dbName) + collectionsResp, err := broker.ShowCollections(ctx, dbName) if err != nil { return err } - for _, collectionID := range resp.GetCollectionIds() { - resp, err := broker.DescribeCollectionInternal(ctx, collectionID) + for _, collectionID := range collectionsResp.GetCollectionIds() { + descResp, err := broker.DescribeCollectionInternal(ctx, collectionID) if err != nil { return err } @@ -259,14 +265,14 @@ func (m *meta) reloadCollectionsFromRootcoord(ctx context.Context, broker broker } collection := &collectionInfo{ ID: collectionID, - Schema: resp.GetSchema(), + Schema: descResp.GetSchema(), Partitions: partitionIDs, - StartPositions: resp.GetStartPositions(), - Properties: funcutil.KeyValuePair2Map(resp.GetProperties()), - CreatedAt: resp.GetCreatedTimestamp(), - DatabaseName: resp.GetDbName(), - DatabaseID: resp.GetDbId(), - VChannelNames: resp.GetVirtualChannelNames(), + StartPositions: descResp.GetStartPositions(), + Properties: funcutil.KeyValuePair2Map(descResp.GetProperties()), + CreatedAt: descResp.GetCreatedTimestamp(), + DatabaseName: descResp.GetDbName(), + DatabaseID: descResp.GetDbId(), + VChannelNames: descResp.GetVirtualChannelNames(), } m.AddCollection(collection) } diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 8b8d042b2a65f..69de1de58b0c8 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -951,7 +951,7 @@ func (s *Server) GetChannelRecoveryInfo(ctx context.Context, req *datapb.GetChan return resp, nil } - channel := NewRWChannel(req.GetVchannel(), collectionID, nil, collection.Schema, 0) // TODO: remove RWChannel, just use vchannel + collectionID + channel := NewRWChannel(req.GetVchannel(), collectionID, nil, collection.Schema, 0, nil) // TODO: remove RWChannel, just use vchannel + collectionID channelInfo := s.handler.GetDataVChanPositions(channel, allPartitionID) if channelInfo.SeekPosition == nil { log.Warn("channel recovery start position is not found, may collection is on creating") @@ -1230,6 +1230,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), zap.Strings("channels", req.GetChannelNames()), + zap.Any("dbProperties", req.GetDbProperties()), ) log.Info("receive watch channels request") resp := &datapb.WatchChannelsResponse{ @@ -1242,7 +1243,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq }, nil } for _, channelName := range req.GetChannelNames() { - ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp()) + ch := NewRWChannel(channelName, req.GetCollectionID(), req.GetStartPositions(), req.GetSchema(), req.GetCreateTimestamp(), req.GetDbProperties()) err := s.channelManager.Watch(ctx, ch) if err != nil { log.Warn("fail to watch channelName", zap.Error(err)) @@ -1562,6 +1563,7 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt StartPositions: req.GetStartPositions(), Properties: properties, DatabaseID: req.GetDbID(), + DatabaseName: req.GetSchema().GetDbName(), VChannelNames: req.GetVChannels(), } s.meta.AddCollection(collInfo) diff --git a/internal/datanode/channel/channel_manager_test.go b/internal/datanode/channel/channel_manager_test.go index 307b5be0261b3..1a4aa78c56735 100644 --- a/internal/datanode/channel/channel_manager_test.go +++ b/internal/datanode/channel/channel_manager_test.go @@ -70,7 +70,7 @@ func (s *OpRunnerSuite) SetupTest() { Return(nil).Maybe() dispClient := msgdispatcher.NewMockClient(s.T()) - dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + dispClient.EXPECT().Register(mock.Anything, mock.Anything). Return(make(chan *msgstream.MsgPack), nil).Maybe() dispClient.EXPECT().Deregister(mock.Anything).Maybe() diff --git a/internal/flushcommon/pipeline/data_sync_service.go b/internal/flushcommon/pipeline/data_sync_service.go index 3e69ad016d433..4e69cfa82669d 100644 --- a/internal/flushcommon/pipeline/data_sync_service.go +++ b/internal/flushcommon/pipeline/data_sync_service.go @@ -350,7 +350,13 @@ func NewDataSyncService(initCtx context.Context, pipelineParams *util.PipelinePa return nil, err } - input, err := createNewInputFromDispatcher(initCtx, pipelineParams.DispClient, info.GetVchan().GetChannelName(), info.GetVchan().GetSeekPosition()) + input, err := createNewInputFromDispatcher(initCtx, + pipelineParams.DispClient, + info.GetVchan().GetChannelName(), + info.GetVchan().GetSeekPosition(), + info.GetSchema(), + info.GetDbProperties(), + ) if err != nil { return nil, err } diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go index cb62f585f63a4..2aa8e3927b903 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node.go @@ -23,8 +23,11 @@ import ( "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/flowgraph" + pkgcommon "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/common" @@ -57,11 +60,29 @@ func newDmInputNode(dmNodeConfig *nodeConfig, input <-chan *msgstream.MsgPack) * return node } -func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgdispatcher.Client, vchannel string, seekPos *msgpb.MsgPosition) (<-chan *msgstream.MsgPack, error) { +func createNewInputFromDispatcher(initCtx context.Context, + dispatcherClient msgdispatcher.Client, + vchannel string, + seekPos *msgpb.MsgPosition, + schema *schemapb.CollectionSchema, + dbProperties []*commonpb.KeyValuePair, +) (<-chan *msgstream.MsgPack, error) { log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("vchannel", vchannel)) + replicateID, _ := pkgcommon.GetReplicateID(schema.GetProperties()) + if replicateID == "" { + log.Info("datanode consume without replicateID, try to get replicateID from dbProperties", zap.Any("dbProperties", dbProperties)) + replicateID, _ = pkgcommon.GetReplicateID(dbProperties) + } + replicateConfig := msgstream.GetReplicateConfig(replicateID, schema.GetDbName(), schema.GetName()) + if seekPos != nil && len(seekPos.MsgID) != 0 { - input, err := dispatcherClient.Register(initCtx, vchannel, seekPos, common.SubscriptionPositionUnknown) + input, err := dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{ + VChannel: vchannel, + Pos: seekPos, + SubPos: common.SubscriptionPositionUnknown, + ReplicateConfig: replicateConfig, + }) if err != nil { return nil, err } @@ -71,7 +92,12 @@ func createNewInputFromDispatcher(initCtx context.Context, dispatcherClient msgd zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp())))) return input, err } - input, err := dispatcherClient.Register(initCtx, vchannel, nil, common.SubscriptionPositionEarliest) + input, err := dispatcherClient.Register(initCtx, &msgdispatcher.StreamConfig{ + VChannel: vchannel, + Pos: nil, + SubPos: common.SubscriptionPositionEarliest, + ReplicateConfig: replicateConfig, + }) if err != nil { return nil, err } diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go index 2f6a298a206fd..e5aaa47eeaa6d 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go @@ -62,6 +62,9 @@ func (mm *mockMsgStreamFactory) NewMsgStreamDisposer(ctx context.Context) func([ type mockTtMsgStream struct{} +func (mtm *mockTtMsgStream) SetReplicate(config *msgstream.ReplicateConfig) { +} + func (mtm *mockTtMsgStream) Close() {} func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack { diff --git a/internal/flushcommon/pipeline/flow_graph_manager_test.go b/internal/flushcommon/pipeline/flow_graph_manager_test.go index c89163f59270f..6108d9911ae79 100644 --- a/internal/flushcommon/pipeline/flow_graph_manager_test.go +++ b/internal/flushcommon/pipeline/flow_graph_manager_test.go @@ -65,7 +65,7 @@ func TestFlowGraphManager(t *testing.T) { wbm.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) dispClient := msgdispatcher.NewMockClient(t) - dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil) + dispClient.EXPECT().Register(mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil) dispClient.EXPECT().Deregister(mock.Anything) pipelineParams := &util.PipelineParams{ @@ -151,7 +151,7 @@ func newFlowGraphManager(t *testing.T) (string, FlowgraphManager) { wbm.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) dispClient := msgdispatcher.NewMockClient(t) - dispClient.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil) + dispClient.EXPECT().Register(mock.Anything, mock.Anything).Return(make(chan *msgstream.MsgPack), nil) pipelineParams := &util.PipelineParams{ Ctx: context.TODO(), diff --git a/internal/metastore/model/collection.go b/internal/metastore/model/collection.go index 13cbac1d3d686..2a41804f14641 100644 --- a/internal/metastore/model/collection.go +++ b/internal/metastore/model/collection.go @@ -15,6 +15,7 @@ type Collection struct { CollectionID int64 Partitions []*Partition Name string + DBName string Description string AutoID bool Fields []*Field @@ -41,6 +42,7 @@ func (c *Collection) Clone() *Collection { DBID: c.DBID, CollectionID: c.CollectionID, Name: c.Name, + DBName: c.DBName, Description: c.Description, AutoID: c.AutoID, Fields: CloneFields(c.Fields), @@ -99,6 +101,7 @@ func UnmarshalCollectionModel(coll *pb.CollectionInfo) *Collection { CollectionID: coll.ID, DBID: coll.DbId, Name: coll.Schema.Name, + DBName: coll.Schema.DbName, Description: coll.Schema.Description, AutoID: coll.Schema.AutoID, Fields: UnmarshalFieldModels(coll.GetSchema().GetFields()), @@ -154,6 +157,7 @@ func marshalCollectionModelWithConfig(coll *Collection, c *config) *pb.Collectio Description: coll.Description, AutoID: coll.AutoID, EnableDynamicField: coll.EnableDynamicField, + DbName: coll.DBName, } if c.withFields { diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index ecab1cdcfb07c..be2fd77d4db6b 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -538,6 +538,7 @@ message ChannelWatchInfo { // watch progress, deprecated int32 progress = 6; int64 opID = 7; + repeated common.KeyValuePair dbProperties = 8; } enum CompactionType { @@ -655,6 +656,7 @@ message WatchChannelsRequest { repeated common.KeyDataPair start_positions = 3; schema.CollectionSchema schema = 4; uint64 create_timestamp = 5; + repeated common.KeyValuePair db_properties = 6; } message WatchChannelsResponse { diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 8a5d688c893b0..f05a07dc7b028 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -323,6 +323,7 @@ message LoadMetaInfo { string db_name = 5; // Only used for metrics label. string resource_group = 6; // Only used for metrics label. repeated int64 load_fields = 7; + repeated common.KeyValuePair db_properties = 8; } message WatchDmChannelsRequest { diff --git a/internal/proto/root_coord.proto b/internal/proto/root_coord.proto index 435d3050efea7..4fdd208ebc418 100644 --- a/internal/proto/root_coord.proto +++ b/internal/proto/root_coord.proto @@ -146,6 +146,7 @@ service RootCoord { message AllocTimestampRequest { common.MsgBase base = 1; uint32 count = 3; + uint64 blockTimestamp = 4; } message AllocTimestampResponse { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 07c74ff8b4903..4caa567128055 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -160,7 +160,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p } globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String())) - case commonpb.MsgType_DropDatabase: + case commonpb.MsgType_DropDatabase, commonpb.MsgType_AlterDatabase: globalMetaCache.RemoveDatabase(ctx, request.GetDbName()) case commonpb.MsgType_AlterCollection, commonpb.MsgType_AlterCollectionField: if request.CollectionID != UniqueID(0) { @@ -6325,13 +6325,19 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil } - if paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() { + collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool() + ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() + + // replicate message can be use in two ways, otherwise return error + // 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode + // 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode + if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) { return &milvuspb.ReplicateMessageResponse{ Status: merr.Status(merr.ErrDenyReplicateMessage), }, nil } - var err error + var err error if req.GetChannelName() == "" { log.Ctx(ctx).Warn("channel name is empty") return &milvuspb.ReplicateMessageResponse{ @@ -6369,6 +6375,18 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate StartPositions: req.StartPositions, EndPositions: req.EndPositions, } + checkCollectionReplicateProperty := func(dbName, collectionName string) bool { + if !collectionReplicateEnable { + return true + } + replicateID, err := GetReplicateID(ctx, dbName, collectionName) + if err != nil { + log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err)) + return false + } + return replicateID != "" + } + // getTsMsgFromConsumerMsg for i, msgBytes := range req.Msgs { header := commonpb.MsgHeader{} @@ -6388,6 +6406,9 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate } switch realMsg := tsMsg.(type) { case *msgstream.InsertMsg: + if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) { + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil + } assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(), realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs) if err != nil { @@ -6402,6 +6423,10 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate realMsg.SegmentID = assignSegmentID break } + case *msgstream.DeleteMsg: + if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) { + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil + } } msgPack.Msgs = append(msgPack.Msgs, tsMsg) } diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 2e7287d93d8e7..ba40366211c62 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -1279,6 +1279,9 @@ func TestProxy_Delete(t *testing.T) { }, } schema := newSchemaInfo(collSchema) + basicInfo := &collectionInfo{ + collID: collectionID, + } paramtable.Init() t.Run("delete run failed", func(t *testing.T) { @@ -1311,6 +1314,7 @@ func TestProxy_Delete(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(partitionID, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(basicInfo, nil) chMgr.On("getVChannels", mock.Anything).Return(channels, nil) chMgr.On("getChannels", mock.Anything).Return(nil, fmt.Errorf("mock error")) globalMetaCache = cache @@ -1863,3 +1867,330 @@ func TestRegisterRestRouter(t *testing.T) { }) } } + +func TestReplicateMessageForCollectionMode(t *testing.T) { + paramtable.Init() + ctx := context.Background() + insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 10, + EndTimestamp: 10, + HashValues: []uint32{0}, + MsgPosition: &msgstream.MsgPosition{ + ChannelName: "foo", + MsgID: []byte("mock message id 2"), + }, + }, + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 10001, + Timestamp: 10, + SourceID: -1, + }, + ShardName: "foo_v1", + DbName: "default", + CollectionName: "foo_collection", + PartitionName: "_default", + DbID: 1, + CollectionID: 11, + PartitionID: 22, + SegmentID: 33, + Timestamps: []uint64{10}, + RowIDs: []int64{66}, + NumRows: 1, + }, + } + insertMsgBytes, _ := insertMsg.Marshal(insertMsg) + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 20, + EndTimestamp: 20, + HashValues: []uint32{0}, + MsgPosition: &msgstream.MsgPosition{ + ChannelName: "foo", + MsgID: []byte("mock message id 2"), + }, + }, + DeleteRequest: &msgpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + MsgID: 10002, + Timestamp: 20, + SourceID: -1, + }, + ShardName: "foo_v1", + DbName: "default", + CollectionName: "foo_collection", + PartitionName: "_default", + DbID: 1, + CollectionID: 11, + PartitionID: 22, + }, + } + deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg) + + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + + t.Run("replicate message in the replicate collection mode", func(t *testing.T) { + defer func() { + paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key) + paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key) + }() + + { + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") + paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "false") + p := &Proxy{} + p.UpdateStateCode(commonpb.StateCode_Healthy) + r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "foo", + }) + assert.NoError(t, err) + assert.Error(t, merr.Error(r.Status)) + } + + { + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true") + p := &Proxy{} + p.UpdateStateCode(commonpb.StateCode_Healthy) + r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "foo", + }) + assert.NoError(t, err) + assert.Error(t, merr.Error(r.Status)) + } + }) + + t.Run("replicate message for the replicate collection mode", func(t *testing.T) { + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") + paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true") + defer func() { + paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key) + paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key) + }() + + mockCache := NewMockCache(t) + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice() + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice() + globalMetaCache = mockCache + + { + p := &Proxy{ + replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil), + } + p.UpdateStateCode(commonpb.StateCode_Healthy) + r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "foo", + Msgs: [][]byte{insertMsgBytes.([]byte)}, + }) + assert.NoError(t, err) + assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode)) + } + + { + p := &Proxy{ + replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil), + } + p.UpdateStateCode(commonpb.StateCode_Healthy) + r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "foo", + Msgs: [][]byte{deleteMsgBytes.([]byte)}, + }) + assert.NoError(t, err) + assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode)) + } + }) +} + +func TestAlterCollectionReplicateProperty(t *testing.T) { + paramtable.Init() + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") + paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true") + defer func() { + paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key) + paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key) + }() + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + replicateID: "local-milvus", + }, nil).Maybe() + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe() + mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil) + globalMetaCache = mockCache + + factory := newMockMsgStreamFactory() + msgStreamObj := msgstream.NewMockMsgStream(t) + msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe() + msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe() + msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe() + msgStreamObj.EXPECT().Close().Return().Maybe() + mockMsgID1 := mqcommon.NewMockMessageID(t) + mockMsgID2 := mqcommon.NewMockMessageID(t) + mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe() + msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{ + "alter_property": {mockMsgID1, mockMsgID2}, + }, nil).Maybe() + + factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { + return msgStreamObj, nil + } + resourceManager := resource.NewManager(time.Second, 2*time.Second, nil) + manager := NewReplicateStreamManager(context.Background(), factory, resourceManager) + + ctx := context.Background() + var startTt uint64 = 10 + startTime := time.Now() + dataCoord := &mockDataCoord{} + dataCoord.expireTime = Timestamp(1000) + segAllocator, err := newSegIDAssigner(ctx, dataCoord, func() Timestamp { + return Timestamp(time.Since(startTime).Seconds()) + startTt + }) + assert.NoError(t, err) + segAllocator.Start() + + mockRootcoord := mocks.NewMockRootCoordClient(t) + mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *rootcoordpb.AllocTimestampRequest, option ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { + return &rootcoordpb.AllocTimestampResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + Timestamp: Timestamp(time.Since(startTime).Seconds()) + startTt, + }, nil + }) + mockRootcoord.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil) + + p := &Proxy{ + ctx: ctx, + replicateStreamManager: manager, + segAssigner: segAllocator, + rootCoord: mockRootcoord, + } + tsoAllocatorIns := newMockTsoAllocator() + p.sched, err = newTaskScheduler(p.ctx, tsoAllocatorIns, p.factory) + assert.NoError(t, err) + p.sched.Start() + defer p.sched.Close() + p.UpdateStateCode(commonpb.StateCode_Healthy) + + getInsertMsgBytes := func(channel string, ts uint64) []byte { + insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: ts, + EndTimestamp: ts, + HashValues: []uint32{0}, + MsgPosition: &msgstream.MsgPosition{ + ChannelName: channel, + MsgID: []byte("mock message id 2"), + }, + }, + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 10001, + Timestamp: ts, + SourceID: -1, + }, + ShardName: channel + "_v1", + DbName: "default", + CollectionName: "foo_collection", + PartitionName: "_default", + DbID: 1, + CollectionID: 11, + PartitionID: 22, + SegmentID: 33, + Timestamps: []uint64{ts}, + RowIDs: []int64{66}, + NumRows: 1, + }, + } + insertMsgBytes, _ := insertMsg.Marshal(insertMsg) + return insertMsgBytes.([]byte) + } + getDeleteMsgBytes := func(channel string, ts uint64) []byte { + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: ts, + EndTimestamp: ts, + HashValues: []uint32{0}, + MsgPosition: &msgstream.MsgPosition{ + ChannelName: "foo", + MsgID: []byte("mock message id 2"), + }, + }, + DeleteRequest: &msgpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + MsgID: 10002, + Timestamp: ts, + SourceID: -1, + }, + ShardName: channel + "_v1", + DbName: "default", + CollectionName: "foo_collection", + PartitionName: "_default", + DbID: 1, + CollectionID: 11, + PartitionID: 22, + }, + } + deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg) + return deleteMsgBytes.([]byte) + } + + go func() { + // replicate message + var replicateResp *milvuspb.ReplicateMessageResponse + var err error + replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "alter_property_1", + Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+5)}, + }) + assert.NoError(t, err) + assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason) + + replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "alter_property_2", + Msgs: [][]byte{getDeleteMsgBytes("alter_property_2", startTt+5)}, + }) + assert.NoError(t, err) + assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason) + + time.Sleep(time.Second) + + replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "alter_property_1", + Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+10)}, + }) + assert.NoError(t, err) + assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason) + + replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{ + ChannelName: "alter_property_2", + Msgs: [][]byte{getInsertMsgBytes("alter_property_2", startTt+10)}, + }) + assert.NoError(t, err) + assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason) + }() + time.Sleep(200 * time.Millisecond) + + // alter collection property + statusResp, err := p.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{ + DbName: "default", + CollectionName: "foo_collection", + Properties: []*commonpb.KeyValuePair{ + { + Key: "replicate.endTS", + Value: "1", + }, + }, + }) + assert.NoError(t, err) + assert.True(t, merr.Ok(statusResp)) +} diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index d3a4a0c32dbb3..feec9752f78bb 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -102,10 +102,12 @@ type collectionInfo struct { createdUtcTimestamp uint64 consistencyLevel commonpb.ConsistencyLevel partitionKeyIsolation bool + replicateID string } type databaseInfo struct { dbID typeutil.UniqueID + properties []*commonpb.KeyValuePair createdTimestamp uint64 } @@ -478,6 +480,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, m.collInfo[database] = make(map[string]*collectionInfo) } + replicateID, _ := common.GetReplicateID(collection.Properties) m.collInfo[database][collectionName] = &collectionInfo{ collID: collection.CollectionID, schema: schemaInfo, @@ -486,6 +489,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, createdUtcTimestamp: collection.CreatedUtcTimestamp, consistencyLevel: collection.ConsistencyLevel, partitionKeyIsolation: isolation, + replicateID: replicateID, } log.Ctx(ctx).Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), @@ -571,10 +575,19 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID - if !ok || collInfo.collID != collectionID { + // Why use collectionID? Because the collectionID is not always provided in the proxy. + if !ok || (collectionID != 0 && collInfo.collID != collectionID) { tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() + if collectionID == 0 { + collInfo, err := m.UpdateByName(ctx, database, collectionName) + if err != nil { + return nil, err + } + metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + return collInfo, nil + } collInfo, err := m.UpdateByID(ctx, database, collectionID) if err != nil { return nil, err @@ -1225,6 +1238,7 @@ func (m *MetaCache) GetDatabaseInfo(ctx context.Context, database string) (*data defer m.mu.Unlock() dbInfo := &databaseInfo{ dbID: resp.GetDbID(), + properties: resp.Properties, createdTimestamp: resp.GetCreatedTimestamp(), } m.dbInfo[database] = dbInfo diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 4592786f1cccc..344363538892a 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { wg.Wait() } +func TestMetaCacheGetCollectionWithUpdate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + ctx := context.Background() + rootCoord := mocks.NewMockRootCoordClient(t) + queryCoord := mocks.NewMockQueryCoordClient(t) + rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil) + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) + assert.NoError(t, err) + t.Run("update with name", func(t *testing.T) { + rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + Schema: &schemapb.CollectionSchema{ + Name: "bar", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "p", + }, + { + FieldID: 100, + Name: "pk", + }, + }, + }, + ShardsNum: 1, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"}, + }, nil).Once() + rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + PartitionIDs: []typeutil.UniqueID{11}, + PartitionNames: []string{"p1"}, + CreatedTimestamps: []uint64{11}, + CreatedUtcTimestamps: []uint64{11}, + }, nil).Once() + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once() + c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1) + assert.NoError(t, err) + assert.Equal(t, c.collID, int64(1)) + assert.Equal(t, c.schema.Name, "bar") + }) + + t.Run("update with name", func(t *testing.T) { + rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), + CollectionID: 1, + Schema: &schemapb.CollectionSchema{ + Name: "bar", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "p", + }, + { + FieldID: 100, + Name: "pk", + }, + }, + }, + ShardsNum: 1, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"}, + }, nil).Once() + rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Success(), + PartitionIDs: []typeutil.UniqueID{11}, + PartitionNames: []string{"p1"}, + CreatedTimestamps: []uint64{11}, + CreatedUtcTimestamps: []uint64{11}, + }, nil).Once() + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once() + c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0) + assert.NoError(t, err) + assert.Equal(t, c.collID, int64(1)) + assert.Equal(t, c.schema.Name, "bar") + }) +} + func TestMetaCache_GetCollectionName(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} diff --git a/internal/proxy/mock_msgstream_test.go b/internal/proxy/mock_msgstream_test.go index 93cf069bf49a7..7c16e7efdf586 100644 --- a/internal/proxy/mock_msgstream_test.go +++ b/internal/proxy/mock_msgstream_test.go @@ -40,6 +40,9 @@ func (m *mockMsgStream) ForceEnableProduce(enabled bool) { } } +func (m *mockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) { +} + func newMockMsgStream() *mockMsgStream { return &mockMsgStream{} } diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 2b89f4fbc3e5d..18da0945d6bd5 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -314,6 +314,9 @@ func (ms *simpleMockMsgStream) CheckTopicValid(topic string) error { func (ms *simpleMockMsgStream) ForceEnableProduce(enabled bool) { } +func (ms *simpleMockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) { +} + func newSimpleMockMsgStream() *simpleMockMsgStream { return &simpleMockMsgStream{ msgChan: make(chan *msgstream.MsgPack, 1024), diff --git a/internal/proxy/task.go b/internal/proxy/task.go index dc0585ac1af20..e6488aa3627ff 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/ctokenizer" "github.com/milvus-io/milvus/pkg/common" @@ -1081,6 +1082,25 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error { } } + _, ok := common.IsReplicateEnabled(t.Properties) + if ok { + return merr.WrapErrParameterInvalidMsg("can't set the replicate.id property") + } + endTS, ok := common.GetReplicateEndTS(t.Properties) + if ok && collBasicInfo.replicateID != "" { + allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: 1, + BlockTimestamp: endTS, + }) + if err = merr.CheckRPCCall(allocResp, err); err != nil { + return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) + } + if allocResp.GetTimestamp() <= endTS { + return merr.WrapErrServiceInternal("alter collection: alloc timestamp failed, timestamp is not greater than endTS", + fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS)) + } + } + return nil } diff --git a/internal/proxy/task_database.go b/internal/proxy/task_database.go index c518106aa6024..bee84860d3c04 100644 --- a/internal/proxy/task_database.go +++ b/internal/proxy/task_database.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "fmt" "go.uber.org/zap" @@ -9,6 +10,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" @@ -274,6 +276,34 @@ func (t *alterDatabaseTask) OnEnqueue() error { } func (t *alterDatabaseTask) PreExecute(ctx context.Context) error { + _, ok := common.GetReplicateID(t.Properties) + if ok { + return merr.WrapErrParameterInvalidMsg("can't set the replicate id property in alter database request") + } + endTS, ok := common.GetReplicateEndTS(t.Properties) + if !ok { // not exist replicate end ts property + return nil + } + cacheInfo, err := globalMetaCache.GetDatabaseInfo(ctx, t.DbName) + if err != nil { + return err + } + oldReplicateEnable, _ := common.IsReplicateEnabled(cacheInfo.properties) + if !oldReplicateEnable { // old replicate enable is false + return merr.WrapErrParameterInvalidMsg("can't set the replicate end ts property in alter database request when db replicate is disabled") + } + allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: 1, + BlockTimestamp: endTS, + }) + if err = merr.CheckRPCCall(allocResp, err); err != nil { + return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) + } + if allocResp.GetTimestamp() <= endTS { + return merr.WrapErrServiceInternal("alter database: alloc timestamp failed, timestamp is not greater than endTS", + fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS)) + } + return nil } diff --git a/internal/proxy/task_database_test.go b/internal/proxy/task_database_test.go index 17615fba42ee5..3be8c4a8a76e6 100644 --- a/internal/proxy/task_database_test.go +++ b/internal/proxy/task_database_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/grpc/metadata" @@ -201,6 +202,163 @@ func TestAlterDatabase(t *testing.T) { assert.Nil(t, err1) } +func TestAlterDatabaseTaskForReplicateProperty(t *testing.T) { + rc := mocks.NewMockRootCoordClient(t) + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + globalMetaCache = mockCache + + t.Run("replicate id", func(t *testing.T) { + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.MmapEnabledKey, + Value: "true", + }, + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("fail to get database info", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("not enable replicate", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{}, + }, nil).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("fail to alloc ts", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Once() + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("alloc wrong ts", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Once() + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: 999, + }, nil).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + + t.Run("alloc wrong ts", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Once() + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: 1001, + }, nil).Once() + task := &alterDatabaseTask{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{}, + DbName: "test_alter_database", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1000", + }, + }, + }, + rootCoord: rc, + } + err := task.PreExecute(context.Background()) + assert.NoError(t, err) + }) +} + func TestDescribeDatabaseTask(t *testing.T) { rc := mocks.NewMockRootCoordClient(t) diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index ef758a44b385a..dd0f3e1bdd413 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -335,6 +335,15 @@ func (dr *deleteRunner) Init(ctx context.Context) error { return ErrWithLog(log, "Failed to get collection id", merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound)) } + replicateID, err := GetReplicateID(ctx, dr.req.GetDbName(), collName) + if err != nil { + log.Warn("get replicate info failed", zap.String("collectionName", collName), zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + if replicateID != "" { + return merr.WrapErrCollectionReplicateMode("delete") + } + dr.schema, err = globalMetaCache.GetCollectionSchema(ctx, dr.req.GetDbName(), collName) if err != nil { return ErrWithLog(log, "Failed to get collection schema", err) diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index 881fa4ee0402f..5e23c8842e07d 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -297,6 +297,45 @@ func TestDeleteRunner_Init(t *testing.T) { assert.Error(t, dr.Init(context.Background())) }) + t.Run("fail to get collection info", func(t *testing.T) { + dr := deleteRunner{req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + DbName: dbName, + }} + cache := NewMockCache(t) + cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(collectionID, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, + errors.New("mock get collection info")) + + globalMetaCache = cache + assert.Error(t, dr.Init(context.Background())) + }) + + t.Run("deny delete in the replicate mode", func(t *testing.T) { + dr := deleteRunner{req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + DbName: dbName, + }} + cache := NewMockCache(t) + cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(collectionID, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + replicateID: "local-mac", + }, nil) + + globalMetaCache = cache + assert.Error(t, dr.Init(context.Background())) + }) + t.Run("fail get collection schema", func(t *testing.T) { dr := deleteRunner{req: &milvuspb.DeleteRequest{ CollectionName: collectionName, @@ -309,6 +348,7 @@ func TestDeleteRunner_Init(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(collectionID, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil) cache.On("GetCollectionSchema", mock.Anything, // context.Context mock.AnythingOfType("string"), @@ -332,6 +372,7 @@ func TestDeleteRunner_Init(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(collectionID, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil) cache.On("GetCollectionSchema", mock.Anything, // context.Context mock.AnythingOfType("string"), @@ -376,6 +417,7 @@ func TestDeleteRunner_Init(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(schema, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil) globalMetaCache = cache assert.Error(t, dr.Init(context.Background())) @@ -402,6 +444,7 @@ func TestDeleteRunner_Init(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(schema, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil) cache.On("GetPartitionID", mock.Anything, // context.Context mock.AnythingOfType("string"), @@ -431,6 +474,7 @@ func TestDeleteRunner_Init(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(collectionID, nil) + cache.On("GetCollectionInfo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil) cache.On("GetCollectionSchema", mock.Anything, // context.Context mock.AnythingOfType("string"), diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 5a95577bbada4..9de31cd53d600 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -125,6 +125,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return merr.WrapErrAsInputError(merr.WrapErrParameterTooLarge("insert request size exceeds maxInsertSize")) } + replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName) + if err != nil { + log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err)) + return merr.WrapErrAsInputError(err) + } + if replicateID != "" { + return merr.WrapErrCollectionReplicateMode("insert") + } + schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName) if err != nil { log.Ctx(ctx).Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err)) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index ffe88cf94604c..019063bbbc9f1 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -1708,8 +1709,8 @@ func TestTask_Int64PrimaryKey(t *testing.T) { assert.NoError(t, err) shardsNum := int32(2) - prefix := "TestTask_all" - dbName := "" + prefix := "TestTask_int64pk" + dbName := "int64PK" collectionName := prefix + funcutil.GenRandomStr() partitionName := prefix + funcutil.GenRandomStr() @@ -1726,45 +1727,43 @@ func TestTask_Int64PrimaryKey(t *testing.T) { } nb := 10 - t.Run("create collection", func(t *testing.T) { - schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) - marshaledSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - - createColT := &createCollectionTask{ - Condition: NewTaskCondition(ctx), - CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - Schema: marshaledSchema, - ShardsNum: shardsNum, - }, - ctx: ctx, - rootCoord: rc, - result: nil, - schema: nil, - } - - assert.NoError(t, createColT.OnEnqueue()) - assert.NoError(t, createColT.PreExecute(ctx)) - assert.NoError(t, createColT.Execute(ctx)) - assert.NoError(t, createColT.PostExecute(ctx)) + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) - _, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_CreatePartition, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, + createColT := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, DbName: dbName, CollectionName: collectionName, - PartitionName: partitionName, - }) + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + assert.NoError(t, createColT.OnEnqueue()) + assert.NoError(t, createColT.PreExecute(ctx)) + assert.NoError(t, createColT.Execute(ctx)) + assert.NoError(t, createColT.PostExecute(ctx)) + + _, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreatePartition, + MsgID: 0, + Timestamp: 0, + SourceID: paramtable.GetNodeID(), + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, }) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) @@ -1957,7 +1956,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { shardsNum := int32(2) prefix := "TestTask_all" - dbName := "" + dbName := "testvarchar" collectionName := prefix + funcutil.GenRandomStr() partitionName := prefix + funcutil.GenRandomStr() @@ -1975,45 +1974,43 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { } nb := 10 - t.Run("create collection", func(t *testing.T) { - schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false) - marshaledSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - - createColT := &createCollectionTask{ - Condition: NewTaskCondition(ctx), - CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - Schema: marshaledSchema, - ShardsNum: shardsNum, - }, - ctx: ctx, - rootCoord: rc, - result: nil, - schema: nil, - } - - assert.NoError(t, createColT.OnEnqueue()) - assert.NoError(t, createColT.PreExecute(ctx)) - assert.NoError(t, createColT.Execute(ctx)) - assert.NoError(t, createColT.PostExecute(ctx)) + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) - _, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_CreatePartition, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, + createColT := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, DbName: dbName, CollectionName: collectionName, - PartitionName: partitionName, - }) + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + assert.NoError(t, createColT.OnEnqueue()) + assert.NoError(t, createColT.PreExecute(ctx)) + assert.NoError(t, createColT.Execute(ctx)) + assert.NoError(t, createColT.PostExecute(ctx)) + + _, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreatePartition, + MsgID: 0, + Timestamp: 0, + SourceID: paramtable.GetNodeID(), + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, }) - collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, collectionName) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) @@ -3444,30 +3441,28 @@ func TestPartitionKey(t *testing.T) { marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) - t.Run("create collection", func(t *testing.T) { - createCollectionTask := &createCollectionTask{ - Condition: NewTaskCondition(ctx), - CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), - Timestamp: Timestamp(time.Now().UnixNano()), - }, - DbName: "", - CollectionName: collectionName, - Schema: marshaledSchema, - ShardsNum: shardsNum, - NumPartitions: common.DefaultPartitionsWithPartitionKey, + createCollectionTask := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Timestamp: Timestamp(time.Now().UnixNano()), }, - ctx: ctx, - rootCoord: rc, - result: nil, - schema: nil, - } - err = createCollectionTask.PreExecute(ctx) - assert.NoError(t, err) - err = createCollectionTask.Execute(ctx) - assert.NoError(t, err) - }) + DbName: "", + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + NumPartitions: common.DefaultPartitionsWithPartitionKey, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + err = createCollectionTask.PreExecute(ctx) + assert.NoError(t, err) + err = createCollectionTask.Execute(ctx) + assert.NoError(t, err) collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) assert.NoError(t, err) @@ -3500,7 +3495,7 @@ func TestPartitionKey(t *testing.T) { _ = segAllocator.Start() defer segAllocator.Close() - partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName) assert.NoError(t, err) assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames))) @@ -4269,3 +4264,136 @@ func TestTaskPartitionKeyIsolation(t *testing.T) { "can not alter partition key isolation mode if the collection already has a vector index. Please drop the index first") }) } + +func TestAlterCollectionForReplicateProperty(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + replicateID: "local-mac-1", + }, nil).Maybe() + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe() + mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Maybe() + globalMetaCache = mockCache + ctx := context.Background() + mockRootcoord := mocks.NewMockRootCoordClient(t) + t.Run("invalid replicate id", func(t *testing.T) { + task := &alterCollectionTask{ + AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "xxxxx", + }, + }, + }, + rootCoord: mockRootcoord, + } + + err := task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("empty replicate id", func(t *testing.T) { + task := &alterCollectionTask{ + AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ + CollectionName: "test", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "", + }, + }, + }, + rootCoord: mockRootcoord, + } + + err := task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("fail to alloc ts", func(t *testing.T) { + task := &alterCollectionTask{ + AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ + CollectionName: "test", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "100", + }, + }, + }, + rootCoord: mockRootcoord, + } + + mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + err := task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("alloc wrong ts", func(t *testing.T) { + task := &alterCollectionTask{ + AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ + CollectionName: "test", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "100", + }, + }, + }, + rootCoord: mockRootcoord, + } + + mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: 99, + }, nil).Once() + err := task.PreExecute(ctx) + assert.Error(t, err) + }) +} + +func TestInsertForReplicate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + globalMetaCache = mockCache + + t.Run("get replicate id fail", func(t *testing.T) { + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() + task := &insertTask{ + insertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: "foo", + }, + }, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) + t.Run("insert with replicate id", func(t *testing.T) { + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + schema: &schemaInfo{ + CollectionSchema: &schemapb.CollectionSchema{ + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-mac", + }, + }, + }, + }, + replicateID: "local-mac", + }, nil).Once() + task := &insertTask{ + insertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: "foo", + }, + }, + } + err := task.PreExecute(context.Background()) + assert.Error(t, err) + }) +} diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 24fbaa60146fb..1de223fa4124d 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -292,6 +292,15 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { Timestamp: it.EndTs(), } + replicateID, err := GetReplicateID(ctx, it.req.GetDbName(), collectionName) + if err != nil { + log.Warn("get replicate info failed", zap.String("collectionName", collectionName), zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + if replicateID != "" { + return merr.WrapErrCollectionReplicateMode("upsert") + } + schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName) if err != nil { log.Warn("Failed to get collection schema", diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index f2f9b87b16651..75fd39964b00e 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -19,6 +19,7 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -325,3 +326,37 @@ func TestUpsertTask(t *testing.T) { assert.ElementsMatch(t, channels, ut.pChannels) }) } + +func TestUpsertTaskForReplicate(t *testing.T) { + cache := globalMetaCache + defer func() { globalMetaCache = cache }() + mockCache := NewMockCache(t) + globalMetaCache = mockCache + ctx := context.Background() + + t.Run("fail to get collection info", func(t *testing.T) { + ut := upsertTask{ + ctx: ctx, + req: &milvuspb.UpsertRequest{ + CollectionName: "col-0", + }, + } + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("foo")).Once() + err := ut.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("replicate mode", func(t *testing.T) { + ut := upsertTask{ + ctx: ctx, + req: &milvuspb.UpsertRequest{ + CollectionName: "col-0", + }, + } + mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{ + replicateID: "local-mac", + }, nil).Once() + err := ut.PreExecute(ctx) + assert.Error(t, err) + }) +} diff --git a/internal/proxy/util.go b/internal/proxy/util.go index d2f3bc30ddb73..92ae381095314 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -2212,3 +2212,22 @@ func GetFailedResponse(req any, err error) any { } return nil } + +func GetReplicateID(ctx context.Context, database, collectionName string) (string, error) { + if globalMetaCache == nil { + return "", merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") + } + colInfo, err := globalMetaCache.GetCollectionInfo(ctx, database, collectionName, 0) + if err != nil { + return "", err + } + if colInfo.replicateID != "" { + return colInfo.replicateID, nil + } + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, database) + if err != nil { + return "", err + } + replicateID, _ := common.GetReplicateID(dbInfo.properties) + return replicateID, nil +} diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 0a82c2000bb09..a5b893619c6eb 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -36,6 +36,7 @@ import ( coordMocks "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/querycoordv2/checkers" "github.com/milvus-io/milvus/internal/querycoordv2/dist" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -614,6 +615,7 @@ func (suite *ServerSuite) hackServer() { ) suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe() + suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil).Maybe() suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).Return(nil, nil).Maybe() for _, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index b3ba5148ed92b..69e3d8d1c56a3 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -56,9 +56,7 @@ var ( ) func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - log := log.Ctx(ctx).With(zap.Int64s("collections", req.GetCollectionIDs())) - - log.Info("show collections request received") + log.Ctx(ctx).Debug("show collections request received", zap.Int64s("collections", req.GetCollectionIDs())) if err := merr.CheckHealthy(s.State()); err != nil { msg := "failed to show collections" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 263a4e1ec2d71..b66d26b0f722e 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -341,18 +341,23 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID()) if err != nil { - log.Warn("failed to get collection info") + log.Warn("failed to get collection info", zap.Error(err)) return err } loadFields := ex.meta.GetLoadFields(ctx, task.CollectionID()) partitions, err := utils.GetPartitions(ctx, ex.targetMgr, task.CollectionID()) if err != nil { - log.Warn("failed to get partitions of collection") + log.Warn("failed to get partitions of collection", zap.Error(err)) return err } indexInfo, err := ex.broker.ListIndexes(ctx, task.CollectionID()) if err != nil { - log.Warn("fail to get index meta of collection") + log.Warn("fail to get index meta of collection", zap.Error(err)) + return err + } + dbResp, err := ex.broker.DescribeDatabase(ctx, collectionInfo.GetDbName()) + if err != nil { + log.Warn("failed to get database info", zap.Error(err)) return err } loadMeta := packLoadMeta( @@ -363,6 +368,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { loadFields, partitions..., ) + loadMeta.DbProperties = dbResp.GetProperties() dmChannel := ex.targetMgr.GetDmChannel(ctx, task.CollectionID(), action.ChannelName(), meta.NextTarget) if dmChannel == nil { diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 2eb53639d5f9d..6de6a70d49175 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" @@ -230,6 +231,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { }, }, nil }) + suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil) for channel, segment := range suite.growingSegments { suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment). Return([]*datapb.SegmentInfo{ diff --git a/internal/querynodev2/metrics_info_test.go b/internal/querynodev2/metrics_info_test.go index a5162ac63c1ff..03c99519a0625 100644 --- a/internal/querynodev2/metrics_info_test.go +++ b/internal/querynodev2/metrics_info_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" @@ -46,7 +47,7 @@ func TestGetPipelineJSON(t *testing.T) { collectionManager := segments.NewMockCollectionManager(t) segmentManager := segments.NewMockSegmentManager(t) - collectionManager.EXPECT().Get(mock.Anything).Return(&segments.Collection{}) + collectionManager.EXPECT().Get(mock.Anything).Return(segments.NewTestCollection(1, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{})) manager := &segments.Manager{ Collection: collectionManager, Segment: segmentManager, diff --git a/internal/querynodev2/pipeline/filter_node_test.go b/internal/querynodev2/pipeline/filter_node_test.go index 001ca4cef21bb..738a6e0399e53 100644 --- a/internal/querynodev2/pipeline/filter_node_test.go +++ b/internal/querynodev2/pipeline/filter_node_test.go @@ -72,7 +72,7 @@ func (suite *FilterNodeSuite) TestWithLoadCollection() { suite.validSegmentIDs = []int64{2, 3, 4, 5, 6} // mock - collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection) + collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadCollection, nil) for _, partitionID := range suite.partitionIDs { collection.AddPartition(partitionID) } @@ -111,7 +111,7 @@ func (suite *FilterNodeSuite) TestWithLoadPartation() { suite.validSegmentIDs = []int64{2, 3, 4, 5, 6} // mock - collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition) + collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadPartition, nil) collection.AddPartition(suite.partitionIDs[0]) mockCollectionManager := segments.NewMockCollectionManager(suite.T()) diff --git a/internal/querynodev2/pipeline/manager.go b/internal/querynodev2/pipeline/manager.go index 81d56202296cd..6c6e26cfc7348 100644 --- a/internal/querynodev2/pipeline/manager.go +++ b/internal/querynodev2/pipeline/manager.go @@ -85,7 +85,7 @@ func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) { return nil, merr.WrapErrChannelNotFound(channel, "delegator not found") } - newPipeLine, err := NewPipeLine(collectionID, channel, m.dataManager, m.dispatcher, delegator) + newPipeLine, err := NewPipeLine(collection, channel, m.dataManager, m.dispatcher, delegator) if err != nil { return nil, merr.WrapErrServiceUnavailable(err.Error(), "failed to create new pipeline") } diff --git a/internal/querynodev2/pipeline/manager_test.go b/internal/querynodev2/pipeline/manager_test.go index 6fb460587a7ee..48f09ce3ddfd5 100644 --- a/internal/querynodev2/pipeline/manager_test.go +++ b/internal/querynodev2/pipeline/manager_test.go @@ -24,9 +24,10 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" - "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -73,9 +74,9 @@ func (suite *PipelineManagerTestSuite) SetupTest() { func (suite *PipelineManagerTestSuite) TestBasic() { // init mock // mock collection manager - suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{}) + suite.collectionManager.EXPECT().Get(suite.collectionID).Return(segments.NewTestCollection(suite.collectionID, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{})) // mock mq factory - suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.msgChan, nil) + suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) // build manager diff --git a/internal/querynodev2/pipeline/pipeline.go b/internal/querynodev2/pipeline/pipeline.go index 46bf5280447dd..add94d4b9cc18 100644 --- a/internal/querynodev2/pipeline/pipeline.go +++ b/internal/querynodev2/pipeline/pipeline.go @@ -19,7 +19,9 @@ package pipeline import ( "github.com/milvus-io/milvus/internal/querynodev2/delegator" base "github.com/milvus-io/milvus/internal/util/pipeline" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -45,17 +47,23 @@ func (p *pipeline) Close() { } func NewPipeLine( - collectionID UniqueID, + collection *Collection, channel string, manager *DataManager, dispatcher msgdispatcher.Client, delegator delegator.ShardDelegator, ) (Pipeline, error) { + collectionID := collection.ID() + replicateID, _ := common.GetReplicateID(collection.Schema().GetProperties()) + if replicateID == "" { + replicateID, _ = common.GetReplicateID(collection.GetDBProperties()) + } + replicateConfig := msgstream.GetReplicateConfig(replicateID, collection.GetDBName(), collection.Schema().Name) pipelineQueueLength := paramtable.Get().QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32() p := &pipeline{ collectionID: collectionID, - StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel), + StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel, replicateConfig), } filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength) diff --git a/internal/querynodev2/pipeline/pipeline_test.go b/internal/querynodev2/pipeline/pipeline_test.go index 292dea24518fd..dd153527e1c67 100644 --- a/internal/querynodev2/pipeline/pipeline_test.go +++ b/internal/querynodev2/pipeline/pipeline_test.go @@ -24,13 +24,14 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" - "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -103,11 +104,17 @@ func (suite *PipelineTestSuite) TestBasic() { schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, + DbProperties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, }) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) // mock mq factory - suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.msgChan, nil) + suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) // mock delegator @@ -136,16 +143,16 @@ func (suite *PipelineTestSuite) TestBasic() { Collection: suite.collectionManager, Segment: suite.segmentManager, } - pipeline, err := NewPipeLine(suite.collectionID, suite.channel, manager, suite.msgDispatcher, suite.delegator) + pipelineObj, err := NewPipeLine(collection, suite.channel, manager, suite.msgDispatcher, suite.delegator) suite.NoError(err) // Init Consumer - err = pipeline.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{}) + err = pipelineObj.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{}) suite.NoError(err) - err = pipeline.Start() + err = pipelineObj.Start() suite.NoError(err) - defer pipeline.Close() + defer pipelineObj.Close() // build input msg in := suite.buildMsgPack(schema) diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 10b774eaa8806..f4ad0f930f599 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -148,6 +148,7 @@ type Collection struct { partitions *typeutil.ConcurrentSet[int64] loadType querypb.LoadType dbName string + dbProperties []*commonpb.KeyValuePair resourceGroup string // resource group of node may be changed if node transfer, // but Collection in Manager will be released before assign new replica of new resource group on these node. @@ -166,6 +167,10 @@ func (c *Collection) GetDBName() string { return c.dbName } +func (c *Collection) GetDBProperties() []*commonpb.KeyValuePair { + return c.dbProperties +} + // GetResourceGroup returns the resource group of collection. func (c *Collection) GetResourceGroup() string { return c.resourceGroup @@ -284,6 +289,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM partitions: typeutil.NewConcurrentSet[int64](), loadType: loadMetaInfo.GetLoadType(), dbName: loadMetaInfo.GetDbName(), + dbProperties: loadMetaInfo.GetDbProperties(), resourceGroup: loadMetaInfo.GetResourceGroup(), refCount: atomic.NewUint32(0), isGpuIndex: isGpuIndex, @@ -297,13 +303,16 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM return coll } -func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *Collection { - return &Collection{ +// Only for test +func NewTestCollection(collectionID int64, loadType querypb.LoadType, schema *schemapb.CollectionSchema) *Collection { + col := &Collection{ id: collectionID, partitions: typeutil.NewConcurrentSet[int64](), loadType: loadType, refCount: atomic.NewUint32(0), } + col.schema.Store(schema) + return col } // new collection without segcore prepare diff --git a/internal/rootcoord/alter_collection_task.go b/internal/rootcoord/alter_collection_task.go index 3fdb5c438b4be..dd875a8e5bb55 100644 --- a/internal/rootcoord/alter_collection_task.go +++ b/internal/rootcoord/alter_collection_task.go @@ -26,11 +26,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -130,6 +132,43 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error { })) } + oldReplicateEnable, _ := common.IsReplicateEnabled(oldColl.Properties) + replicateEnable, ok := common.IsReplicateEnabled(newColl.Properties) + if ok && !replicateEnable && oldReplicateEnable { + replicateID, _ := common.GetReplicateID(oldColl.Properties) + redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for collection", func(ctx context.Context) ([]nestedStep, error) { + msgPack := &msgstream.MsgPack{} + msg := &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + BeginTimestamp: ts, + EndTimestamp: ts, + HashValues: []uint32{0}, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: ts, + ReplicateInfo: &commonpb.ReplicateInfo{ + IsReplicate: true, + ReplicateID: replicateID, + }, + }, + IsEnd: true, + Database: newColl.DBName, + Collection: newColl.Name, + }, + } + msgPack.Msgs = append(msgPack.Msgs, msg) + log.Info("send replicate end msg", + zap.String("collection", newColl.Name), + zap.String("database", newColl.DBName), + zap.String("replicateID", replicateID), + ) + return nil, a.core.chanTimeTick.broadcastDmlChannels(newColl.PhysicalChannelNames, msgPack) + })) + } + return redoTask.Execute(ctx) } diff --git a/internal/rootcoord/alter_collection_task_test.go b/internal/rootcoord/alter_collection_task_test.go index e2b53e958da50..f0a0edb75ea5f 100644 --- a/internal/rootcoord/alter_collection_task_test.go +++ b/internal/rootcoord/alter_collection_task_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -29,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" ) func Test_alterCollectionTask_Prepare(t *testing.T) { @@ -217,14 +219,25 @@ func Test_alterCollectionTask_Execute(t *testing.T) { assert.NoError(t, err) }) - t.Run("alter successfully", func(t *testing.T) { + t.Run("alter successfully2", func(t *testing.T) { meta := mockrootcoord.NewIMetaTable(t) meta.On("GetCollectionByName", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Collection{CollectionID: int64(1)}, nil) + ).Return(&model.Collection{ + CollectionID: int64(1), + Name: "cn", + DBName: "foo", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"}, + }, nil) meta.On("AlterCollection", mock.Anything, mock.Anything, @@ -237,19 +250,37 @@ func Test_alterCollectionTask_Execute(t *testing.T) { broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { return nil } - - core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker)) + packChan := make(chan *msgstream.MsgPack, 10) + ticker := newChanTimeTickSync(packChan) + ticker.addDmlChannels("by-dev-rootcoord-dml_1") + + core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker)) + newPros := append(properties, &commonpb.KeyValuePair{ + Key: common.ReplicateEndTSKey, + Value: "10000", + }) task := &alterCollectionTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.AlterCollectionRequest{ Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "cn", - Properties: properties, + Properties: newPros, }, } err := task.Execute(context.Background()) assert.NoError(t, err) + time.Sleep(time.Second) + select { + case pack := <-packChan: + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) + replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase()) + assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection()) + assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) + default: + assert.Fail(t, "no message sent") + } }) t.Run("test update collection props", func(t *testing.T) { diff --git a/internal/rootcoord/alter_database_task.go b/internal/rootcoord/alter_database_task.go index 569acd77dff22..e11e4fa058f5f 100644 --- a/internal/rootcoord/alter_database_task.go +++ b/internal/rootcoord/alter_database_task.go @@ -25,11 +25,14 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -43,6 +46,19 @@ func (a *alterDatabaseTask) Prepare(ctx context.Context) error { return fmt.Errorf("alter database failed, database name does not exists") } + // TODO SimFG maybe it will support to alter the replica.id properties in the future when the database has no collections + // now it can't be because the latest database properties can't be notified to the querycoord and datacoord + replicateID, _ := common.GetReplicateID(a.Req.Properties) + if replicateID != "" { + colls, err := a.core.meta.ListCollections(ctx, a.Req.DbName, a.ts, true) + if err != nil { + return err + } + if len(colls) > 0 { + return errors.New("can't set replicate id on database with collections") + } + } + return nil } @@ -85,6 +101,18 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error { ts: ts, }) + redoTask.AddSyncStep(&expireCacheStep{ + baseStep: baseStep{core: a.core}, + dbName: newDB.Name, + ts: ts, + // make sure to send the "expire cache" request + // because it won't send this request when the length of collection names array is zero + collectionNames: []string{""}, + opts: []proxyutil.ExpireCacheOpt{ + proxyutil.SetMsgType(commonpb.MsgType_AlterDatabase), + }, + }) + oldReplicaNumber, _ := common.DatabaseLevelReplicaNumber(oldDB.Properties) oldResourceGroups, _ := common.DatabaseLevelResourceGroups(oldDB.Properties) newReplicaNumber, _ := common.DatabaseLevelReplicaNumber(newDB.Properties) @@ -123,6 +151,39 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error { })) } + oldReplicateEnable, _ := common.IsReplicateEnabled(oldDB.Properties) + newReplicateEnable, ok := common.IsReplicateEnabled(newDB.Properties) + if ok && !newReplicateEnable && oldReplicateEnable { + replicateID, _ := common.GetReplicateID(oldDB.Properties) + redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for db", func(ctx context.Context) ([]nestedStep, error) { + msgPack := &msgstream.MsgPack{} + msg := &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + BeginTimestamp: ts, + EndTimestamp: ts, + HashValues: []uint32{0}, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: ts, + ReplicateInfo: &commonpb.ReplicateInfo{ + IsReplicate: true, + ReplicateID: replicateID, + }, + }, + IsEnd: true, + Database: newDB.Name, + Collection: "", + }, + } + msgPack.Msgs = append(msgPack.Msgs, msg) + log.Info("send replicate end msg for db", zap.String("db", newDB.Name), zap.String("replicateID", replicateID)) + return nil, a.core.chanTimeTick.broadcastDmlChannels(a.core.chanTimeTick.listDmlChannels(), msgPack) + })) + } + return redoTask.Execute(ctx) } @@ -134,6 +195,14 @@ func (a *alterDatabaseTask) GetLockerKey() LockerKey { } func MergeProperties(oldProps []*commonpb.KeyValuePair, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair { + _, existEndTS := common.GetReplicateEndTS(updatedProps) + if existEndTS { + updatedProps = append(updatedProps, &commonpb.KeyValuePair{ + Key: common.ReplicateIDKey, + Value: "", + }) + } + props := make(map[string]string) for _, prop := range oldProps { props[prop.Key] = prop.Value diff --git a/internal/rootcoord/alter_database_task_test.go b/internal/rootcoord/alter_database_task_test.go index 8b23d9a3ef75a..47b66e176b4e7 100644 --- a/internal/rootcoord/alter_database_task_test.go +++ b/internal/rootcoord/alter_database_task_test.go @@ -19,6 +19,7 @@ package rootcoord import ( "context" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -29,6 +30,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/funcutil" ) func Test_alterDatabaseTask_Prepare(t *testing.T) { @@ -47,6 +50,76 @@ func Test_alterDatabaseTask_Prepare(t *testing.T) { err := task.Prepare(context.Background()) assert.NoError(t, err) }) + + t.Run("replicate id", func(t *testing.T) { + { + // no collections + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT(). + ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]*model.Collection{}, nil). + Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.NoError(t, err) + } + { + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]*model.Collection{ + { + Name: "foo", + }, + }, nil).Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.Error(t, err) + } + { + meta := mockrootcoord.NewIMetaTable(t) + core := newTestCore(withMeta(meta)) + meta.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("err")). + Once() + task := &alterDatabaseTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &rootcoordpb.AlterDatabaseRequest{ + DbName: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, + } + err := task.Prepare(context.Background()) + assert.Error(t, err) + } + }) } func Test_alterDatabaseTask_Execute(t *testing.T) { @@ -146,25 +219,51 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, - ).Return(&model.Database{ID: int64(1)}, nil) + ).Return(&model.Database{ + ID: int64(1), + Name: "cn", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil) meta.On("AlterDatabase", mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil) - - core := newTestCore(withMeta(meta)) + // the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step + packChan := make(chan *msgstream.MsgPack, 10) + ticker := newChanTimeTickSync(packChan) + ticker.addDmlChannels("by-dev-rootcoord-dml_1") + + core := newTestCore(withMeta(meta), withValidProxyManager(), withTtSynchronizer(ticker)) + newPros := append(properties, + &commonpb.KeyValuePair{Key: common.ReplicateEndTSKey, Value: "1000"}, + ) task := &alterDatabaseTask{ baseTask: newBaseTask(context.Background(), core), Req: &rootcoordpb.AlterDatabaseRequest{ DbName: "cn", - Properties: properties, + Properties: newPros, }, } err := task.Execute(context.Background()) assert.NoError(t, err) + time.Sleep(time.Second) + select { + case pack := <-packChan: + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) + replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase()) + assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) + default: + assert.Fail(t, "no message sent") + } }) t.Run("test update collection props", func(t *testing.T) { @@ -248,3 +347,26 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { assert.Empty(t, ret2) }) } + +func TestMergeProperties(t *testing.T) { + p := MergeProperties([]*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + { + Key: "foo", + Value: "xxx", + }, + }, []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "1001", + }, + }) + assert.Len(t, p, 3) + m := funcutil.KeyValuePair2Map(p) + assert.Equal(t, "", m[common.ReplicateIDKey]) + assert.Equal(t, "1001", m[common.ReplicateEndTSKey]) + assert.Equal(t, "xxx", m["foo"]) +} diff --git a/internal/rootcoord/broker.go b/internal/rootcoord/broker.go index 7b1a281d317b0..94a3bcfe88a7c 100644 --- a/internal/rootcoord/broker.go +++ b/internal/rootcoord/broker.go @@ -43,6 +43,7 @@ type watchInfo struct { vChannels []string startPositions []*commonpb.KeyDataPair schema *schemapb.CollectionSchema + dbProperties []*commonpb.KeyValuePair } // Broker communicates with other components. @@ -165,6 +166,7 @@ func (b *ServerBroker) WatchChannels(ctx context.Context, info *watchInfo) error StartPositions: info.startPositions, Schema: info.schema, CreateTimestamp: info.ts, + DbProperties: info.dbProperties, }) if err != nil { return err diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index 9e9f2d8de7b1f..25b437ed91f5d 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -61,6 +61,7 @@ type createCollectionTask struct { channels collectionChannels dbID UniqueID partitionNames []string + dbProperties []*commonpb.KeyValuePair } func (t *createCollectionTask) validate(ctx context.Context) error { @@ -424,6 +425,18 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error { return err } t.dbID = db.ID + dbReplicateID, _ := common.GetReplicateID(db.Properties) + if dbReplicateID != "" { + reqProperties := make([]*commonpb.KeyValuePair, 0, len(t.Req.Properties)) + for _, prop := range t.Req.Properties { + if prop.Key == common.ReplicateIDKey { + continue + } + reqProperties = append(reqProperties, prop) + } + t.Req.Properties = reqProperties + } + t.dbProperties = db.Properties if err := t.validate(ctx); err != nil { return err @@ -565,6 +578,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { CollectionID: collID, DBID: t.dbID, Name: t.schema.Name, + DBName: t.Req.GetDbName(), Description: t.schema.Description, AutoID: t.schema.AutoID, Fields: model.UnmarshalFieldModels(t.schema.Fields), @@ -644,11 +658,14 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { startPositions: toKeyDataPairs(startPositions), schema: &schemapb.CollectionSchema{ Name: collInfo.Name, + DbName: collInfo.DBName, Description: collInfo.Description, AutoID: collInfo.AutoID, Fields: model.MarshalFieldModels(collInfo.Fields), + Properties: collInfo.Properties, Functions: model.MarshalFunctionModels(collInfo.Functions), }, + dbProperties: t.dbProperties, }, }, &nullStep{}) undoTask.AddStep(&changeCollectionStateStep{ diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index 188fdc1d7947a..62d44562cce0d 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -823,6 +823,70 @@ func Test_createCollectionTask_Prepare(t *testing.T) { }) } +func TestCreateCollectionTask_Prepare_WithProperty(t *testing.T) { + paramtable.Init() + meta := mockrootcoord.NewIMetaTable(t) + t.Run("with db properties", func(t *testing.T) { + meta.EXPECT().GetDatabaseByName(mock.Anything, mock.Anything, mock.Anything).Return(&model.Database{ + Name: "foo", + ID: 1, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "local-test", + }, + }, + }, nil).Twice() + meta.EXPECT().ListAllAvailCollections(mock.Anything).Return(map[int64][]int64{ + util.DefaultDBID: {1, 2}, + }).Once() + meta.EXPECT().GetGeneralCount(mock.Anything).Return(0).Once() + + defer cleanTestEnv() + + collectionName := funcutil.GenRandomStr() + field1 := funcutil.GenRandomStr() + + ticker := newRocksMqTtSynchronizer() + core := newTestCore(withValidIDAllocator(), withTtSynchronizer(ticker), withMeta(meta)) + + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + Name: field1, + DataType: schemapb.DataType_Int64, + }, + }, + } + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task := createCollectionTask{ + baseTask: newBaseTask(context.Background(), core), + Req: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreateCollection}, + CollectionName: collectionName, + Schema: marshaledSchema, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateIDKey, + Value: "hoo", + }, + }, + }, + dbID: 1, + } + task.Req.ShardsNum = common.DefaultShardsNum + err = task.Prepare(context.Background()) + assert.Len(t, task.dbProperties, 1) + assert.Len(t, task.Req.Properties, 0) + assert.NoError(t, err) + }) +} + func Test_createCollectionTask_Execute(t *testing.T) { t.Run("add same collection with different parameters", func(t *testing.T) { defer cleanTestEnv() diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index 72ff07f21cd7a..804093f019ec0 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -195,6 +195,9 @@ func (mt *MetaTable) reload() error { return err } for _, collection := range collections { + if collection.DBName == "" { + collection.DBName = dbName + } mt.collID2Meta[collection.CollectionID] = collection mt.generalCnt += len(collection.Partitions) * int(collection.ShardsNum) if collection.Available() { @@ -559,12 +562,14 @@ func filterUnavailable(coll *model.Collection) *model.Collection { func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowUnavailable bool) (*model.Collection, error) { coll, ok := mt.collID2Meta[collectionID] if !ok || coll == nil { + log.Warn("not found collection", zap.Int64("collectionID", collectionID)) return nil, merr.WrapErrCollectionNotFound(collectionID) } if allowUnavailable { return coll.Clone(), nil } if !coll.Available() { + log.Warn("collection not available", zap.Int64("collectionID", collectionID), zap.Any("state", coll.State)) return nil, merr.WrapErrCollectionNotFound(collectionID) } return filterUnavailable(coll), nil diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 9dc973be8133a..5f59f27c1246a 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -1058,6 +1058,31 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync { return ticker } +func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync { + f := msgstream.NewMockMqFactory() + f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) { + stream := msgstream.NewWastedMockMsgStream() + stream.BroadcastFunc = func(pack *msgstream.MsgPack) error { + log.Info("mock Broadcast") + packChan <- pack + return nil + } + stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { + log.Info("mock BroadcastMark") + packChan <- pack + return map[string][]msgstream.MessageID{}, nil + } + stream.AsProducerFunc = func(channels []string) { + } + stream.ChanFunc = func() <-chan *msgstream.MsgPack { + return packChan + } + return stream, nil + } + + return newTickerWithFactory(f) +} + type mockDdlTsLockManager struct { DdlTsLockManager GetMinDdlTsFunc func() Timestamp diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 5d684001d9cbf..58c609dbdc47c 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1226,6 +1226,7 @@ func convertModelToDesc(collInfo *model.Collection, aliases []string, dbName str Fields: model.MarshalFieldModels(collInfo.Fields), Functions: model.MarshalFunctionModels(collInfo.Functions), EnableDynamicField: collInfo.EnableDynamicField, + Properties: collInfo.Properties, } resp.CollectionID = collInfo.CollectionID resp.VirtualChannelNames = collInfo.VirtualChannelNames @@ -1745,6 +1746,19 @@ func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestam }, nil } + if in.BlockTimestamp > 0 { + blockTime, _ := tsoutil.ParseTS(in.BlockTimestamp) + lastTime := c.tsoAllocator.GetLastSavedTime() + deltaDuration := blockTime.Sub(lastTime) + if deltaDuration > 0 { + log.Info("wait for block timestamp", + zap.Time("blockTime", blockTime), + zap.Time("lastTime", lastTime), + zap.Duration("delta", deltaDuration)) + time.Sleep(deltaDuration + time.Millisecond*200) + } + } + ts, err := c.tsoAllocator.GenerateTSO(in.GetCount()) if err != nil { log.Ctx(ctx).Error("failed to allocate timestamp", zap.String("role", typeutil.RootCoordRole), diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index b8915e5deab79..e6b8c380f4e9b 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -856,6 +856,32 @@ func TestRootCoord_AllocTimestamp(t *testing.T) { assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp()) assert.Equal(t, count, resp.GetCount()) }) + + t.Run("block timestamp", func(t *testing.T) { + alloc := newMockTsoAllocator() + count := uint32(10) + current := time.Now() + ts := tsoutil.ComposeTSByTime(current.Add(time.Second), 1) + alloc.GenerateTSOF = func(count uint32) (uint64, error) { + // end ts + return ts, nil + } + alloc.GetLastSavedTimeF = func() time.Time { + return current + } + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withTsoAllocator(alloc)) + resp, err := c.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: count, + BlockTimestamp: tsoutil.ComposeTSByTime(current.Add(time.Second), 0), + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + // begin ts + assert.Equal(t, ts-uint64(count)+1, resp.GetTimestamp()) + assert.Equal(t, count, resp.GetCount()) + }) } func TestRootCoord_AllocID(t *testing.T) { diff --git a/internal/rootcoord/task.go b/internal/rootcoord/task.go index 4b0927b3bf77f..d515979f7a09f 100644 --- a/internal/rootcoord/task.go +++ b/internal/rootcoord/task.go @@ -18,6 +18,7 @@ package rootcoord import ( "context" + "fmt" "time" "go.uber.org/zap" @@ -173,7 +174,6 @@ func NewCollectionLockerKey(collection string, rw bool) LockerKey { } func NewLockerKeyChain(lockerKeys ...LockerKey) LockerKey { - log.Info("NewLockerKeyChain", zap.Any("lockerKeys", len(lockerKeys))) if len(lockerKeys) == 0 { return nil } @@ -191,3 +191,16 @@ func NewLockerKeyChain(lockerKeys ...LockerKey) LockerKey { } return lockerKeys[0] } + +func GetLockerKeyString(k LockerKey) string { + if k == nil { + return "nil" + } + key := k.LockKey() + level := k.Level() + wLock := k.IsWLock() + if k.Next() == nil { + return fmt.Sprintf("%s-%d-%t", key, level, wLock) + } + return fmt.Sprintf("%s-%d-%t|%s", key, level, wLock, GetLockerKeyString(k.Next())) +} diff --git a/internal/rootcoord/task_test.go b/internal/rootcoord/task_test.go index ed4522c97a091..09f0eb5a9771d 100644 --- a/internal/rootcoord/task_test.go +++ b/internal/rootcoord/task_test.go @@ -20,7 +20,6 @@ package rootcoord import ( "context" - "fmt" "testing" "github.com/cockroachdb/errors" @@ -72,16 +71,6 @@ func TestLockerKey(t *testing.T) { } } -func GetLockerKeyString(k LockerKey) string { - key := k.LockKey() - level := k.Level() - wLock := k.IsWLock() - if k.Next() == nil { - return fmt.Sprintf("%s-%d-%t", key, level, wLock) - } - return fmt.Sprintf("%s-%d-%t|%s", key, level, wLock, GetLockerKeyString(k.Next())) -} - func TestGetLockerKey(t *testing.T) { t.Run("alter alias task locker key", func(t *testing.T) { tt := &alterAliasTask{ diff --git a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go index f8c8eaed5cca3..670677006a27a 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go +++ b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go @@ -116,6 +116,7 @@ func (c *channelLifetime) Run() error { // Build and add pipeline. ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, c.f.pipelineParams, + // TODO fubang add the db properties &datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan(), func(t syncmgr.Task, err error) { if err != nil || t == nil { return diff --git a/internal/util/pipeline/stream_pipeline.go b/internal/util/pipeline/stream_pipeline.go index 46f086cb97992..204b6cfb0019c 100644 --- a/internal/util/pipeline/stream_pipeline.go +++ b/internal/util/pipeline/stream_pipeline.go @@ -45,12 +45,13 @@ type StreamPipeline interface { } type streamPipeline struct { - pipeline *pipeline - input <-chan *msgstream.MsgPack - scanner streaming.Scanner - dispatcher msgdispatcher.Client - startOnce sync.Once - vChannel string + pipeline *pipeline + input <-chan *msgstream.MsgPack + scanner streaming.Scanner + dispatcher msgdispatcher.Client + startOnce sync.Once + vChannel string + replicateConfig *msgstream.ReplicateConfig closeCh chan struct{} // notify work to exit closeWg sync.WaitGroup @@ -118,7 +119,12 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M } start := time.Now() - p.input, err = p.dispatcher.Register(ctx, p.vChannel, position, common.SubscriptionPositionUnknown) + p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{ + VChannel: p.vChannel, + Pos: position, + SubPos: common.SubscriptionPositionUnknown, + ReplicateConfig: p.replicateConfig, + }) if err != nil { log.Error("dispatcher register failed", zap.String("channel", position.ChannelName)) return WrapErrRegDispather(err) @@ -160,18 +166,24 @@ func (p *streamPipeline) Close() { }) } -func NewPipelineWithStream(dispatcher msgdispatcher.Client, nodeTtInterval time.Duration, enableTtChecker bool, vChannel string) StreamPipeline { +func NewPipelineWithStream(dispatcher msgdispatcher.Client, + nodeTtInterval time.Duration, + enableTtChecker bool, + vChannel string, + replicateConfig *msgstream.ReplicateConfig, +) StreamPipeline { pipeline := &streamPipeline{ pipeline: &pipeline{ nodes: []*nodeCtx{}, nodeTtInterval: nodeTtInterval, enableTtChecker: enableTtChecker, }, - dispatcher: dispatcher, - vChannel: vChannel, - closeCh: make(chan struct{}), - closeWg: sync.WaitGroup{}, - lastAccessTime: atomic.NewTime(time.Now()), + dispatcher: dispatcher, + vChannel: vChannel, + replicateConfig: replicateConfig, + closeCh: make(chan struct{}), + closeWg: sync.WaitGroup{}, + lastAccessTime: atomic.NewTime(time.Now()), } return pipeline diff --git a/internal/util/pipeline/stream_pipeline_test.go b/internal/util/pipeline/stream_pipeline_test.go index 8ceaf38e52194..1b28b558e8645 100644 --- a/internal/util/pipeline/stream_pipeline_test.go +++ b/internal/util/pipeline/stream_pipeline_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" ) @@ -47,9 +46,9 @@ func (suite *StreamPipelineSuite) SetupTest() { suite.inChannel = make(chan *msgstream.MsgPack, 1) suite.outChannel = make(chan msgstream.Timestamp) suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T()) - suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, common.SubscriptionPositionUnknown).Return(suite.inChannel, nil) + suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.inChannel, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) - suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel) + suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel, nil) suite.length = 4 } diff --git a/pkg/common/common.go b/pkg/common/common.go index e40c5825db10c..be6237a9729b1 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -191,6 +191,8 @@ const ( PartitionKeyIsolationKey = "partitionkey.isolation" FieldSkipLoadKey = "field.skipLoad" IndexOffsetCacheEnabledKey = "indexoffsetcache.enabled" + ReplicateIDKey = "replicate.id" + ReplicateEndTSKey = "replicate.endTS" ) const ( @@ -395,3 +397,31 @@ func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) { } return true, nil } + +func IsReplicateEnabled(kvs []*commonpb.KeyValuePair) (bool, bool) { + replicateID, ok := GetReplicateID(kvs) + return replicateID != "", ok +} + +func GetReplicateID(kvs []*commonpb.KeyValuePair) (string, bool) { + for _, kv := range kvs { + if kv.GetKey() == ReplicateIDKey { + return kv.GetValue(), true + } + } + return "", false +} + +func GetReplicateEndTS(kvs []*commonpb.KeyValuePair) (uint64, bool) { + for _, kv := range kvs { + if kv.GetKey() == ReplicateEndTSKey { + ts, err := strconv.ParseUint(kv.GetValue(), 10, 64) + if err != nil { + log.Warn("parse replicate end ts failed", zap.Error(err), zap.Stack("stack")) + return 0, false + } + return ts, true + } + } + return 0, false +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 7e77b782f38eb..422c020589731 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -177,3 +177,84 @@ func TestShouldFieldBeLoaded(t *testing.T) { }) } } + +func TestReplicateProperty(t *testing.T) { + t.Run("ReplicateID", func(t *testing.T) { + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateIDKey, + Value: "1001", + }, + } + e, ok := IsReplicateEnabled(p) + assert.True(t, e) + assert.True(t, ok) + i, ok := GetReplicateID(p) + assert.True(t, ok) + assert.Equal(t, "1001", i) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateIDKey, + Value: "", + }, + } + e, ok := IsReplicateEnabled(p) + assert.False(t, e) + assert.True(t, ok) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: "foo", + Value: "1001", + }, + } + e, ok := IsReplicateEnabled(p) + assert.False(t, e) + assert.False(t, ok) + } + }) + + t.Run("ReplicateTS", func(t *testing.T) { + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateEndTSKey, + Value: "1001", + }, + } + ts, ok := GetReplicateEndTS(p) + assert.True(t, ok) + assert.EqualValues(t, 1001, ts) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateEndTSKey, + Value: "foo", + }, + } + ts, ok := GetReplicateEndTS(p) + assert.False(t, ok) + assert.EqualValues(t, 0, ts) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: "foo", + Value: "1001", + }, + } + ts, ok := GetReplicateEndTS(p) + assert.False(t, ok) + assert.EqualValues(t, 0, ts) + } + }) +} diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 95e5ff1730811..7875c67dabaf4 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -36,8 +36,23 @@ type ( SubPos = common.SubscriptionInitialPosition ) +type StreamConfig struct { + VChannel string + Pos *Pos + SubPos SubPos + ReplicateConfig *msgstream.ReplicateConfig +} + +func NewStreamConfig(vchannel string, pos *Pos, subPos SubPos) *StreamConfig { + return &StreamConfig{ + VChannel: vchannel, + Pos: pos, + SubPos: subPos, + } +} + type Client interface { - Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) + Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) Deregister(vchannel string) Close() } @@ -62,7 +77,8 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client { } } -func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) { +func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) { + vchannel := streamConfig.VChannel log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) pchannel := funcutil.ToPhysicalChannel(vchannel) @@ -75,7 +91,7 @@ func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos c.managers.Insert(pchannel, manager) go manager.Run() } - ch, err := manager.Add(ctx, vchannel, pos, subPos) + ch, err := manager.Add(ctx, streamConfig) if err != nil { if manager.Num() == 0 { manager.Close() diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 707e0becfd467..11ddab3d8feae 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -34,9 +34,9 @@ import ( func TestClient(t *testing.T) { client := NewClient(newMockFactory(), typeutil.ProxyRole, 1) assert.NotNil(t, client) - _, err := client.Register(context.Background(), "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + _, err := client.Register(context.Background(), NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = client.Register(context.Background(), "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err = client.Register(context.Background(), NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) assert.NotPanics(t, func() { client.Deregister("mock_vchannel_0") @@ -51,7 +51,7 @@ func TestClient(t *testing.T) { client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1) defer client.Close() assert.NotNil(t, client) - _, err := client.Register(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err := client.Register(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) }) } @@ -66,7 +66,7 @@ func TestClient_Concurrency(t *testing.T) { vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int()) wg.Add(1) go func() { - _, err := client1.Register(context.Background(), vchannel, nil, common.SubscriptionPositionUnknown) + _, err := client1.Register(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) for j := 0; j < rand.Intn(2); j++ { client1.Deregister(vchannel) diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index c41aa7be83311..2df78b2c901b8 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -80,10 +80,14 @@ type Dispatcher struct { stream msgstream.MsgStream } -func NewDispatcher(ctx context.Context, - factory msgstream.Factory, isMain bool, - pchannel string, position *Pos, - subName string, subPos SubPos, +func NewDispatcher( + ctx context.Context, + factory msgstream.Factory, + isMain bool, + pchannel string, + position *Pos, + subName string, + subPos SubPos, lagNotifyChan chan struct{}, lagTargets *typeutil.ConcurrentMap[string, *target], includeCurrentMsg bool, @@ -260,7 +264,8 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { // init packs for all targets, even though there's no msg in pack, // but we still need to dispatch time ticks to the targets. targetPacks := make(map[string]*MsgPack) - for vchannel := range d.targets { + replicateConfigs := make(map[string]*msgstream.ReplicateConfig) + for vchannel, t := range d.targets { targetPacks[vchannel] = &MsgPack{ BeginTs: pack.BeginTs, EndTs: pack.EndTs, @@ -268,6 +273,9 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { StartPositions: pack.StartPositions, EndPositions: pack.EndPositions, } + if t.replicateConfig != nil { + replicateConfigs[vchannel] = t.replicateConfig + } } // group messages by vchannel for _, msg := range pack.Msgs { @@ -287,9 +295,16 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { collectionID = strconv.FormatInt(msg.(*msgstream.DropPartitionMsg).GetCollectionID(), 10) } if vchannel == "" { - // for non-dml msg, such as CreateCollection, DropCollection, ... // we need to dispatch it to the vchannel of this collection for k := range targetPacks { + if msg.Type() == commonpb.MsgType_Replicate { + config := replicateConfigs[k] + if config != nil && msgstream.MatchReplicateID(msg, config.ReplicateID) { + targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg) + } + continue + } + if !strings.Contains(k, collectionID) { continue } @@ -303,9 +318,63 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, msg) } } + replicateEndChannels := make(map[string]struct{}) + for vchannel, c := range replicateConfigs { + if len(targetPacks[vchannel].Msgs) == 0 { + delete(targetPacks, vchannel) // no replicate msg, can't send pack + continue + } + // calculate the new pack ts + beginTs := targetPacks[vchannel].Msgs[0].BeginTs() + endTs := targetPacks[vchannel].Msgs[0].EndTs() + newMsgs := make([]msgstream.TsMsg, 0) + for _, msg := range targetPacks[vchannel].Msgs { + if msg.BeginTs() < beginTs { + beginTs = msg.BeginTs() + } + if msg.EndTs() > endTs { + endTs = msg.EndTs() + } + if msg.Type() == commonpb.MsgType_Replicate { + replicateMsg := msg.(*msgstream.ReplicateMsg) + if c.CheckFunc(replicateMsg) { + replicateEndChannels[vchannel] = struct{}{} + } + continue + } + newMsgs = append(newMsgs, msg) + } + targetPacks[vchannel].Msgs = newMsgs + d.resetMsgPackTS(targetPacks[vchannel], beginTs, endTs) + } + for vchannel := range replicateEndChannels { + if t, ok := d.targets[vchannel]; ok { + t.replicateConfig = nil + log.Info("replicate end, set replicate config nil", zap.String("vchannel", vchannel)) + } + } return targetPacks } +func (d *Dispatcher) resetMsgPackTS(pack *MsgPack, newBeginTs, newEndTs typeutil.Timestamp) { + pack.BeginTs = newBeginTs + pack.EndTs = newEndTs + startPositions := make([]*msgstream.MsgPosition, 0) + endPositions := make([]*msgstream.MsgPosition, 0) + for _, pos := range pack.StartPositions { + startPosition := typeutil.Clone(pos) + startPosition.Timestamp = newBeginTs + startPositions = append(startPositions, startPosition) + } + for _, pos := range pack.EndPositions { + endPosition := typeutil.Clone(pos) + endPosition.Timestamp = newEndTs + endPositions = append(endPositions, endPosition) + } + pack.StartPositions = startPositions + pack.EndPositions = endPositions +} + func (d *Dispatcher) nonBlockingNotify() { select { case d.lagNotifyChan <- struct{}{}: diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index 02c9d89cc70e9..d4c20fae8c3fb 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -17,6 +17,8 @@ package msgdispatcher import ( + "fmt" + "math/rand" "sync" "testing" "time" @@ -26,6 +28,8 @@ import ( "github.com/stretchr/testify/mock" "golang.org/x/net/context" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" ) @@ -73,7 +77,7 @@ func TestDispatcher(t *testing.T) { output := make(chan *msgstream.MsgPack, 1024) getTarget := func(vchannel string, pos *Pos, ch chan *msgstream.MsgPack) *target { - target := newTarget(vchannel, pos) + target := newTarget(vchannel, pos, nil) target.ch = ch return target } @@ -103,7 +107,7 @@ func TestDispatcher(t *testing.T) { t.Run("test concurrent send and close", func(t *testing.T) { for i := 0; i < 100; i++ { output := make(chan *msgstream.MsgPack, 1024) - target := newTarget("mock_vchannel_0", nil) + target := newTarget("mock_vchannel_0", nil, nil) target.ch = output assert.Equal(t, cap(output), cap(target.ch)) wg := &sync.WaitGroup{} @@ -138,3 +142,195 @@ func BenchmarkDispatcher_handle(b *testing.B) { // BenchmarkDispatcher_handle-12 9568 122123 ns/op // PASS } + +func TestGroupMessage(t *testing.T) { + d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0"+fmt.Sprintf("%d", rand.Int()), common.SubscriptionPositionEarliest, nil, nil, false) + assert.NoError(t, err) + d.AddTarget(newTarget("mock_pchannel_0_1v0", nil, nil)) + d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo"))) + { + // no replicate msg + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 5, + EndTimestamp: 5, + }, + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + Timestamp: 5, + }, + ShardName: "mock_pchannel_0_1v0", + }, + }, + }, + }) + assert.Len(t, packs, 1) + } + + { + // equal to replicateID + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test", + }, + }, + }, + }, + }, + }) + assert.Len(t, packs, 2) + { + replicatePack := packs["mock_pchannel_0_2v0"] + assert.EqualValues(t, 100, replicatePack.BeginTs) + assert.EqualValues(t, 100, replicatePack.EndTs) + assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp) + assert.Len(t, replicatePack.Msgs, 0) + } + { + replicatePack := packs["mock_pchannel_0_1v0"] + assert.EqualValues(t, 1, replicatePack.BeginTs) + assert.EqualValues(t, 10, replicatePack.EndTs) + assert.EqualValues(t, 1, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 10, replicatePack.EndPositions[0].Timestamp) + assert.Len(t, replicatePack.Msgs, 0) + } + } + + { + // not equal to replicateID + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test-1", // not equal to replicateID + }, + }, + }, + }, + }, + }) + assert.Len(t, packs, 1) + replicatePack := packs["mock_pchannel_0_2v0"] + assert.Nil(t, replicatePack) + } + + { + // replicate end + replicateTarget := d.targets["mock_pchannel_0_2v0"] + assert.NotNil(t, replicateTarget.replicateConfig) + packs := d.groupingMsgs(&MsgPack{ + BeginTs: 1, + EndTs: 10, + StartPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("1"), + Timestamp: 1, + }, + }, + EndPositions: []*msgstream.MsgPosition{ + { + ChannelName: "mock_pchannel_0", + MsgID: []byte("10"), + Timestamp: 10, + }, + }, + Msgs: []msgstream.TsMsg{ + &msgstream.ReplicateMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 100, + EndTimestamp: 100, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local-test", + }, + }, + IsEnd: true, + Database: "foo", + }, + }, + }, + }) + assert.Len(t, packs, 2) + replicatePack := packs["mock_pchannel_0_2v0"] + assert.EqualValues(t, 100, replicatePack.BeginTs) + assert.EqualValues(t, 100, replicatePack.EndTs) + assert.EqualValues(t, 100, replicatePack.StartPositions[0].Timestamp) + assert.EqualValues(t, 100, replicatePack.EndPositions[0].Timestamp) + assert.Nil(t, replicateTarget.replicateConfig) + } +} diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index 6fd9a22f20354..d046953b6417a 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -36,7 +36,7 @@ import ( ) type DispatcherManager interface { - Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) + Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) Remove(vchannel string) Num() int Run() @@ -82,7 +82,8 @@ func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) strin return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain) } -func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) { +func (c *dispatcherManager) Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) { + vchannel := streamConfig.VChannel log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) @@ -102,11 +103,11 @@ func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, } isMain := c.mainDispatcher == nil - d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets, false) + d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, streamConfig.Pos, c.constructSubName(vchannel, isMain), streamConfig.SubPos, c.lagNotifyChan, c.lagTargets, false) if err != nil { return nil, err } - t := newTarget(vchannel, pos) + t := newTarget(vchannel, streamConfig.Pos, streamConfig.ReplicateConfig) d.AddTarget(t) if isMain { c.mainDispatcher = d diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index feb55799651be..b02ba95621286 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -48,7 +48,7 @@ func TestManager(t *testing.T) { offset++ vchannel := fmt.Sprintf("mock-pchannel-dml_0_vchannelv%d", offset) t.Logf("add vchannel, %s", vchannel) - _, err := c.Add(context.Background(), vchannel, nil, common.SubscriptionPositionUnknown) + _, err := c.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) assert.Equal(t, offset, c.Num()) } @@ -67,11 +67,11 @@ func TestManager(t *testing.T) { ctx := context.Background() c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) assert.NotNil(t, c) - _, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) assert.Equal(t, 3, c.Num()) c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) @@ -98,11 +98,11 @@ func TestManager(t *testing.T) { ctx := context.Background() c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) assert.NotNil(t, c) - _, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) assert.Equal(t, 3, c.Num()) c.(*dispatcherManager).mainDispatcher.curTs.Store(1000) @@ -134,11 +134,11 @@ func TestManager(t *testing.T) { c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) go c.Run() assert.NotNil(t, c) - _, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) assert.Equal(t, 0, c.Num()) @@ -153,18 +153,18 @@ func TestManager(t *testing.T) { go c.Run() assert.NotNil(t, c) ctx := context.Background() - _, err := c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + _, err := c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.NoError(t, err) - _, err = c.Add(ctx, "mock_vchannel_0", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - _, err = c.Add(ctx, "mock_vchannel_1", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) - _, err = c.Add(ctx, "mock_vchannel_2", nil, common.SubscriptionPositionUnknown) + _, err = c.Add(ctx, NewStreamConfig("mock_vchannel_2", nil, common.SubscriptionPositionUnknown)) assert.Error(t, err) assert.NotPanics(t, func() { @@ -325,7 +325,7 @@ func (suite *SimulationSuite) TestDispatchToVchannels() { suite.vchannels = make(map[string]*vchannelHelper, vchannelNum) for i := 0; i < vchannelNum; i++ { vchannel := fmt.Sprintf("%s_%dv%d", suite.pchannel, collectionID, i) - output, err := suite.manager.Add(context.Background(), vchannel, nil, common.SubscriptionPositionEarliest) + output, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest)) assert.NoError(suite.T(), err) suite.vchannels[vchannel] = &vchannelHelper{output: output} } @@ -360,8 +360,10 @@ func (suite *SimulationSuite) TestMerge() { for i := 0; i < vchannelNum; i++ { vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - output, err := suite.manager.Add(context.Background(), vchannel, positions[rand.Intn(len(positions))], - common.SubscriptionPositionUnknown) // seek from random position + output, err := suite.manager.Add(context.Background(), NewStreamConfig( + vchannel, positions[rand.Intn(len(positions))], + common.SubscriptionPositionUnknown, + )) // seek from random position assert.NoError(suite.T(), err) suite.vchannels[vchannel] = &vchannelHelper{output: output} } @@ -402,7 +404,7 @@ func (suite *SimulationSuite) TestSplit() { paramtable.Get().Save(targetBufSizeK, "10") } vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - _, err := suite.manager.Add(context.Background(), vchannel, nil, common.SubscriptionPositionEarliest) + _, err := suite.manager.Add(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionEarliest)) assert.NoError(suite.T(), err) } diff --git a/pkg/mq/msgdispatcher/mock_client.go b/pkg/mq/msgdispatcher/mock_client.go index bc279459d2fd4..9a056a730b8d5 100644 --- a/pkg/mq/msgdispatcher/mock_client.go +++ b/pkg/mq/msgdispatcher/mock_client.go @@ -5,13 +5,8 @@ package msgdispatcher import ( context "context" - common "github.com/milvus-io/milvus/pkg/mq/common" - - mock "github.com/stretchr/testify/mock" - - msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream" + mock "github.com/stretchr/testify/mock" ) // MockClient is an autogenerated mock type for the Client type @@ -92,9 +87,9 @@ func (_c *MockClient_Deregister_Call) RunAndReturn(run func(string)) *MockClient return _c } -// Register provides a mock function with given fields: ctx, vchannel, pos, subPos -func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) { - ret := _m.Called(ctx, vchannel, pos, subPos) +// Register provides a mock function with given fields: ctx, streamConfig +func (_m *MockClient) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *msgstream.MsgPack, error) { + ret := _m.Called(ctx, streamConfig) if len(ret) == 0 { panic("no return value specified for Register") @@ -102,19 +97,19 @@ func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb. var r0 <-chan *msgstream.MsgPack var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok { - return rf(ctx, vchannel, pos, subPos) + if rf, ok := ret.Get(0).(func(context.Context, *StreamConfig) (<-chan *msgstream.MsgPack, error)); ok { + return rf(ctx, streamConfig) } - if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok { - r0 = rf(ctx, vchannel, pos, subPos) + if rf, ok := ret.Get(0).(func(context.Context, *StreamConfig) <-chan *msgstream.MsgPack); ok { + r0 = rf(ctx, streamConfig) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(<-chan *msgstream.MsgPack) } } - if rf, ok := ret.Get(1).(func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) error); ok { - r1 = rf(ctx, vchannel, pos, subPos) + if rf, ok := ret.Get(1).(func(context.Context, *StreamConfig) error); ok { + r1 = rf(ctx, streamConfig) } else { r1 = ret.Error(1) } @@ -129,16 +124,14 @@ type MockClient_Register_Call struct { // Register is a helper method to define mock.On call // - ctx context.Context -// - vchannel string -// - pos *msgpb.MsgPosition -// - subPos common.SubscriptionInitialPosition -func (_e *MockClient_Expecter) Register(ctx interface{}, vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call { - return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, vchannel, pos, subPos)} +// - streamConfig *StreamConfig +func (_e *MockClient_Expecter) Register(ctx interface{}, streamConfig interface{}) *MockClient_Register_Call { + return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, streamConfig)} } -func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos common.SubscriptionInitialPosition)) *MockClient_Register_Call { +func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, streamConfig *StreamConfig)) *MockClient_Register_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition), args[3].(common.SubscriptionInitialPosition)) + run(args[0].(context.Context), args[1].(*StreamConfig)) }) return _c } @@ -148,7 +141,7 @@ func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 er return _c } -func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition, common.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call { +func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, *StreamConfig) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call { _c.Call.Return(run) return _c } diff --git a/pkg/mq/msgdispatcher/target.go b/pkg/mq/msgdispatcher/target.go index d1ccab6f9ad72..251131e3f826d 100644 --- a/pkg/mq/msgdispatcher/target.go +++ b/pkg/mq/msgdispatcher/target.go @@ -24,6 +24,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -33,26 +34,33 @@ type target struct { ch chan *MsgPack pos *Pos - closeMu sync.Mutex - closeOnce sync.Once - closed bool - maxLag time.Duration - timer *time.Timer + closeMu sync.Mutex + closeOnce sync.Once + closed bool + maxLag time.Duration + timer *time.Timer + replicateConfig *msgstream.ReplicateConfig cancelCh lifetime.SafeChan } -func newTarget(vchannel string, pos *Pos) *target { +func newTarget(vchannel string, pos *Pos, replicateConfig *msgstream.ReplicateConfig) *target { maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second) t := &target{ - vchannel: vchannel, - ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()), - pos: pos, - cancelCh: lifetime.NewSafeChan(), - maxLag: maxTolerantLag, - timer: time.NewTimer(maxTolerantLag), + vchannel: vchannel, + ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()), + pos: pos, + cancelCh: lifetime.NewSafeChan(), + maxLag: maxTolerantLag, + timer: time.NewTimer(maxTolerantLag), + replicateConfig: replicateConfig, } t.closed = false + if replicateConfig != nil { + log.Info("have replicate config", + zap.String("vchannel", vchannel), + zap.String("replicateID", replicateConfig.ReplicateID)) + } return t } diff --git a/pkg/mq/msgdispatcher/target_test.go b/pkg/mq/msgdispatcher/target_test.go index 29e970068ea36..b0f4a37405723 100644 --- a/pkg/mq/msgdispatcher/target_test.go +++ b/pkg/mq/msgdispatcher/target_test.go @@ -14,7 +14,7 @@ import ( ) func TestSendTimeout(t *testing.T) { - target := newTarget("test1", &msgpb.MsgPosition{}) + target := newTarget("test1", &msgpb.MsgPosition{}, nil) time.Sleep(paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second)) diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 61246421bea00..76486865e81a1 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -72,6 +72,9 @@ type mqMsgStream struct { ttMsgEnable atomic.Value forceEnableProduce atomic.Value configEvent config.EventHandler + + replicateID string + checkFunc CheckReplicateMsgFunc } // NewMqMsgStream is used to generate a new mqMsgStream object @@ -276,6 +279,23 @@ func (ms *mqMsgStream) isEnabledProduce() bool { return ms.forceEnableProduce.Load().(bool) || ms.ttMsgEnable.Load().(bool) } +func (ms *mqMsgStream) isSkipSystemTT() bool { + return ms.replicateID != "" +} + +// checkReplicateID check the replicate id of the message, return values: isMatch, isReplicate +func (ms *mqMsgStream) checkReplicateID(msg TsMsg) (bool, bool) { + if !ms.isSkipSystemTT() { + return true, false + } + msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase }) + if !ok { + log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type())) + return false, false + } + return msgBase.GetBase().GetReplicateInfo().GetReplicateID() == ms.replicateID, true +} + func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error { if !ms.isEnabledProduce() { log.Ctx(ms.ctx).Warn("can't produce the msg in the backup instance", zap.Stack("stack")) @@ -688,9 +708,9 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { startBufTime := time.Now() var endTs uint64 var size uint64 - var containsDropCollectionMsg bool + var containsEndBufferMsg bool - for ms.continueBuffering(endTs, size, startBufTime) && !containsDropCollectionMsg { + for ms.continueBuffering(endTs, size, startBufTime) && !containsEndBufferMsg { ms.consumerLock.Lock() // wait all channels get ttMsg for _, consumer := range ms.consumers { @@ -726,15 +746,16 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { timeTickMsg = v continue } - if v.EndTs() <= currTs { + if v.EndTs() <= currTs || + GetReplicateID(v) != "" { size += uint64(v.Size()) timeTickBuf = append(timeTickBuf, v) } else { tempBuffer = append(tempBuffer, v) } // when drop collection, force to exit the buffer loop - if v.Type() == commonpb.MsgType_DropCollection { - containsDropCollectionMsg = true + if v.Type() == commonpb.MsgType_DropCollection || v.Type() == commonpb.MsgType_Replicate { + containsEndBufferMsg = true } } ms.chanMsgBuf[consumer] = tempBuffer @@ -860,7 +881,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu } for consumer := range ms.chanTtMsgTime { ms.chanTtMsgTimeMutex.RLock() - chanTtMsgSync[consumer] = (ms.chanTtMsgTime[consumer] == maxTime) + chanTtMsgSync[consumer] = ms.chanTtMsgTime[consumer] == maxTime ms.chanTtMsgTimeMutex.RUnlock() } @@ -960,6 +981,10 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, if err != nil { return fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error()) } + // skip the replicate msg because it must have been consumed + if GetReplicateID(tsMsg) != "" { + continue + } if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp { runLoop = false if time.Since(loopStarTime) > 30*time.Second { diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index 6c9973f3373cd..1ff559c191c63 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -708,6 +708,21 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 1)) msgPack1.Msgs = append(msgPack1.Msgs, getTsMsg(commonpb.MsgType_Insert, 3)) + replicatePack := MsgPack{} + replicatePack.Msgs = append(replicatePack.Msgs, &ReplicateMsg{ + BaseMsg: BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{100}, + }, + ReplicateMsg: &msgpb.ReplicateMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Replicate, + Timestamp: 100, + }, + }, + }) + msgPack2 := MsgPack{} msgPack2.Msgs = append(msgPack2.Msgs, getTimeTickMsg(5)) @@ -721,6 +736,9 @@ func TestStream_PulsarTtMsgStream_UnMarshalHeader(t *testing.T) { err = inputStream.Produce(ctx, &msgPack1) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + err = inputStream.Produce(ctx, &replicatePack) + require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) + _, err = inputStream.Broadcast(ctx, &msgPack2) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) diff --git a/pkg/mq/msgstream/msg_for_replicate.go b/pkg/mq/msgstream/msg_for_replicate.go new file mode 100644 index 0000000000000..3f6b9699c69ab --- /dev/null +++ b/pkg/mq/msgstream/msg_for_replicate.go @@ -0,0 +1,78 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" +) + +type ReplicateMsg struct { + BaseMsg + *msgpb.ReplicateMsg +} + +var _ TsMsg = (*ReplicateMsg)(nil) + +func (r *ReplicateMsg) ID() UniqueID { + return r.Base.MsgID +} + +func (r *ReplicateMsg) SetID(id UniqueID) { + r.Base.MsgID = id +} + +func (r *ReplicateMsg) Type() MsgType { + return r.Base.MsgType +} + +func (r *ReplicateMsg) SourceID() int64 { + return r.Base.SourceID +} + +func (r *ReplicateMsg) Marshal(input TsMsg) (MarshalType, error) { + replicateMsg := input.(*ReplicateMsg) + mb, err := proto.Marshal(replicateMsg.ReplicateMsg) + if err != nil { + return nil, err + } + return mb, nil +} + +func (r *ReplicateMsg) Unmarshal(input MarshalType) (TsMsg, error) { + replicateMsg := &msgpb.ReplicateMsg{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, replicateMsg) + if err != nil { + return nil, err + } + rr := &ReplicateMsg{ReplicateMsg: replicateMsg} + rr.BeginTimestamp = replicateMsg.GetBase().GetTimestamp() + rr.EndTimestamp = replicateMsg.GetBase().GetTimestamp() + + return rr, nil +} + +func (r *ReplicateMsg) Size() int { + return proto.Size(r.ReplicateMsg) +} diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 24709de81cf90..3b1d4b3c32d8d 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -19,7 +19,11 @@ package msgstream import ( "context" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -73,6 +77,50 @@ type MsgStream interface { ForceEnableProduce(can bool) } +type ReplicateConfig struct { + ReplicateID string + CheckFunc CheckReplicateMsgFunc +} + +type CheckReplicateMsgFunc func(*ReplicateMsg) bool + +func GetReplicateConfig(replicateID, dbName, colName string) *ReplicateConfig { + if replicateID == "" { + return nil + } + replicateConfig := &ReplicateConfig{ + ReplicateID: replicateID, + CheckFunc: func(msg *ReplicateMsg) bool { + if !msg.GetIsEnd() { + return false + } + log.Info("check replicate msg", + zap.String("replicateID", replicateID), + zap.String("dbName", dbName), + zap.String("colName", colName), + zap.Any("msg", msg)) + if msg.GetIsCluster() { + return true + } + return msg.GetDatabase() == dbName && (msg.GetCollection() == colName || msg.GetCollection() == "") + }, + } + return replicateConfig +} + +func GetReplicateID(msg TsMsg) string { + msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase }) + if !ok { + log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type())) + return "" + } + return msgBase.GetBase().GetReplicateInfo().GetReplicateID() +} + +func MatchReplicateID(msg TsMsg, replicateID string) bool { + return GetReplicateID(msg) == replicateID +} + type Factory interface { NewMsgStream(ctx context.Context) (MsgStream, error) NewTtMsgStream(ctx context.Context) (MsgStream, error) diff --git a/pkg/mq/msgstream/msgstream_util_test.go b/pkg/mq/msgstream/msgstream_util_test.go index f8d1754eac205..9811ea5e78670 100644 --- a/pkg/mq/msgstream/msgstream_util_test.go +++ b/pkg/mq/msgstream/msgstream_util_test.go @@ -24,6 +24,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/common" ) @@ -80,3 +82,90 @@ func TestGetLatestMsgID(t *testing.T) { assert.Equal(t, []byte("mock"), id) } } + +func TestReplicateConfig(t *testing.T) { + t.Run("get replicate id", func(t *testing.T) { + { + msg := &InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + ReplicateInfo: &commonpb.ReplicateInfo{ + ReplicateID: "local", + }, + }, + }, + } + assert.Equal(t, "local", GetReplicateID(msg)) + assert.True(t, MatchReplicateID(msg, "local")) + } + { + msg := &InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{}, + }, + } + assert.Equal(t, "", GetReplicateID(msg)) + assert.False(t, MatchReplicateID(msg, "local")) + } + { + msg := &MarshalFailTsMsg{} + assert.Equal(t, "", GetReplicateID(msg)) + } + }) + + t.Run("get replicate config", func(t *testing.T) { + { + assert.Nil(t, GetReplicateConfig("", "", "")) + } + { + rc := GetReplicateConfig("local", "db", "col") + assert.Equal(t, "local", rc.ReplicateID) + checkFunc := rc.CheckFunc + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{}, + })) + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + IsCluster: true, + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db1", + }, + })) + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + Collection: "col1", + }, + })) + } + { + rc := GetReplicateConfig("local", "db", "col") + checkFunc := rc.CheckFunc + assert.True(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db", + }, + })) + assert.False(t, checkFunc(&ReplicateMsg{ + ReplicateMsg: &msgpb.ReplicateMsg{ + IsEnd: true, + Database: "db1", + Collection: "col1", + }, + })) + } + }) +} diff --git a/pkg/mq/msgstream/unmarshal.go b/pkg/mq/msgstream/unmarshal.go index c4427a2bea1c0..8cd8de9eece3a 100644 --- a/pkg/mq/msgstream/unmarshal.go +++ b/pkg/mq/msgstream/unmarshal.go @@ -84,6 +84,7 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { dropRoleMsg := DropRoleMsg{} operateUserRoleMsg := OperateUserRoleMsg{} operatePrivilegeMsg := OperatePrivilegeMsg{} + replicateMsg := ReplicateMsg{} p := &ProtoUnmarshalDispatcher{} p.TempMap = make(map[commonpb.MsgType]UnmarshalFunc) @@ -113,6 +114,7 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { p.TempMap[commonpb.MsgType_DropRole] = dropRoleMsg.Unmarshal p.TempMap[commonpb.MsgType_OperateUserRole] = operateUserRoleMsg.Unmarshal p.TempMap[commonpb.MsgType_OperatePrivilege] = operatePrivilegeMsg.Unmarshal + p.TempMap[commonpb.MsgType_Replicate] = replicateMsg.Unmarshal return p } diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 7d81d8a4da70d..8de8203461eaa 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -70,6 +70,7 @@ var ( ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false) ErrCollectionOnRecovering = newMilvusError("collection on recovering", 106, true) ErrCollectionVectorClusteringKeyNotAllowed = newMilvusError("vector clustering key not allowed", 107, false) + ErrCollectionReplicateMode = newMilvusError("can't operate on the collection under standby mode", 108, false) // Partition related ErrPartitionNotFound = newMilvusError("partition not found", 200, false) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 3cf0d888458b6..aef0bed7a344b 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -330,6 +330,10 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error { return err } +func WrapErrCollectionReplicateMode(operation string) error { + return wrapFields(ErrCollectionReplicateMode, value("operation", operation)) +} + func GetErrorType(err error) ErrorType { if merr, ok := err.(milvusError); ok { return merr.errType diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 0950233569af4..8174cdbaa544f 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -268,6 +268,7 @@ type commonConfig struct { MaxBloomFalsePositive ParamItem `refreshable:"true"` BloomFilterApplyBatchSize ParamItem `refreshable:"true"` PanicWhenPluginFail ParamItem `refreshable:"false"` + CollectionReplicateEnable ParamItem `refreshable:"true"` UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"` UseVectorAsClusteringKey ParamItem `refreshable:"true"` @@ -784,6 +785,15 @@ This helps Milvus-CDC synchronize incremental data`, } p.TTMsgEnabled.Init(base.mgr) + p.CollectionReplicateEnable = ParamItem{ + Key: "common.collectionReplicateEnable", + Version: "2.4.16", + DefaultValue: "false", + Doc: `Whether to enable collection replication.`, + Export: true, + } + p.CollectionReplicateEnable.Init(base.mgr) + p.TraceLogMode = ParamItem{ Key: "common.traceLogMode", Version: "2.3.4",