From b5a94cccb18f6bd81b7c78ce7997a56f47caf1de Mon Sep 17 00:00:00 2001 From: aoiasd Date: Mon, 23 Dec 2024 14:33:29 +0800 Subject: [PATCH] unmashall ts msg in dispatcher instead in msgstream Signed-off-by: aoiasd --- .../pipeline/data_sync_service_test.go | 9 +- .../flow_graph_dmstream_input_node_test.go | 8 +- internal/proxy/mock_test.go | 13 +- .../querynodev2/delegator/delegator_data.go | 10 +- .../delegator/delegator_data_test.go | 9 +- internal/querynodev2/services_test.go | 2 +- .../rootcoord/alter_collection_task_test.go | 12 +- .../rootcoord/alter_database_task_test.go | 13 +- internal/rootcoord/dml_channels_test.go | 9 +- internal/rootcoord/mock_test.go | 8 +- internal/util/flowgraph/input_node_test.go | 7 +- internal/util/flowgraph/node_test.go | 3 +- pkg/mq/common/message.go | 7 +- pkg/mq/msgdispatcher/dispatcher.go | 53 ++-- pkg/mq/msgdispatcher/dispatcher_test.go | 16 +- pkg/mq/msgstream/factory_stream_test.go | 50 +-- pkg/mq/msgstream/mock_msgstream.go | 59 +++- pkg/mq/msgstream/mq_kafka_msgstream_test.go | 37 ++- pkg/mq/msgstream/mq_msgstream.go | 172 ++++++----- pkg/mq/msgstream/mq_msgstream_test.go | 52 ++-- pkg/mq/msgstream/mq_rocksmq_msgstream_test.go | 20 +- pkg/mq/msgstream/msgstream.go | 289 +++++++++++++++++- pkg/mq/msgstream/msgstream_util.go | 29 ++ pkg/mq/msgstream/trace.go | 14 +- pkg/mq/msgstream/wasted_mock_msgstream.go | 4 +- 25 files changed, 681 insertions(+), 224 deletions(-) diff --git a/internal/flushcommon/pipeline/data_sync_service_test.go b/internal/flushcommon/pipeline/data_sync_service_test.go index 1a55231b0e8d8..caa0ddeb6557d 100644 --- a/internal/flushcommon/pipeline/data_sync_service_test.go +++ b/internal/flushcommon/pipeline/data_sync_service_test.go @@ -308,7 +308,7 @@ type DataSyncServiceSuite struct { channelCheckpointUpdater *util2.ChannelCheckpointUpdater factory *dependency.MockFactory ms *msgstream.MockMsgStream - msChan chan *msgstream.MsgPack + msChan chan *msgstream.ConsumeMsgPack } func (s *DataSyncServiceSuite) SetupSuite() { @@ -330,7 +330,7 @@ func (s *DataSyncServiceSuite) SetupTest() { s.channelCheckpointUpdater = util2.NewChannelCheckpointUpdater(s.broker) go s.channelCheckpointUpdater.Start() - s.msChan = make(chan *msgstream.MsgPack, 1) + s.msChan = make(chan *msgstream.ConsumeMsgPack, 1) s.factory = dependency.NewMockFactory(s.T()) s.ms = msgstream.NewMockMsgStream(s.T()) @@ -338,6 +338,7 @@ func (s *DataSyncServiceSuite) SetupTest() { s.ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) s.ms.EXPECT().Chan().Return(s.msChan) s.ms.EXPECT().Close().Return() + s.ms.EXPECT().GetUnmarshalDispatcher().Return(nil) s.pipelineParams = &util2.PipelineParams{ Ctx: context.TODO(), @@ -487,8 +488,8 @@ func (s *DataSyncServiceSuite) TestStartStop() { close(ch) return nil }) - s.msChan <- &msgPack - s.msChan <- &timeTickMsgPack + s.msChan <- msgstream.BuildConsumeMsgPack(&msgPack) + s.msChan <- msgstream.BuildConsumeMsgPack(&msgPack) <-ch } diff --git a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go index e5aaa47eeaa6d..2b0cca39238a3 100644 --- a/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go +++ b/internal/flushcommon/pipeline/flow_graph_dmstream_input_node_test.go @@ -67,8 +67,8 @@ func (mtm *mockTtMsgStream) SetReplicate(config *msgstream.ReplicateConfig) { func (mtm *mockTtMsgStream) Close() {} -func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack { - return make(chan *msgstream.MsgPack, 100) +func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack { + return make(chan *msgstream.ConsumeMsgPack, 100) } func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {} @@ -77,6 +77,10 @@ func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, s return nil } +func (mtm *mockTtMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher { + return nil +} + func (mtm *mockTtMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {} func (mtm *mockTtMsgStream) GetProduceChannels() []string { diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 18da0945d6bd5..c90b50ab15df6 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -235,7 +235,7 @@ func newDefaultMockDqlTask() *mockDqlTask { } type simpleMockMsgStream struct { - msgChan chan *msgstream.MsgPack + msgChan chan *msgstream.ConsumeMsgPack msgCount int msgCountMtx sync.RWMutex @@ -244,7 +244,7 @@ type simpleMockMsgStream struct { func (ms *simpleMockMsgStream) Close() { } -func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack { +func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack { if ms.getMsgCount() <= 0 { ms.msgChan <- nil return ms.msgChan @@ -255,6 +255,10 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack { return ms.msgChan } +func (ms *simpleMockMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher { + return nil +} + func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) { } @@ -286,8 +290,7 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) { func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error { defer ms.increaseMsgCount(1) - ms.msgChan <- pack - + ms.msgChan <- msgstream.BuildConsumeMsgPack(pack) return nil } @@ -319,7 +322,7 @@ func (ms *simpleMockMsgStream) SetReplicate(config *msgstream.ReplicateConfig) { func newSimpleMockMsgStream() *simpleMockMsgStream { return &simpleMockMsgStream{ - msgChan: make(chan *msgstream.MsgPack, 1024), + msgChan: make(chan *msgstream.ConsumeMsgPack, 1024), msgCount: 0, } } diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index a6ffc4f1b0d86..3573be34a4101 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -727,7 +727,15 @@ func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, positio if err != nil { return nil, stream.Close, err } - return stream.Chan(), stream.Close, nil + + dispatcher := msgstream.NewSimpleMsgDispatcher(stream, func(pm msgstream.PackMsg) bool { + if pm.GetType() != commonpb.MsgType_Delete || pm.GetChannel() != vchannelName { + return false + } + return true + }) + + return dispatcher.Chan(), dispatcher.Close, nil } func (sd *shardDelegator) createDeleteStreamFromStreamingService(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) { diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index f164b9c93ff7a..08ac5e4898a8f 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -207,6 +207,8 @@ func (s *DelegatorDataSuite) SetupTest() { // init schema s.genNormalCollection() s.mq = &msgstream.MockMsgStream{} + s.mq.EXPECT().GetUnmarshalDispatcher().Return(nil) + s.rootPath = s.Suite.T().Name() chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background()) @@ -916,8 +918,9 @@ func (s *DelegatorDataSuite) TestLoadSegments() { s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) + s.mq.EXPECT().GetUnmarshalDispatcher().Return(nil) s.mq.EXPECT().Close() - ch := make(chan *msgstream.MsgPack, 10) + ch := make(chan *msgstream.ConsumeMsgPack, 10) close(ch) s.mq.EXPECT().Chan().Return(ch) @@ -1584,7 +1587,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() { s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mq.EXPECT().Close() - ch := make(chan *msgstream.MsgPack, 10) + ch := make(chan *msgstream.ConsumeMsgPack, 10) s.mq.EXPECT().Chan().Return(ch) oracle := pkoracle.NewBloomFilterSet(1, 1, commonpb.SegmentState_Sealed) @@ -1602,7 +1605,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() { } for _, data := range datas { - ch <- data + ch <- msgstream.BuildConsumeMsgPack(data) } result, err := s.delegator.readDeleteFromMsgstream(ctx, &msgpb.MsgPosition{Timestamp: 0}, 10, oracle) diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index f3edb64965629..6b0a690f5b977 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -64,7 +64,7 @@ import ( type ServiceSuite struct { suite.Suite // Data - msgChan chan *msgstream.MsgPack + msgChan chan *msgstream.ConsumeMsgPack collectionID int64 collectionName string schema *schemapb.CollectionSchema diff --git a/internal/rootcoord/alter_collection_task_test.go b/internal/rootcoord/alter_collection_task_test.go index f0a0edb75ea5f..b333e192578e5 100644 --- a/internal/rootcoord/alter_collection_task_test.go +++ b/internal/rootcoord/alter_collection_task_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -250,7 +251,7 @@ func Test_alterCollectionTask_Execute(t *testing.T) { broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error { return nil } - packChan := make(chan *msgstream.MsgPack, 10) + packChan := make(chan *msgstream.ConsumeMsgPack, 10) ticker := newChanTimeTickSync(packChan) ticker.addDmlChannels("by-dev-rootcoord-dml_1") @@ -268,13 +269,18 @@ func Test_alterCollectionTask_Execute(t *testing.T) { }, } + unmarshalFactory := &msgstream.ProtoUDFactory{} + unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher() + err := task.Execute(context.Background()) assert.NoError(t, err) time.Sleep(time.Second) select { case pack := <-packChan: - assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) - replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType()) + tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher) + require.NoError(t, err) + replicateMsg := tsMsg.(*msgstream.ReplicateMsg) assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase()) assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection()) assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) diff --git a/internal/rootcoord/alter_database_task_test.go b/internal/rootcoord/alter_database_task_test.go index 47b66e176b4e7..2edf9ff53338b 100644 --- a/internal/rootcoord/alter_database_task_test.go +++ b/internal/rootcoord/alter_database_task_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/metastore/model" @@ -236,7 +237,7 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { mock.Anything, ).Return(nil) // the chan length should larger than 4, because newChanTimeTickSync will send 4 ts messages when execute the `broadcast` step - packChan := make(chan *msgstream.MsgPack, 10) + packChan := make(chan *msgstream.ConsumeMsgPack, 10) ticker := newChanTimeTickSync(packChan) ticker.addDmlChannels("by-dev-rootcoord-dml_1") @@ -252,13 +253,19 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { }, } + unmarshalFactory := &msgstream.ProtoUDFactory{} + unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher() + err := task.Execute(context.Background()) assert.NoError(t, err) time.Sleep(time.Second) select { case pack := <-packChan: - assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type()) - replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg) + assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType()) + + tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher) + require.NoError(t, err) + replicateMsg := tsMsg.(*msgstream.ReplicateMsg) assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetDatabase()) assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd()) default: diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index 7d7f3284835f2..12a64bec19329 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -277,10 +277,11 @@ type FailMsgStream struct { errBroadcast bool } -func (ms *FailMsgStream) Close() {} -func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil } -func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {} -func (ms *FailMsgStream) AsReader(channels []string, subName string) {} +func (ms *FailMsgStream) Close() {} +func (ms *FailMsgStream) Chan() <-chan *msgstream.ConsumeMsgPack { return nil } +func (ms *FailMsgStream) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher { return nil } +func (ms *FailMsgStream) AsProducer(ctx context.Context, channels []string) {} +func (ms *FailMsgStream) AsReader(channels []string, subName string) {} func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { return nil } diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 19c28253c9da8..9b46f2563cbee 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -1059,23 +1059,23 @@ func newTickerWithFactory(factory msgstream.Factory) *timetickSync { return ticker } -func newChanTimeTickSync(packChan chan *msgstream.MsgPack) *timetickSync { +func newChanTimeTickSync(packChan chan *msgstream.ConsumeMsgPack) *timetickSync { f := msgstream.NewMockMqFactory() f.NewMsgStreamFunc = func(ctx context.Context) (msgstream.MsgStream, error) { stream := msgstream.NewWastedMockMsgStream() stream.BroadcastFunc = func(pack *msgstream.MsgPack) error { log.Info("mock Broadcast") - packChan <- pack + packChan <- msgstream.BuildConsumeMsgPack(pack) return nil } stream.BroadcastMarkFunc = func(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { log.Info("mock BroadcastMark") - packChan <- pack + packChan <- msgstream.BuildConsumeMsgPack(pack) return map[string][]msgstream.MessageID{}, nil } stream.AsProducerFunc = func(channels []string) { } - stream.ChanFunc = func() <-chan *msgstream.MsgPack { + stream.ChanFunc = func() <-chan *msgstream.ConsumeMsgPack { return packChan } return stream, nil diff --git a/internal/util/flowgraph/input_node_test.go b/internal/util/flowgraph/input_node_test.go index bd7087b44c476..42f37d64a4cda 100644 --- a/internal/util/flowgraph/input_node_test.go +++ b/internal/util/flowgraph/input_node_test.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -45,8 +46,9 @@ func TestInputNode(t *testing.T) { produceStream.AsProducer(context.TODO(), channels) produceStream.Produce(context.TODO(), &msgPack) + dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.PackMsg) bool { return true }) nodeName := "input_node" - inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "") + inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, "", 0, 0, "") defer inputNode.Close() isInputNode := inputNode.IsInputNode() @@ -89,7 +91,8 @@ func Test_InputNodeSkipMode(t *testing.T) { outputCh := make(chan bool) nodeName := "input_node" - inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, typeutil.DataNodeRole, 0, 0, "") + dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.PackMsg) bool { return true }) + inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, typeutil.DataNodeRole, 0, 0, "") defer inputNode.Close() outputCount := 0 diff --git a/internal/util/flowgraph/node_test.go b/internal/util/flowgraph/node_test.go index 08752ca435ec5..a442b6410c99b 100644 --- a/internal/util/flowgraph/node_test.go +++ b/internal/util/flowgraph/node_test.go @@ -89,7 +89,8 @@ func TestNodeManager_Start(t *testing.T) { produceStream.Produce(context.TODO(), &msgPack) nodeName := "input_node" - inputNode := NewInputNode(msgStream.Chan(), nodeName, 100, 100, "", 0, 0, "") + dispatcher := msgstream.NewSimpleMsgDispatcher(msgStream, func(pm msgstream.PackMsg) bool { return true }) + inputNode := NewInputNode(dispatcher.Chan(), nodeName, 100, 100, "", 0, 0, "") ddNode := BaseNode{} diff --git a/pkg/mq/common/message.go b/pkg/mq/common/message.go index e0b215d8d1804..0debdf5e0a5e1 100644 --- a/pkg/mq/common/message.go +++ b/pkg/mq/common/message.go @@ -74,7 +74,12 @@ const ( SubscriptionPositionUnknown ) -const MsgTypeKey = "msg_type" +const ( + MsgTypeKey = "msg_type" + TimestampTypeKey = "timestamp" + ChannelTypeKey = "vchannel" + ReplicateIDTypeKey = "replicate_id" +) func GetMsgType(msg Message) (commonpb.MsgType, error) { msgType := commonpb.MsgType_Undefined diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index 2df78b2c901b8..ff85611db0476 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -19,7 +19,6 @@ package msgdispatcher import ( "context" "fmt" - "strconv" "strings" "sync" "time" @@ -229,7 +228,7 @@ func (d *Dispatcher) work() { } d.curTs.Store(pack.EndPositions[0].GetTimestamp()) - targetPacks := d.groupingMsgs(pack) + targetPacks := d.groupAndParseMsgs(pack, d.stream.GetUnmarshalDispatcher()) for vchannel, p := range targetPacks { var err error t := d.targets[vchannel] @@ -260,7 +259,7 @@ func (d *Dispatcher) work() { } } -func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { +func (d *Dispatcher) groupAndParseMsgs(pack *msgstream.ConsumeMsgPack, unmarshalDispatcher msgstream.UnmarshalDispatcher) map[string]*MsgPack { // init packs for all targets, even though there's no msg in pack, // but we still need to dispatch time ticks to the targets. targetPacks := make(map[string]*MsgPack) @@ -280,27 +279,24 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { // group messages by vchannel for _, msg := range pack.Msgs { var vchannel, collectionID string - switch msg.Type() { - case commonpb.MsgType_Insert: - vchannel = msg.(*msgstream.InsertMsg).GetShardName() - case commonpb.MsgType_Delete: - vchannel = msg.(*msgstream.DeleteMsg).GetShardName() - case commonpb.MsgType_CreateCollection: - collectionID = strconv.FormatInt(msg.(*msgstream.CreateCollectionMsg).GetCollectionID(), 10) - case commonpb.MsgType_DropCollection: - collectionID = strconv.FormatInt(msg.(*msgstream.DropCollectionMsg).GetCollectionID(), 10) - case commonpb.MsgType_CreatePartition: - collectionID = strconv.FormatInt(msg.(*msgstream.CreatePartitionMsg).GetCollectionID(), 10) - case commonpb.MsgType_DropPartition: - collectionID = strconv.FormatInt(msg.(*msgstream.DropPartitionMsg).GetCollectionID(), 10) + + if msg.GetType() == commonpb.MsgType_Insert || msg.GetType() == commonpb.MsgType_Delete { + vchannel = msg.GetChannel() + } else if msg.GetType() == commonpb.MsgType_CreateCollection || + msg.GetType() == commonpb.MsgType_DropCollection || + msg.GetType() == commonpb.MsgType_CreatePartition || + msg.GetType() == commonpb.MsgType_DropPartition { + collectionID = msg.GetChannel() // TODO AOIASD } + if vchannel == "" { // we need to dispatch it to the vchannel of this collection + targets := []string{} for k := range targetPacks { - if msg.Type() == commonpb.MsgType_Replicate { + if msg.GetType() == commonpb.MsgType_Replicate { config := replicateConfigs[k] - if config != nil && msgstream.MatchReplicateID(msg, config.ReplicateID) { - targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg) + if config != nil && msg.GetReplicateID() == config.ReplicateID { + targets = append(targets, k) } continue } @@ -308,14 +304,29 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { if !strings.Contains(k, collectionID) { continue } + targets = append(targets, k) + } + if len(targets) > 0 { + tsMsg, err := msg.Unmarshal(unmarshalDispatcher) + if err != nil { + log.Warn("unmarshl message failed", zap.Error(err)) + continue + } // TODO: There's data race when non-dml msg is sent to different flow graph. // Wrong open-trancing information is generated, Fix in future. - targetPacks[k].Msgs = append(targetPacks[k].Msgs, msg) + for _, target := range targets { + targetPacks[target].Msgs = append(targetPacks[target].Msgs, tsMsg) + } } continue } if _, ok := targetPacks[vchannel]; ok { - targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, msg) + tsMsg, err := msg.Unmarshal(unmarshalDispatcher) // TODO AOIASD UNMARSHAL + if err != nil { + log.Warn("unmarshl message failed", zap.Error(err)) + continue + } + targetPacks[vchannel].Msgs = append(targetPacks[vchannel].Msgs, tsMsg) } } replicateEndChannels := make(map[string]struct{}) diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index d4c20fae8c3fb..7da33191fbc9a 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -150,7 +150,7 @@ func TestGroupMessage(t *testing.T) { d.AddTarget(newTarget("mock_pchannel_0_2v0", nil, msgstream.GetReplicateConfig("local-test", "foo", "coo"))) { // no replicate msg - packs := d.groupingMsgs(&MsgPack{ + packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{ BeginTs: 1, EndTs: 10, StartPositions: []*msgstream.MsgPosition{ @@ -182,13 +182,13 @@ func TestGroupMessage(t *testing.T) { }, }, }, - }) + }), nil) assert.Len(t, packs, 1) } { // equal to replicateID - packs := d.groupingMsgs(&MsgPack{ + packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{ BeginTs: 1, EndTs: 10, StartPositions: []*msgstream.MsgPosition{ @@ -222,7 +222,7 @@ func TestGroupMessage(t *testing.T) { }, }, }, - }) + }), nil) assert.Len(t, packs, 2) { replicatePack := packs["mock_pchannel_0_2v0"] @@ -244,7 +244,7 @@ func TestGroupMessage(t *testing.T) { { // not equal to replicateID - packs := d.groupingMsgs(&MsgPack{ + packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{ BeginTs: 1, EndTs: 10, StartPositions: []*msgstream.MsgPosition{ @@ -278,7 +278,7 @@ func TestGroupMessage(t *testing.T) { }, }, }, - }) + }), nil) assert.Len(t, packs, 1) replicatePack := packs["mock_pchannel_0_2v0"] assert.Nil(t, replicatePack) @@ -288,7 +288,7 @@ func TestGroupMessage(t *testing.T) { // replicate end replicateTarget := d.targets["mock_pchannel_0_2v0"] assert.NotNil(t, replicateTarget.replicateConfig) - packs := d.groupingMsgs(&MsgPack{ + packs := d.groupAndParseMsgs(msgstream.BuildConsumeMsgPack(&MsgPack{ BeginTs: 1, EndTs: 10, StartPositions: []*msgstream.MsgPosition{ @@ -324,7 +324,7 @@ func TestGroupMessage(t *testing.T) { }, }, }, - }) + }), nil) assert.Len(t, packs, 2) replicatePack := packs["mock_pchannel_0_2v0"] assert.EqualValues(t, 100, replicatePack.BeginTs) diff --git a/pkg/mq/msgstream/factory_stream_test.go b/pkg/mq/msgstream/factory_stream_test.go index be0e6fb503924..1d33b29b2e073 100644 --- a/pkg/mq/msgstream/factory_stream_test.go +++ b/pkg/mq/msgstream/factory_stream_test.go @@ -266,7 +266,7 @@ func testSeekToLast(t *testing.T, f []Factory) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consume(ctx, consumer) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) if i == 5 { seekPosition = result.EndPositions[0] } @@ -295,11 +295,11 @@ func testSeekToLast(t *testing.T, f []Factory) { assert.Equal(t, 1, len(msgPack.Msgs)) for _, tsMsg := range msgPack.Msgs { - assert.Equal(t, value, tsMsg.ID()) + assert.Equal(t, value, tsMsg.GetID()) value++ cnt++ - ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID) + ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID) assert.NoError(t, err) if ret { hasMore = false @@ -398,13 +398,17 @@ func testTimeTickerSeek(t *testing.T, f []Factory) { assert.Equal(t, len(seekMsg.Msgs), 3) result := []uint64{14, 12, 13} for i, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), result[i]) + tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), result[i]) } seekMsg2 := consume(ctx, consumer) assert.Equal(t, len(seekMsg2.Msgs), 1) for _, msg := range seekMsg2.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } consumer.Close() @@ -412,7 +416,9 @@ func testTimeTickerSeek(t *testing.T, f []Factory) { seekMsg = consume(ctx, consumer) assert.Equal(t, len(seekMsg.Msgs), 1) for _, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } consumer.Close() } @@ -473,9 +479,11 @@ func testTimeTickerStream1(t *testing.T, f []Factory) { rcvMsg += len(msgPack.Msgs) if len(msgPack.Msgs) > 0 { for _, msg := range msgPack.Msgs { - log.Println("msg type: ", msg.Type(), ", msg value: ", msg) - assert.Greater(t, msg.BeginTs(), msgPack.BeginTs) - assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs) + tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher()) + require.NoError(t, err) + log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg) + assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs) + assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs) } log.Println("================") } @@ -525,7 +533,7 @@ func testTimeTickerStream2(t *testing.T, f []Factory) { // consume msg log.Println("=============receive msg===================") - rcvMsgPacks := make([]*MsgPack, 0) + rcvMsgPacks := make([]*ConsumeMsgPack, 0) resumeMsgPack := func(t *testing.T) int { var consumer MsgStream @@ -539,9 +547,11 @@ func testTimeTickerStream2(t *testing.T, f []Factory) { rcvMsgPacks = append(rcvMsgPacks, msgPack) if len(msgPack.Msgs) > 0 { for _, msg := range msgPack.Msgs { - log.Println("msg type: ", msg.Type(), ", msg value: ", msg) - assert.Greater(t, msg.BeginTs(), msgPack.BeginTs) - assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs) + tsMsg, err := msg.Unmarshal(consumer.GetUnmarshalDispatcher()) + require.NoError(t, err) + log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg) + assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs) + assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs) } log.Println("================") } @@ -576,7 +586,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consume(ctx, consumer) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) if i == 5 { seekPosition = result.EndPositions[0] } @@ -586,7 +596,7 @@ func testMqMsgStreamSeek(t *testing.T, f []Factory) { consumer = createAndSeekConsumer(ctx, t, f[0].NewMsgStream, channels, []*msgpb.MsgPosition{seekPosition}) for i := 6; i < 10; i++ { result := consume(ctx, consumer) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) } consumer.Close() } @@ -610,7 +620,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consume(ctx, consumer) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) seekPosition = result.EndPositions[0] } @@ -625,7 +635,7 @@ func testMqMsgStreamSeekInvalidMessage(t *testing.T, f []Factory, pg positionGen err = producer.Produce(ctx, msgPack) assert.NoError(t, err) result := consume(ctx, consumer2) - assert.Equal(t, result.Msgs[0].ID(), int64(1)) + assert.Equal(t, result.Msgs[0].GetID(), int64(1)) } func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) { @@ -658,7 +668,7 @@ func testMqMsgStreamSeekLatest(t *testing.T, f []Factory) { for i := 10; i < 20; i++ { result := consume(ctx, consumer2) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) } } @@ -748,7 +758,7 @@ func applyProduceAndConsume( receiveAndValidateMsg(context.Background(), consumer, len(msgPack.Msgs)) } -func consume(ctx context.Context, mq MsgStream) *MsgPack { +func consume(ctx context.Context, mq MsgStream) *ConsumeMsgPack { for { select { case msgPack, ok := <-mq.Chan(): @@ -829,7 +839,7 @@ func receiveAndValidateMsg(ctx context.Context, outputStream MsgStream, msgCount msgs := result.Msgs for _, v := range msgs { receiveCount++ - log.Println("msg type: ", v.Type(), ", msg value: ", v) + log.Println("msg type: ", v.GetType(), ", msg value: ", v) } log.Println("================") } diff --git a/pkg/mq/msgstream/mock_msgstream.go b/pkg/mq/msgstream/mock_msgstream.go index 169d03a756607..758666f2b441d 100644 --- a/pkg/mq/msgstream/mock_msgstream.go +++ b/pkg/mq/msgstream/mock_msgstream.go @@ -168,19 +168,19 @@ func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(context.Context, * } // Chan provides a mock function with given fields: -func (_m *MockMsgStream) Chan() <-chan *MsgPack { +func (_m *MockMsgStream) Chan() <-chan *ConsumeMsgPack { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Chan") } - var r0 <-chan *MsgPack - if rf, ok := ret.Get(0).(func() <-chan *MsgPack); ok { + var r0 <-chan *ConsumeMsgPack + if rf, ok := ret.Get(0).(func() <-chan *ConsumeMsgPack); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan *MsgPack) + r0 = ret.Get(0).(<-chan *ConsumeMsgPack) } } @@ -204,12 +204,12 @@ func (_c *MockMsgStream_Chan_Call) Run(run func()) *MockMsgStream_Chan_Call { return _c } -func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *MsgPack) *MockMsgStream_Chan_Call { +func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *ConsumeMsgPack) *MockMsgStream_Chan_Call { _c.Call.Return(_a0) return _c } -func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *MsgPack) *MockMsgStream_Chan_Call { +func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *ConsumeMsgPack) *MockMsgStream_Chan_Call { _c.Call.Return(run) return _c } @@ -430,6 +430,53 @@ func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []strin return _c } +// GetUnmarshalDispatcher provides a mock function with given fields: +func (_m *MockMsgStream) GetUnmarshalDispatcher() UnmarshalDispatcher { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetUnmarshalDispatcher") + } + + var r0 UnmarshalDispatcher + if rf, ok := ret.Get(0).(func() UnmarshalDispatcher); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(UnmarshalDispatcher) + } + } + + return r0 +} + +// MockMsgStream_GetUnmarshalDispatcher_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUnmarshalDispatcher' +type MockMsgStream_GetUnmarshalDispatcher_Call struct { + *mock.Call +} + +// GetUnmarshalDispatcher is a helper method to define mock.On call +func (_e *MockMsgStream_Expecter) GetUnmarshalDispatcher() *MockMsgStream_GetUnmarshalDispatcher_Call { + return &MockMsgStream_GetUnmarshalDispatcher_Call{Call: _e.mock.On("GetUnmarshalDispatcher")} +} + +func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) Run(run func()) *MockMsgStream_GetUnmarshalDispatcher_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) Return(_a0 UnmarshalDispatcher) *MockMsgStream_GetUnmarshalDispatcher_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsgStream_GetUnmarshalDispatcher_Call) RunAndReturn(run func() UnmarshalDispatcher) *MockMsgStream_GetUnmarshalDispatcher_Call { + _c.Call.Return(run) + return _c +} + // Produce provides a mock function with given fields: _a0, _a1 func (_m *MockMsgStream) Produce(_a0 context.Context, _a1 *MsgPack) error { ret := _m.Called(_a0, _a1) diff --git a/pkg/mq/msgstream/mq_kafka_msgstream_test.go b/pkg/mq/msgstream/mq_kafka_msgstream_test.go index 18317e4501643..054f1aed73d22 100644 --- a/pkg/mq/msgstream/mq_kafka_msgstream_test.go +++ b/pkg/mq/msgstream/mq_kafka_msgstream_test.go @@ -24,6 +24,7 @@ import ( "github.com/confluentinc/confluent-kafka-go/kafka" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -131,7 +132,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) { outputStream := getKafkaOutputStream(ctx, kafkaAddress, consumerChannels, consumerSubName, common.SubscriptionPositionEarliest) for i := 0; i < 10; i++ { result := consumer(ctx, outputStream) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) if i == 5 { seekPosition = result.EndPositions[0] break @@ -162,11 +163,11 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) { assert.Equal(t, 1, len(msgPack.Msgs)) for _, tsMsg := range msgPack.Msgs { - assert.Equal(t, value, tsMsg.ID()) + assert.Equal(t, value, tsMsg.GetID()) value++ cnt++ - ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID) + ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID) assert.NoError(t, err) if ret { hasMore = false @@ -272,20 +273,26 @@ func TestStream_KafkaTtMsgStream_Seek(t *testing.T) { assert.Equal(t, len(seekMsg.Msgs), 3) result := []uint64{14, 12, 13} for i, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), result[i]) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), result[i]) } seekMsg2 := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg2.Msgs), 1) for _, msg := range seekMsg2.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } outputStream2 := getKafkaTtOutputStreamAndSeek(ctx, kafkaAddress, receivedMsg3.EndPositions) seekMsg = consumer(ctx, outputStream2) assert.Equal(t, len(seekMsg.Msgs), 1) for _, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } inputStream.Close() @@ -320,9 +327,11 @@ func TestStream_KafkaTtMsgStream_1(t *testing.T) { rcvMsg += len(msgPack.Msgs) if len(msgPack.Msgs) > 0 { for _, msg := range msgPack.Msgs { - log.Println("msg type: ", msg.Type(), ", msg value: ", msg) - assert.Greater(t, msg.BeginTs(), msgPack.BeginTs) - assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg) + assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs) + assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs) } } } @@ -361,7 +370,7 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) { // consume msg log.Println("=============receive msg===================") - rcvMsgPacks := make([]*MsgPack, 0) + rcvMsgPacks := make([]*ConsumeMsgPack, 0) resumeMsgPack := func(t *testing.T) int { var outputStream MsgStream @@ -376,9 +385,11 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) { rcvMsgPacks = append(rcvMsgPacks, msgPack) if len(msgPack.Msgs) > 0 { for _, msg := range msgPack.Msgs { - log.Println("TestStream_KafkaTtMsgStream_2 msg type: ", msg.Type(), ", msg value: ", msg) - assert.Greater(t, msg.BeginTs(), msgPack.BeginTs) - assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + log.Println("TestStream_KafkaTtMsgStream_2 msg type: ", tsMsg.Type(), ", msg value: ", msg) + assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs) + assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs) } log.Println("================") } diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 76486865e81a1..bcbb2acbe2298 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -61,7 +61,7 @@ type mqMsgStream struct { repackFunc RepackFunc unmarshal UnmarshalDispatcher - receiveBuf chan *MsgPack + receiveBuf chan *ConsumeMsgPack closeRWMutex *sync.RWMutex streamCancel func() bufSize int64 @@ -89,7 +89,7 @@ func NewMqMsgStream(ctx context.Context, consumers := make(map[string]mqwrapper.Consumer) producerChannels := make([]string, 0) consumerChannels := make([]string, 0) - receiveBuf := make(chan *MsgPack, receiveBufSize) + receiveBuf := make(chan *ConsumeMsgPack, receiveBufSize) stream := &mqMsgStream{ ctx: streamCtx, @@ -355,9 +355,7 @@ func (ms *mqMsgStream) Produce(ctx context.Context, msgPack *MsgPack) error { return err } - msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{ - common.MsgTypeKey: v.Msgs[i].Type().String(), - }} + msg := &common.ProducerMessage{Payload: m, Properties: GetPorperties(v.Msgs[i])} InjectCtx(spanCtx, msg.Properties) if _, err := producer.Send(spanCtx, msg); err != nil { @@ -399,7 +397,7 @@ func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[str return ids, err } - msg := &common.ProducerMessage{Payload: m, Properties: map[string]string{}} + msg := &common.ProducerMessage{Payload: m, Properties: GetPorperties(v)} InjectCtx(spanCtx, msg.Properties) ms.producerLock.RLock() @@ -421,10 +419,6 @@ func (ms *mqMsgStream) Broadcast(ctx context.Context, msgPack *MsgPack) (map[str return ids, nil } -func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg common.Message) (TsMsg, error) { - return GetTsMsgFromConsumerMsg(ms.unmarshal, msg) -} - // GetTsMsgFromConsumerMsg get TsMsg from consumer message func GetTsMsgFromConsumerMsg(unmarshalDispatcher UnmarshalDispatcher, msg common.Message) (TsMsg, error) { msgType, err := common.GetMsgType(msg) @@ -464,31 +458,35 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { log.Ctx(ms.ctx).Warn("MqMsgStream get msg whose payload is nil") continue } - // not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295 - // if the message not belong to the topic, will skip it - tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) + + var err error + var packMsg PackMsg + + packMsg, err = NewMarshaledMsg(msg, consumer.Subscription()) if err != nil { - log.Ctx(ms.ctx).Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) - continue + packMsg, err = UnmarshalMsg(msg, ms.unmarshal) + if err != nil { + log.Ctx(ms.ctx).Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) + continue + } } - pos := tsMsg.Position() - tsMsg.SetPosition(&MsgPosition{ - ChannelName: pos.ChannelName, - MsgID: pos.MsgID, - MsgGroup: consumer.Subscription(), - Timestamp: tsMsg.BeginTs(), - }) - ctx, _ := ExtractCtx(tsMsg, msg.Properties()) - tsMsg.SetTraceCtx(ctx) + pos := &msgpb.MsgPosition{ + ChannelName: packMsg.GetChannel(), + MsgID: packMsg.GetMessageID(), + MsgGroup: consumer.Subscription(), + Timestamp: packMsg.GetTimestamp(), + } - msgPack := MsgPack{ - Msgs: []TsMsg{tsMsg}, - StartPositions: []*msgpb.MsgPosition{tsMsg.Position()}, - EndPositions: []*msgpb.MsgPosition{tsMsg.Position()}, - BeginTs: tsMsg.BeginTs(), - EndTs: tsMsg.EndTs(), + packMsg.SetPosition(pos) + msgPack := ConsumeMsgPack{ + Msgs: []PackMsg{packMsg}, + StartPositions: []*msgpb.MsgPosition{pos}, + EndPositions: []*msgpb.MsgPosition{pos}, + BeginTs: packMsg.GetTimestamp(), + EndTs: packMsg.GetTimestamp(), } + select { case ms.receiveBuf <- &msgPack: case <-ms.ctx.Done(): @@ -498,7 +496,11 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { } } -func (ms *mqMsgStream) Chan() <-chan *MsgPack { +func (ms *mqMsgStream) GetUnmarshalDispatcher() UnmarshalDispatcher { + return ms.unmarshal +} + +func (ms *mqMsgStream) Chan() <-chan *ConsumeMsgPack { ms.onceChan.Do(func() { for _, c := range ms.consumers { go ms.receiveMsg(c) @@ -546,7 +548,7 @@ var _ MsgStream = (*MqTtMsgStream)(nil) // MqTtMsgStream is a msgstream that contains timeticks type MqTtMsgStream struct { *mqMsgStream - chanMsgBuf map[mqwrapper.Consumer][]TsMsg + chanMsgBuf map[mqwrapper.Consumer][]PackMsg chanMsgPos map[mqwrapper.Consumer]*msgpb.MsgPosition chanStopChan map[mqwrapper.Consumer]chan bool chanTtMsgTime map[mqwrapper.Consumer]Timestamp @@ -568,7 +570,7 @@ func NewMqTtMsgStream(ctx context.Context, if err != nil { return nil, err } - chanMsgBuf := make(map[mqwrapper.Consumer][]TsMsg) + chanMsgBuf := make(map[mqwrapper.Consumer][]PackMsg) chanMsgPos := make(map[mqwrapper.Consumer]*msgpb.MsgPosition) chanStopChan := make(map[mqwrapper.Consumer]chan bool) chanTtMsgTime := make(map[mqwrapper.Consumer]Timestamp) @@ -593,7 +595,7 @@ func (ms *MqTtMsgStream) addConsumer(consumer mqwrapper.Consumer, channel string } ms.consumers[channel] = consumer ms.consumerChannels = append(ms.consumerChannels, channel) - ms.chanMsgBuf[consumer] = make([]TsMsg, 0) + ms.chanMsgBuf[consumer] = make([]PackMsg, 0) ms.chanMsgPos[consumer] = &msgpb.MsgPosition{ ChannelName: channel, MsgID: make([]byte, 0), @@ -649,8 +651,8 @@ func (ms *MqTtMsgStream) Close() { ms.mqMsgStream.Close() } -func isDMLMsg(msg TsMsg) bool { - return msg.Type() == commonpb.MsgType_Insert || msg.Type() == commonpb.MsgType_Delete +func isDMLMsg(msg PackMsg) bool { + return msg.GetType() == commonpb.MsgType_Insert || msg.GetType() == commonpb.MsgType_Delete } func (ms *MqTtMsgStream) continueBuffering(endTs, size uint64, startTime time.Time) bool { @@ -700,7 +702,7 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { case <-ms.ctx.Done(): return default: - timeTickBuf := make([]TsMsg, 0) + timeTickBuf := make([]PackMsg, 0) // startMsgPosition := make([]*msgpb.MsgPosition, 0) // endMsgPositions := make([]*msgpb.MsgPosition, 0) startPositions := make(map[string]*msgpb.MsgPosition) @@ -739,22 +741,22 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { if _, ok := startPositions[channelName]; !ok { startPositions[channelName] = startPos } - tempBuffer := make([]TsMsg, 0) - var timeTickMsg TsMsg + tempBuffer := make([]PackMsg, 0) + var timeTickMsg PackMsg for _, v := range msgs { - if v.Type() == commonpb.MsgType_TimeTick { + if v.GetType() == commonpb.MsgType_TimeTick { timeTickMsg = v continue } - if v.EndTs() <= currTs || - GetReplicateID(v) != "" { - size += uint64(v.Size()) + if v.GetTimestamp() <= currTs || + v.GetReplicateID() != "" { + size += uint64(v.GetSize()) timeTickBuf = append(timeTickBuf, v) } else { tempBuffer = append(tempBuffer, v) } // when drop collection, force to exit the buffer loop - if v.Type() == commonpb.MsgType_DropCollection || v.Type() == commonpb.MsgType_Replicate { + if v.GetType() == commonpb.MsgType_DropCollection || v.GetType() == commonpb.MsgType_Replicate { containsEndBufferMsg = true } } @@ -765,8 +767,8 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { if len(tempBuffer) > 0 { // if tempBuffer is not empty, use tempBuffer[0] to seek newPos = &msgpb.MsgPosition{ - ChannelName: tempBuffer[0].Position().ChannelName, - MsgID: tempBuffer[0].Position().MsgID, + ChannelName: tempBuffer[0].GetChannel(), + MsgID: tempBuffer[0].GetMessageID(), Timestamp: currTs, MsgGroup: consumer.Subscription(), } @@ -774,8 +776,8 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { } else if timeTickMsg != nil { // if tempBuffer is empty, use timeTickMsg to seek newPos = &msgpb.MsgPosition{ - ChannelName: timeTickMsg.Position().ChannelName, - MsgID: timeTickMsg.Position().MsgID, + ChannelName: timeTickMsg.GetChannel(), + MsgID: timeTickMsg.GetMessageID(), Timestamp: currTs, MsgGroup: consumer.Subscription(), } @@ -787,20 +789,20 @@ func (ms *MqTtMsgStream) bufMsgPackToChannel() { ms.consumerLock.Unlock() } - idset := make(typeutil.UniqueSet) - uniqueMsgs := make([]TsMsg, 0, len(timeTickBuf)) + idset := make(typeutil.Set[int64]) + uniqueMsgs := make([]PackMsg, 0, len(timeTickBuf)) for _, msg := range timeTickBuf { - if isDMLMsg(msg) && idset.Contain(msg.ID()) { - log.Ctx(ms.ctx).Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.ID())) + if isDMLMsg(msg) && idset.Contain(msg.GetID()) { + log.Ctx(ms.ctx).Warn("mqTtMsgStream, found duplicated msg", zap.Int64("msgID", msg.GetID())) continue } - idset.Insert(msg.ID()) + idset.Insert(msg.GetID()) uniqueMsgs = append(uniqueMsgs, msg) } // skip endTs = 0 (no run for ctx error) if endTs > 0 { - msgPack := MsgPack{ + msgPack := ConsumeMsgPack{ BeginTs: ms.lastTimeStamp, EndTs: endTs, Msgs: uniqueMsgs, @@ -840,21 +842,26 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqwrapper.Consumer) { log.Warn("MqTtMsgStream get msg whose payload is nil") continue } - // not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295 - // if the message not belong to the topic, will skip it - tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) + + var err error + var packMsg PackMsg + + packMsg, err = NewMarshaledMsg(msg, consumer.Subscription()) if err != nil { - log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) - continue + packMsg, err = UnmarshalMsg(msg, ms.unmarshal) + if err != nil { + log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) + continue + } } ms.chanMsgBufMutex.Lock() - ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg) + ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], packMsg) ms.chanMsgBufMutex.Unlock() - if tsMsg.Type() == commonpb.MsgType_TimeTick { + if packMsg.GetType() == commonpb.MsgType_TimeTick { ms.chanTtMsgTimeMutex.Lock() - ms.chanTtMsgTime[consumer] = tsMsg.(*TimeTickMsg).Base.Timestamp + ms.chanTtMsgTime[consumer] = packMsg.GetTimestamp() ms.chanTtMsgTimeMutex.Unlock() return } @@ -972,20 +979,23 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, loopMsgCnt++ consumer.Ack(msg) - headerMsg := commonpb.MsgHeader{} - err := proto.Unmarshal(msg.Payload(), &headerMsg) - if err != nil { - return fmt.Errorf("failed to unmarshal message header, err %s", err.Error()) - } - tsMsg, err := ms.unmarshal.Unmarshal(msg.Payload(), headerMsg.Base.MsgType) + var err error + var packMsg PackMsg + + packMsg, err = NewMarshaledMsg(msg, consumer.Subscription()) if err != nil { - return fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error()) + packMsg, err = UnmarshalMsg(msg, ms.unmarshal) + if err != nil { + log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) + continue + } } + // skip the replicate msg because it must have been consumed - if GetReplicateID(tsMsg) != "" { + if packMsg.GetReplicateID() != "" { continue } - if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp { + if packMsg.GetType() == commonpb.MsgType_TimeTick && packMsg.GetTimestamp() >= mp.Timestamp { runLoop = false if time.Since(loopStarTime) > 30*time.Second { log.Info("seek loop finished long time", @@ -993,21 +1003,21 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, zap.String("channel", mp.ChannelName), zap.Duration("cost", time.Since(loopStarTime))) } - } else if tsMsg.BeginTs() > mp.Timestamp { - ctx, _ := ExtractCtx(tsMsg, msg.Properties()) - tsMsg.SetTraceCtx(ctx) + } else if packMsg.GetTimestamp() > mp.Timestamp { + ctx, _ := ExtractCtx(packMsg, msg.Properties()) + packMsg.SetTraceCtx(ctx) - tsMsg.SetPosition(&MsgPosition{ + packMsg.SetPosition(&MsgPosition{ ChannelName: filepath.Base(msg.Topic()), MsgID: msg.ID().Serialize(), }) - ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg) + ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], packMsg) } else { log.Info("skip msg", - zap.Int64("source", tsMsg.SourceID()), - zap.String("type", tsMsg.Type().String()), - zap.Int("size", tsMsg.Size()), - zap.Uint64("msgTs", tsMsg.BeginTs()), + // zap.Int64("source", tsMsg.SourceID()), // TODO AOIASD SOURCE ID ? + zap.String("type", packMsg.GetType().String()), + zap.Int("size", packMsg.GetSize()), + zap.Uint64("msgTs", packMsg.GetTimestamp()), zap.Uint64("posTs", mp.GetTimestamp()), ) } @@ -1017,7 +1027,7 @@ func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, return nil } -func (ms *MqTtMsgStream) Chan() <-chan *MsgPack { +func (ms *MqTtMsgStream) Chan() <-chan *ConsumeMsgPack { ms.onceChan.Do(func() { if ms.consumers != nil { go ms.bufMsgPackToChannel() diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index 1ff559c191c63..c07cd644636c0 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -85,7 +85,7 @@ func getKafkaBrokerList() string { return brokerList } -func consumer(ctx context.Context, mq MsgStream) *MsgPack { +func consumer(ctx context.Context, mq MsgStream) *ConsumeMsgPack { for { select { case msgPack, ok := <-mq.Chan(): @@ -506,7 +506,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consumer(ctx, outputStream) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) if i == 5 { seekPosition = result.EndPositions[0] } @@ -539,11 +539,11 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) { assert.Equal(t, 1, len(msgPack.Msgs)) for _, tsMsg := range msgPack.Msgs { - assert.Equal(t, value, tsMsg.ID()) + assert.Equal(t, value, tsMsg.GetID()) value++ cnt++ - ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID) + ret, err := lastMsgID.LessOrEqualThan(tsMsg.GetPosition().MsgID) assert.NoError(t, err) if ret { hasMore = false @@ -674,20 +674,26 @@ func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { assert.Equal(t, len(seekMsg.Msgs), 3) result := []uint64{14, 12, 13} for i, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), result[i]) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), result[i]) } seekMsg2 := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg2.Msgs), 1) for _, msg := range seekMsg2.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } outputStream2 := getPulsarTtOutputStreamAndSeek(ctx, pulsarAddress, receivedMsg3.EndPositions) seekMsg = consumer(ctx, outputStream2) assert.Equal(t, len(seekMsg.Msgs), 1) for _, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } inputStream.Close() @@ -882,9 +888,11 @@ func TestStream_PulsarTtMsgStream_1(t *testing.T) { rcvMsg += len(msgPack.Msgs) if len(msgPack.Msgs) > 0 { for _, msg := range msgPack.Msgs { - log.Println("msg type: ", msg.Type(), ", msg value: ", msg) - assert.Greater(t, msg.BeginTs(), msgPack.BeginTs) - assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg) + assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs) + assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs) } log.Println("================") } @@ -940,7 +948,7 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) { // consume msg log.Println("=============receive msg===================") - rcvMsgPacks := make([]*MsgPack, 0) + rcvMsgPacks := make([]*ConsumeMsgPack, 0) resumeMsgPack := func(t *testing.T) int { var outputStream MsgStream @@ -954,9 +962,11 @@ func TestStream_PulsarTtMsgStream_2(t *testing.T) { rcvMsgPacks = append(rcvMsgPacks, msgPack) if len(msgPack.Msgs) > 0 { for _, msg := range msgPack.Msgs { - log.Println("msg type: ", msg.Type(), ", msg value: ", msg) - assert.Greater(t, msg.BeginTs(), msgPack.BeginTs) - assert.LessOrEqual(t, msg.BeginTs(), msgPack.EndTs) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + log.Println("msg type: ", tsMsg.Type(), ", msg value: ", msg) + assert.Greater(t, tsMsg.BeginTs(), msgPack.BeginTs) + assert.LessOrEqual(t, tsMsg.BeginTs(), msgPack.EndTs) } log.Println("================") } @@ -998,7 +1008,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consumer(ctx, outputStream) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) if i == 5 { seekPosition = result.EndPositions[0] } @@ -1013,7 +1023,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) { for i := 6; i < 10; i++ { result := consumer(ctx, outputStream2) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) } outputStream2.Close() } @@ -1042,7 +1052,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consumer(ctx, outputStream) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) seekPosition = result.EndPositions[0] } @@ -1074,7 +1084,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { err = inputStream.Produce(ctx, msgPack) assert.NoError(t, err) result := consumer(ctx, outputStream2) - assert.Equal(t, result.Msgs[0].ID(), int64(1)) + assert.Equal(t, result.Msgs[0].GetID(), int64(1)) } func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) { @@ -1101,7 +1111,7 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consumer(ctx, outputStream) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) seekPosition = result.EndPositions[0] } @@ -1179,7 +1189,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) { for i := 10; i < 20; i++ { result := consumer(ctx, outputStream2) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) } inputStream.Close() @@ -1570,7 +1580,7 @@ func receiveMsg(ctx context.Context, outputStream MsgStream, msgCount int) { msgs := result.Msgs for _, v := range msgs { receiveCount++ - log.Println("msg type: ", v.Type(), ", msg value: ", v) + log.Println("msg type: ", v.GetType(), ", msg value: ", v) } log.Println("================") } diff --git a/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go b/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go index c982d401b4fb5..9a5cf91a929e2 100644 --- a/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_rocksmq_msgstream_test.go @@ -380,9 +380,9 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { outputStream.Seek(ctx, receivedMsg.StartPositions, false) seekMsg := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg.Msgs), 1+2) - assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1) - assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type()) - assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type()) + assert.EqualValues(t, seekMsg.Msgs[0].GetTimestamp(), 1) + assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].GetType()) + assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].GetType()) inputStream.Close() outputStream.Close() @@ -485,13 +485,17 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) { assert.Equal(t, len(seekMsg.Msgs), 3) result := []uint64{14, 12, 13} for i, msg := range seekMsg.Msgs { - assert.Equal(t, msg.BeginTs(), result[i]) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), result[i]) } seekMsg2 := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg2.Msgs), 1) for _, msg := range seekMsg2.Msgs { - assert.Equal(t, msg.BeginTs(), uint64(19)) + tsMsg, err := msg.Unmarshal(outputStream.GetUnmarshalDispatcher()) + require.NoError(t, err) + assert.Equal(t, tsMsg.BeginTs(), uint64(19)) } inputStream.Close() @@ -517,7 +521,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { var seekPosition *msgpb.MsgPosition for i := 0; i < 10; i++ { result := consumer(ctx, outputStream) - assert.Equal(t, result.Msgs[0].ID(), int64(i)) + assert.Equal(t, result.Msgs[0].GetID(), int64(i)) seekPosition = result.EndPositions[0] } outputStream.Close() @@ -550,7 +554,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { assert.NoError(t, err) result := consumer(ctx, outputStream2) - assert.Equal(t, result.Msgs[0].ID(), int64(1)) + assert.Equal(t, result.Msgs[0].GetID(), int64(1)) inputStream.Close() outputStream2.Close() @@ -585,7 +589,7 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { pack := <-outputStream.Chan() assert.NotNil(t, pack) assert.Equal(t, 1, len(pack.Msgs)) - assert.EqualValues(t, 1000, pack.Msgs[0].BeginTs()) + assert.EqualValues(t, 1000, pack.Msgs[0].GetTimestamp()) inputStream.Close() outputStream.Close() diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 3b1d4b3c32d8d..9744b16f3911e 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -18,6 +18,9 @@ package msgstream import ( "context" + "fmt" + "strconv" + "sync" "go.uber.org/zap" @@ -25,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -52,6 +56,201 @@ type MsgPack struct { EndPositions []*MsgPosition } +// ConsumeMsgPack represents a batch of msg in consumer +type ConsumeMsgPack struct { + BeginTs Timestamp + EndTs Timestamp + Msgs []PackMsg + StartPositions []*MsgPosition + EndPositions []*MsgPosition +} + +// PackMsg used for ConumserMsgPack +// support fetch some properties metric +type PackMsg interface { + GetPosition() *msgpb.MsgPosition + SetPosition(*msgpb.MsgPosition) + + GetSize() int + GetTimestamp() uint64 + GetChannel() string + GetMessageID() []byte + GetID() int64 + GetType() commonpb.MsgType + GetReplicateID() string + + SetTraceCtx(ctx context.Context) + Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) +} + +// UnmarshalledMsg pack unmarshalled tsMsg as PackMsg +// For Compatibility or Test +type UnmarshalledMsg struct { + msg TsMsg +} + +func (m *UnmarshalledMsg) GetTimestamp() uint64 { + return m.msg.BeginTs() +} + +func (m *UnmarshalledMsg) GetChannel() string { + return m.msg.Position().GetChannelName() +} + +func (m *UnmarshalledMsg) GetMessageID() []byte { + return m.msg.Position().GetMsgID() +} + +func (m *UnmarshalledMsg) GetID() int64 { + return m.msg.ID() +} + +func (m *UnmarshalledMsg) GetType() commonpb.MsgType { + return m.msg.Type() +} + +func (m *UnmarshalledMsg) GetSize() int { + return m.msg.Size() +} + +func (m *UnmarshalledMsg) GetReplicateID() string { + msgBase, ok := m.msg.(interface{ GetBase() *commonpb.MsgBase }) + if !ok { + log.Warn("fail to get msg base, please check it", zap.Any("type", m.msg.Type())) + return "" + } + return msgBase.GetBase().GetReplicateInfo().GetReplicateID() +} + +func (m *UnmarshalledMsg) SetPosition(pos *msgpb.MsgPosition) { + m.msg.SetPosition(pos) +} + +func (m *UnmarshalledMsg) GetPosition() *msgpb.MsgPosition { + return m.msg.Position() +} + +func (m *UnmarshalledMsg) SetTraceCtx(ctx context.Context) { + m.msg.SetTraceCtx(ctx) +} + +func (m *UnmarshalledMsg) Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) { + return m.msg, nil +} + +// MarshaledMsg pack marshaled tsMsg +// and parse properties +type MarshaledMsg struct { + msg common.Message + pos *MsgPosition + msgType MsgType + msgID int64 + timestamp uint64 + vchannel string + replicateID string + traceCtx context.Context +} + +func (m *MarshaledMsg) GetTimestamp() uint64 { + return m.timestamp +} + +func (m *MarshaledMsg) GetChannel() string { + return m.vchannel +} + +func (m *MarshaledMsg) GetMessageID() []byte { + return m.msg.ID().Serialize() +} + +func (m *MarshaledMsg) GetID() int64 { + return m.msgID +} + +func (m *MarshaledMsg) GetType() commonpb.MsgType { + return m.msgType +} + +func (m *MarshaledMsg) GetSize() int { + return len(m.msg.Payload()) +} + +func (m *MarshaledMsg) GetReplicateID() string { + return m.replicateID +} + +func (m *MarshaledMsg) SetPosition(pos *msgpb.MsgPosition) { + m.pos = pos +} + +func (m *MarshaledMsg) GetPosition() *msgpb.MsgPosition { + return m.pos +} + +func (m *MarshaledMsg) SetTraceCtx(ctx context.Context) { + m.traceCtx = ctx +} + +func (m *MarshaledMsg) Unmarshal(unmarshalDispatcher UnmarshalDispatcher) (TsMsg, error) { + tsMsg, err := GetTsMsgFromConsumerMsg(unmarshalDispatcher, m.msg) + if err != nil { + return nil, err + } + tsMsg.SetTraceCtx(m.traceCtx) + tsMsg.SetPosition(m.pos) + return tsMsg, nil +} + +func NewMarshaledMsg(msg common.Message, group string) (PackMsg, error) { + properties := msg.Properties() + vchannel, ok := properties[common.ChannelTypeKey] + if !ok { + return nil, fmt.Errorf("get channel namse from msg properties failed") + } + + tsStr, ok := properties[common.TimestampTypeKey] + if !ok { + return nil, fmt.Errorf("get minTs from msg properties failed") + } + + timestamp, err := strconv.ParseUint(tsStr, 10, 64) + if err != nil { + log.Warn("parse message properties minTs failed, unknown message", zap.Error(err)) + return nil, fmt.Errorf("parse minTs from msg properties failed") + } + + val, ok := properties[common.MsgTypeKey] + if !ok { + return nil, fmt.Errorf("get msgType from msg properties failed") + } + msgType := commonpb.MsgType(commonpb.MsgType_value[val]) + + result := &MarshaledMsg{ + msg: msg, + timestamp: timestamp, + msgType: msgType, + vchannel: vchannel, + } + + replicateID, ok := properties[common.ReplicateIDTypeKey] + if ok { + result.replicateID = replicateID + } + return result, nil +} + +// unmarshal common message to UnmarshalledMsg +func UnmarshalMsg(msg common.Message, unmarshalDispatcher UnmarshalDispatcher) (PackMsg, error) { + tsMsg, err := GetTsMsgFromConsumerMsg(unmarshalDispatcher, msg) + if err != nil { + return nil, err + } + + return &UnmarshalledMsg{ + msg: tsMsg, + }, nil +} + // RepackFunc is a function type which used to repack message after hash by primary key type RepackFunc func(msgs []TsMsg, hashKeys [][]int32) (map[int32]*MsgPack, error) @@ -66,7 +265,8 @@ type MsgStream interface { Broadcast(context.Context, *MsgPack) (map[string][]MessageID, error) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error - Chan() <-chan *MsgPack + Chan() <-chan *ConsumeMsgPack + GetUnmarshalDispatcher() UnmarshalDispatcher // Seek consume message from the specified position // includeCurrentMsg indicates whether to consume the current message, and in the milvus system, it should be always false Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error @@ -117,6 +317,15 @@ func GetReplicateID(msg TsMsg) string { return msgBase.GetBase().GetReplicateInfo().GetReplicateID() } +func GetTimestamp(msg TsMsg) uint64 { + msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase }) + if !ok { + log.Warn("fail to get msg base, please check it", zap.Any("type", msg.Type())) + return 0 + } + return msgBase.GetBase().GetTimestamp() +} + func MatchReplicateID(msg TsMsg, replicateID string) bool { return GetReplicateID(msg) == replicateID } @@ -126,3 +335,81 @@ type Factory interface { NewTtMsgStream(ctx context.Context) (MsgStream, error) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error } + +// Filter and parse ts message for temporary stream +type SimpleMsgDispatcher struct { + stream MsgStream + unmarshalDispatcher UnmarshalDispatcher + filter func(PackMsg) bool + ch chan *MsgPack + chOnce sync.Once + + closeCh chan struct{} + closeOnce sync.Once + wg sync.WaitGroup +} + +func NewSimpleMsgDispatcher(stream MsgStream, filter func(PackMsg) bool) *SimpleMsgDispatcher { + return &SimpleMsgDispatcher{ + stream: stream, + filter: filter, + unmarshalDispatcher: stream.GetUnmarshalDispatcher(), + closeCh: make(chan struct{}), + } +} + +func (p *SimpleMsgDispatcher) filterAndParase() { + defer func() { + close(p.ch) + p.wg.Done() + }() + for { + select { + case <-p.closeCh: + return + case marshalPack, ok := <-p.stream.Chan(): + if !ok { + log.Warn("dispatcher fail to read delta msg") + return + } + + msgPack := &MsgPack{ + BeginTs: marshalPack.BeginTs, + EndTs: marshalPack.EndTs, + Msgs: make([]TsMsg, 0), + StartPositions: marshalPack.StartPositions, + EndPositions: marshalPack.EndPositions, + } + for _, marshalMsg := range marshalPack.Msgs { + if !p.filter(marshalMsg) { + continue + } + // unmarshal message + msg, err := marshalMsg.Unmarshal(p.unmarshalDispatcher) + if err != nil { + log.Warn("unmarshal message failed, invalid message", zap.Error(err)) + continue + } + msgPack.Msgs = append(msgPack.Msgs, msg) + } + p.ch <- msgPack + } + } +} + +func (p *SimpleMsgDispatcher) Chan() chan *MsgPack { + p.chOnce.Do(func() { + p.ch = make(chan *MsgPack, paramtable.Get().MQCfg.ReceiveBufSize.GetAsInt64()) + p.wg.Add(1) + go p.filterAndParase() + }) + return p.ch +} + +func (p *SimpleMsgDispatcher) Close() { + p.closeOnce.Do(func() { + p.stream.Close() + close(p.closeCh) + p.wg.Wait() + }) +} diff --git a/pkg/mq/msgstream/msgstream_util.go b/pkg/mq/msgstream/msgstream_util.go index 7544c4d129e77..8202cabf51dc9 100644 --- a/pkg/mq/msgstream/msgstream_util.go +++ b/pkg/mq/msgstream/msgstream_util.go @@ -20,10 +20,13 @@ import ( "context" "fmt" "math/rand" + "strconv" "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/samber/lo" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" pcommon "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" @@ -136,3 +139,29 @@ func KafkaHealthCheck(clusterStatus *pcommon.MQClusterStatus) { clusterStatus.Health = true clusterStatus.Members = healthList } + +func GetPorperties(msg TsMsg) map[string]string { + properties := map[string]string{} + + properties[common.ChannelTypeKey] = msg.Position().GetChannelName() + properties[common.MsgTypeKey] = msg.Type().String() + msgBase, ok := msg.(interface{ GetBase() *commonpb.MsgBase }) + if ok { + properties[common.TimestampTypeKey] = strconv.FormatUint(msgBase.GetBase().GetTimestamp(), 10) + properties[common.ReplicateIDTypeKey] = msgBase.GetBase().GetReplicateInfo().GetReplicateID() + } + + return properties +} + +func BuildConsumeMsgPack(pack *MsgPack) *ConsumeMsgPack { + return &ConsumeMsgPack{ + BeginTs: pack.BeginTs, + EndTs: pack.EndTs, + Msgs: lo.Map(pack.Msgs, func(msg TsMsg, _ int) PackMsg { + return &UnmarshalledMsg{msg: msg} + }), + StartPositions: pack.StartPositions, + EndPositions: pack.EndPositions, + } +} diff --git a/pkg/mq/msgstream/trace.go b/pkg/mq/msgstream/trace.go index db1d027615750..cd203836af206 100644 --- a/pkg/mq/msgstream/trace.go +++ b/pkg/mq/msgstream/trace.go @@ -29,21 +29,17 @@ import ( // ExtractCtx extracts trace span from msg.properties. // And it will attach some default tags to the span. -func ExtractCtx(msg TsMsg, properties map[string]string) (context.Context, trace.Span) { - ctx := msg.TraceCtx() - if ctx == nil { - ctx = context.Background() - } +func ExtractCtx(msg PackMsg, properties map[string]string) (context.Context, trace.Span) { + ctx := context.Background() if !allowTrace(msg) { return ctx, trace.SpanFromContext(ctx) } ctx = otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(properties)) name := "ReceieveMsg" return otel.Tracer(name).Start(ctx, name, trace.WithAttributes( - attribute.Int64("ID", msg.ID()), - attribute.String("Type", msg.Type().String()), - // attribute.Int64Value("HashKeys", msg.HashKeys()), - attribute.String("Position", msg.Position().String()), + attribute.Int64("ID", msg.GetID()), + attribute.String("Type", msg.GetType().String()), + attribute.String("Position", msg.GetPosition().String()), )) } diff --git a/pkg/mq/msgstream/wasted_mock_msgstream.go b/pkg/mq/msgstream/wasted_mock_msgstream.go index 2efc0ff0e5b6e..0154e6eaf2e10 100644 --- a/pkg/mq/msgstream/wasted_mock_msgstream.go +++ b/pkg/mq/msgstream/wasted_mock_msgstream.go @@ -7,7 +7,7 @@ type WastedMockMsgStream struct { AsProducerFunc func(channels []string) BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error) BroadcastFunc func(*MsgPack) error - ChanFunc func() <-chan *MsgPack + ChanFunc func() <-chan *ConsumeMsgPack } func NewWastedMockMsgStream() *WastedMockMsgStream { @@ -22,6 +22,6 @@ func (m WastedMockMsgStream) Broadcast(ctx context.Context, pack *MsgPack) (map[ return m.BroadcastMarkFunc(pack) } -func (m WastedMockMsgStream) Chan() <-chan *MsgPack { +func (m WastedMockMsgStream) Chan() <-chan *ConsumeMsgPack { return m.ChanFunc() }