diff --git a/internal/datanode/flow_graph_dmstream_input_node_test.go b/internal/datanode/flow_graph_dmstream_input_node_test.go index 75df57af0b49c..ae804fe19e5d2 100644 --- a/internal/datanode/flow_graph_dmstream_input_node_test.go +++ b/internal/datanode/flow_graph_dmstream_input_node_test.go @@ -91,7 +91,7 @@ func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstrea return nil, nil } -func (mtm *mockTtMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error { +func (mtm *mockTtMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error { return nil } diff --git a/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go b/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go index e29171d1437a5..96be3628c9056 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go @@ -240,7 +240,7 @@ func TestMqMsgStream_SeekNotSubscribed(t *testing.T) { ChannelName: "b", }, } - err = m.Seek(context.Background(), p) + err = m.Seek(context.Background(), p, false) assert.Error(t, err) } @@ -403,7 +403,7 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) consumerSubName = funcutil.RandomString(8) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(ctx, receivedMsg.StartPositions) + outputStream.Seek(ctx, receivedMsg.StartPositions, false) seekMsg := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg.Msgs), 1+2) assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1) @@ -506,7 +506,7 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) { consumerSubName = funcutil.RandomString(8) outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(ctx, receivedMsg3.StartPositions) + outputStream.Seek(ctx, receivedMsg3.StartPositions, false) seekMsg := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg.Msgs), 3) result := []uint64{14, 12, 13} @@ -565,7 +565,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { }, } - err = outputStream2.Seek(ctx, p) + err = outputStream2.Seek(ctx, p, false) assert.NoError(t, err) for i := 10; i < 20; i++ { diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 5675b100fa54d..96ee30669a0fe 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -298,7 +298,7 @@ func (ms *simpleMockMsgStream) GetProduceChannels() []string { return nil } -func (ms *simpleMockMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { +func (ms *simpleMockMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error { return nil } diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index 6055d8ec731ea..3990e63ba14b1 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -698,7 +698,7 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position } ts = time.Now() - err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}) + err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false) if err != nil { return nil, err } diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 47a284afd4c89..50665425aa8af 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -725,7 +725,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() { }, 10) s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) + s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mq.EXPECT().Close() ch := make(chan *msgstream.MsgPack, 10) close(ch) @@ -1173,7 +1173,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() { defer cancel() s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) + s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mq.EXPECT().Close() ch := make(chan *msgstream.MsgPack, 10) s.mq.EXPECT().Chan().Return(ch) diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index d636c0dd43863..dcaf420b8cbe5 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -306,7 +306,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { // mocks suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) suite.msgStream.EXPECT().Chan().Return(suite.msgChan) suite.msgStream.EXPECT().Close() @@ -358,7 +358,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { // mocks suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) - suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil) suite.msgStream.EXPECT().Chan().Return(suite.msgChan) suite.msgStream.EXPECT().Close() @@ -432,7 +432,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) suite.msgStream.EXPECT().Close().Return() - suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(errors.New("mock error")).Once() + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once() status, err = suite.node.WatchDmChannels(ctx, req) suite.NoError(err) diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index db61ff1327db9..e27117b0268aa 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -293,8 +293,10 @@ func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.M } return nil, nil } -func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil } -func (ms *FailMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { return nil } +func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil } +func (ms *FailMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error { + return nil +} func (ms *FailMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, error) { return nil, nil diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index ee552046ddc08..4d0ab3e2c606e 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -103,7 +103,7 @@ func NewDispatcher(ctx context.Context, return nil, err } - err = stream.Seek(ctx, []*Pos{position}) + err = stream.Seek(ctx, []*Pos{position}, false) if err != nil { stream.Close() log.Error("seek failed", zap.Error(err)) diff --git a/pkg/mq/msgstream/factory_stream_test.go b/pkg/mq/msgstream/factory_stream_test.go index cb7ff8702cd08..d07e74cdfc0f1 100644 --- a/pkg/mq/msgstream/factory_stream_test.go +++ b/pkg/mq/msgstream/factory_stream_test.go @@ -766,7 +766,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer, consumer, err := newer(ctx) assert.NoError(t, err) consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) - err = consumer.Seek(context.Background(), seekPositions) + err = consumer.Seek(context.Background(), seekPositions, false) assert.NoError(t, err) return consumer } diff --git a/pkg/mq/msgstream/mock_msgstream.go b/pkg/mq/msgstream/mock_msgstream.go index e97b0e30d91a5..adbf233246bf3 100644 --- a/pkg/mq/msgstream/mock_msgstream.go +++ b/pkg/mq/msgstream/mock_msgstream.go @@ -44,10 +44,10 @@ type MockMsgStream_AsConsumer_Call struct { } // AsConsumer is a helper method to define mock.On call -// - ctx context.Context -// - channels []string -// - subName string -// - position mqwrapper.SubscriptionInitialPosition +// - ctx context.Context +// - channels []string +// - subName string +// - position mqwrapper.SubscriptionInitialPosition func (_e *MockMsgStream_Expecter) AsConsumer(ctx interface{}, channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call { return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", ctx, channels, subName, position)} } @@ -80,7 +80,7 @@ type MockMsgStream_AsProducer_Call struct { } // AsProducer is a helper method to define mock.On call -// - channels []string +// - channels []string func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call { return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)} } @@ -134,7 +134,7 @@ type MockMsgStream_Broadcast_Call struct { } // Broadcast is a helper method to define mock.On call -// - _a0 *MsgPack +// - _a0 *MsgPack func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call { return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)} } @@ -219,7 +219,7 @@ type MockMsgStream_CheckTopicValid_Call struct { } // CheckTopicValid is a helper method to define mock.On call -// - channel string +// - channel string func (_e *MockMsgStream_Expecter) CheckTopicValid(channel interface{}) *MockMsgStream_CheckTopicValid_Call { return &MockMsgStream_CheckTopicValid_Call{Call: _e.mock.On("CheckTopicValid", channel)} } @@ -284,7 +284,7 @@ type MockMsgStream_EnableProduce_Call struct { } // EnableProduce is a helper method to define mock.On call -// - can bool +// - can bool func (_e *MockMsgStream_Expecter) EnableProduce(can interface{}) *MockMsgStream_EnableProduce_Call { return &MockMsgStream_EnableProduce_Call{Call: _e.mock.On("EnableProduce", can)} } @@ -338,7 +338,7 @@ type MockMsgStream_GetLatestMsgID_Call struct { } // GetLatestMsgID is a helper method to define mock.On call -// - channel string +// - channel string func (_e *MockMsgStream_Expecter) GetLatestMsgID(channel interface{}) *MockMsgStream_GetLatestMsgID_Call { return &MockMsgStream_GetLatestMsgID_Call{Call: _e.mock.On("GetLatestMsgID", channel)} } @@ -423,7 +423,7 @@ type MockMsgStream_Produce_Call struct { } // Produce is a helper method to define mock.On call -// - _a0 *MsgPack +// - _a0 *MsgPack func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call { return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)} } @@ -445,13 +445,13 @@ func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *Mo return _c } -// Seek provides a mock function with given fields: ctx, offset -func (_m *MockMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error { - ret := _m.Called(ctx, offset) +// Seek provides a mock function with given fields: ctx, msgPositions, includeCurrentMsg +func (_m *MockMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition, includeCurrentMsg bool) error { + ret := _m.Called(ctx, msgPositions, includeCurrentMsg) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.MsgPosition) error); ok { - r0 = rf(ctx, offset) + if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.MsgPosition, bool) error); ok { + r0 = rf(ctx, msgPositions, includeCurrentMsg) } else { r0 = ret.Error(0) } @@ -465,15 +465,16 @@ type MockMsgStream_Seek_Call struct { } // Seek is a helper method to define mock.On call -// - ctx context.Context -// - offset []*msgpb.MsgPosition -func (_e *MockMsgStream_Expecter) Seek(ctx interface{}, offset interface{}) *MockMsgStream_Seek_Call { - return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", ctx, offset)} +// - ctx context.Context +// - msgPositions []*msgpb.MsgPosition +// - includeCurrentMsg bool +func (_e *MockMsgStream_Expecter) Seek(ctx interface{}, msgPositions interface{}, includeCurrentMsg interface{}) *MockMsgStream_Seek_Call { + return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", ctx, msgPositions, includeCurrentMsg)} } -func (_c *MockMsgStream_Seek_Call) Run(run func(ctx context.Context, offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call { +func (_c *MockMsgStream_Seek_Call) Run(run func(ctx context.Context, msgPositions []*msgpb.MsgPosition, includeCurrentMsg bool)) *MockMsgStream_Seek_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([]*msgpb.MsgPosition)) + run(args[0].(context.Context), args[1].([]*msgpb.MsgPosition), args[2].(bool)) }) return _c } @@ -483,7 +484,7 @@ func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call { return _c } -func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func(context.Context, []*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call { +func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func(context.Context, []*msgpb.MsgPosition, bool) error) *MockMsgStream_Seek_Call { _c.Call.Return(run) return _c } @@ -499,7 +500,7 @@ type MockMsgStream_SetRepackFunc_Call struct { } // SetRepackFunc is a helper method to define mock.On call -// - repackFunc RepackFunc +// - repackFunc RepackFunc func (_e *MockMsgStream_Expecter) SetRepackFunc(repackFunc interface{}) *MockMsgStream_SetRepackFunc_Call { return &MockMsgStream_SetRepackFunc_Call{Call: _e.mock.On("SetRepackFunc", repackFunc)} } diff --git a/pkg/mq/msgstream/mq_kafka_msgstream_test.go b/pkg/mq/msgstream/mq_kafka_msgstream_test.go index 468d4e054a96f..fe39f8f082e2d 100644 --- a/pkg/mq/msgstream/mq_kafka_msgstream_test.go +++ b/pkg/mq/msgstream/mq_kafka_msgstream_test.go @@ -145,7 +145,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) { defer outputStream2.Close() assert.NoError(t, err) - err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}) + err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false) assert.NoError(t, err) cnt := 0 @@ -482,6 +482,6 @@ func getKafkaTtOutputStreamAndSeek(ctx context.Context, kafkaAddress string, pos consumerName = append(consumerName, c.ChannelName) } outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(context.Background(), positions) + outputStream.Seek(context.Background(), positions, false) return outputStream } diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 86ad3f7dfe578..a93c9962f414d 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -473,7 +473,7 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack { // Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive // User has to ensure mq_msgstream is not closed before seek, and the seek position is already written. -func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error { +func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error { for _, mp := range msgPositions { consumer, ok := ms.consumers[mp.ChannelName] if !ok { @@ -493,8 +493,8 @@ func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPositi } } - log.Info("MsgStream seek begin", zap.String("channel", mp.ChannelName), zap.Any("MessageID", mp.MsgID)) - err = consumer.Seek(messageID, false) + log.Info("MsgStream seek begin", zap.String("channel", mp.ChannelName), zap.Any("MessageID", mp.MsgID), zap.Bool("includeCurrentMsg", includeCurrentMsg)) + err = consumer.Seek(messageID, includeCurrentMsg) if err != nil { log.Warn("Failed to seek", zap.String("channel", mp.ChannelName), zap.Error(err)) return err @@ -840,7 +840,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu } // Seek to the specified position -func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error { +func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error { var consumer mqwrapper.Consumer var mp *MsgPosition var err error diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index 8705eddf13499..ee4a5d57ffacc 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -517,7 +517,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) { defer outputStream2.Close() assert.NoError(t, err) - err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}) + err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false) assert.NoError(t, err) cnt := 0 @@ -946,7 +946,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) { pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) - outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}) + outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false) for i := 6; i < 10; i++ { result := consumer(ctx, outputStream2) @@ -1001,7 +1001,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { }, } - err = outputStream2.Seek(ctx, p) + err = outputStream2.Seek(ctx, p, false) assert.NoError(t, err) for i := 10; i < 20; i++ { @@ -1070,15 +1070,15 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) { } paramtable.Get().Save(paramtable.Get().MQCfg.IgnoreBadPosition.Key, "false") - err = outputStream2.Seek(ctx, p) + err = outputStream2.Seek(ctx, p, false) assert.Error(t, err) - err = outputStream3.Seek(ctx, p) + err = outputStream3.Seek(ctx, p, false) assert.Error(t, err) paramtable.Get().Save(paramtable.Get().MQCfg.IgnoreBadPosition.Key, "true") - err = outputStream2.Seek(ctx, p) + err = outputStream2.Seek(ctx, p, false) assert.NoError(t, err) - err = outputStream3.Seek(ctx, p) + err = outputStream3.Seek(ctx, p, false) assert.NoError(t, err) } @@ -1466,7 +1466,7 @@ func getPulsarTtOutputStreamAndSeek(ctx context.Context, pulsarAddress string, p consumerName = append(consumerName, c.ChannelName) } outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(context.Background(), positions) + outputStream.Seek(context.Background(), positions, false) return outputStream } diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 184d44967d098..62f8c8737e026 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -63,7 +63,9 @@ type MsgStream interface { AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error Chan() <-chan *MsgPack - Seek(ctx context.Context, offset []*MsgPosition) error + // Seek consume message from the specified position + // includeCurrentMsg indicates whether to consume the current message, and in the milvus system, it should be always false + Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error GetLatestMsgID(channel string) (MessageID, error) CheckTopicValid(channel string) error