Skip to content

Commit

Permalink
enhance: add the includeCurrentMsg param for the Seek method (milvus-…
Browse files Browse the repository at this point in the history
…io#33326)

/kind improvement
- issue: milvus-io#33325

Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored May 27, 2024
1 parent 58ee613 commit cb99e3d
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 54 deletions.
2 changes: 1 addition & 1 deletion internal/datanode/flow_graph_dmstream_input_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstrea
return nil, nil
}

func (mtm *mockTtMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error {
func (mtm *mockTtMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
return nil
}

Expand Down
8 changes: 4 additions & 4 deletions internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func TestMqMsgStream_SeekNotSubscribed(t *testing.T) {
ChannelName: "b",
},
}
err = m.Seek(context.Background(), p)
err = m.Seek(context.Background(), p, false)
assert.Error(t, err)
}

Expand Down Expand Up @@ -403,7 +403,7 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher())
consumerSubName = funcutil.RandomString(8)
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(ctx, receivedMsg.StartPositions)
outputStream.Seek(ctx, receivedMsg.StartPositions, false)
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 1+2)
assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1)
Expand Down Expand Up @@ -506,7 +506,7 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) {
consumerSubName = funcutil.RandomString(8)
outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown)

outputStream.Seek(ctx, receivedMsg3.StartPositions)
outputStream.Seek(ctx, receivedMsg3.StartPositions, false)
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 3)
result := []uint64{14, 12, 13}
Expand Down Expand Up @@ -565,7 +565,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) {
},
}

err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.NoError(t, err)

for i := 10; i < 20; i++ {
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (ms *simpleMockMsgStream) GetProduceChannels() []string {
return nil
}

func (ms *simpleMockMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error {
func (ms *simpleMockMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/delegator/delegator_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position
}

ts = time.Now()
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position})
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/querynodev2/delegator/delegator_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
}, 10)

s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.MsgPack, 10)
close(ch)
Expand Down Expand Up @@ -1173,7 +1173,7 @@ func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
defer cancel()

s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.MsgPack, 10)
s.mq.EXPECT().Chan().Return(ch)
Expand Down
6 changes: 3 additions & 3 deletions internal/querynodev2/services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()

Expand Down Expand Up @@ -358,7 +358,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()

Expand Down Expand Up @@ -432,7 +432,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Close().Return()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(errors.New("mock error")).Once()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once()

status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
Expand Down
6 changes: 4 additions & 2 deletions internal/rootcoord/dml_channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.M
}
return nil, nil
}
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { return nil }
func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil }
func (ms *FailMsgStream) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error {
return nil
}

func (ms *FailMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, error) {
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/mq/msgdispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func NewDispatcher(ctx context.Context,
return nil, err
}

err = stream.Seek(ctx, []*Pos{position})
err = stream.Seek(ctx, []*Pos{position}, false)
if err != nil {
stream.Close()
log.Error("seek failed", zap.Error(err))
Expand Down
2 changes: 1 addition & 1 deletion pkg/mq/msgstream/factory_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer,
consumer, err := newer(ctx)
assert.NoError(t, err)
consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
err = consumer.Seek(context.Background(), seekPositions)
err = consumer.Seek(context.Background(), seekPositions, false)
assert.NoError(t, err)
return consumer
}
Expand Down
47 changes: 24 additions & 23 deletions pkg/mq/msgstream/mock_msgstream.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkg/mq/msgstream/mq_kafka_msgstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) {
defer outputStream2.Close()
assert.NoError(t, err)

err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false)
assert.NoError(t, err)

cnt := 0
Expand Down Expand Up @@ -482,6 +482,6 @@ func getKafkaTtOutputStreamAndSeek(ctx context.Context, kafkaAddress string, pos
consumerName = append(consumerName, c.ChannelName)
}
outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(context.Background(), positions)
outputStream.Seek(context.Background(), positions, false)
return outputStream
}
8 changes: 4 additions & 4 deletions pkg/mq/msgstream/mq_msgstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack {

// Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive
// User has to ensure mq_msgstream is not closed before seek, and the seek position is already written.
func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error {
func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error {
for _, mp := range msgPositions {
consumer, ok := ms.consumers[mp.ChannelName]
if !ok {
Expand All @@ -493,8 +493,8 @@ func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPositi
}
}

log.Info("MsgStream seek begin", zap.String("channel", mp.ChannelName), zap.Any("MessageID", mp.MsgID))
err = consumer.Seek(messageID, false)
log.Info("MsgStream seek begin", zap.String("channel", mp.ChannelName), zap.Any("MessageID", mp.MsgID), zap.Bool("includeCurrentMsg", includeCurrentMsg))
err = consumer.Seek(messageID, includeCurrentMsg)
if err != nil {
log.Warn("Failed to seek", zap.String("channel", mp.ChannelName), zap.Error(err))
return err
Expand Down Expand Up @@ -840,7 +840,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu
}

// Seek to the specified position
func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error {
func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error {
var consumer mqwrapper.Consumer
var mp *MsgPosition
var err error
Expand Down
16 changes: 8 additions & 8 deletions pkg/mq/msgstream/mq_msgstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) {
defer outputStream2.Close()
assert.NoError(t, err)

err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false)
assert.NoError(t, err)

cnt := 0
Expand Down Expand Up @@ -946,7 +946,7 @@ func TestStream_MqMsgStream_Seek(t *testing.T) {
pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress})
outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher())
outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest)
outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition})
outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}, false)

for i := 6; i < 10; i++ {
result := consumer(ctx, outputStream2)
Expand Down Expand Up @@ -1001,7 +1001,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) {
},
}

err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.NoError(t, err)

for i := 10; i < 20; i++ {
Expand Down Expand Up @@ -1070,15 +1070,15 @@ func TestSTream_MqMsgStream_SeekBadMessageID(t *testing.T) {
}

paramtable.Get().Save(paramtable.Get().MQCfg.IgnoreBadPosition.Key, "false")
err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.Error(t, err)
err = outputStream3.Seek(ctx, p)
err = outputStream3.Seek(ctx, p, false)
assert.Error(t, err)

paramtable.Get().Save(paramtable.Get().MQCfg.IgnoreBadPosition.Key, "true")
err = outputStream2.Seek(ctx, p)
err = outputStream2.Seek(ctx, p, false)
assert.NoError(t, err)
err = outputStream3.Seek(ctx, p)
err = outputStream3.Seek(ctx, p, false)
assert.NoError(t, err)
}

Expand Down Expand Up @@ -1466,7 +1466,7 @@ func getPulsarTtOutputStreamAndSeek(ctx context.Context, pulsarAddress string, p
consumerName = append(consumerName, c.ChannelName)
}
outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown)
outputStream.Seek(context.Background(), positions)
outputStream.Seek(context.Background(), positions, false)
return outputStream
}

Expand Down
4 changes: 3 additions & 1 deletion pkg/mq/msgstream/msgstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ type MsgStream interface {

AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error
Chan() <-chan *MsgPack
Seek(ctx context.Context, offset []*MsgPosition) error
// Seek consume message from the specified position
// includeCurrentMsg indicates whether to consume the current message, and in the milvus system, it should be always false
Seek(ctx context.Context, msgPositions []*MsgPosition, includeCurrentMsg bool) error

GetLatestMsgID(channel string) (MessageID, error)
CheckTopicValid(channel string) error
Expand Down

0 comments on commit cb99e3d

Please sign in to comment.