diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index dfd232a977fb4..7d76b8c981ef5 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -81,14 +81,12 @@ type Dispatcher struct { } func NewDispatcher(ctx context.Context, - factory msgstream.Factory, - isMain bool, - pchannel string, - position *Pos, - subName string, - subPos SubPos, + factory msgstream.Factory, isMain bool, + pchannel string, position *Pos, + subName string, subPos SubPos, lagNotifyChan chan struct{}, lagTargets *typeutil.ConcurrentMap[string, *target], + includeCurrentMsg bool, ) (*Dispatcher, error) { log := log.With(zap.String("pchannel", pchannel), zap.String("subName", subName), zap.Bool("isMain", isMain)) @@ -106,7 +104,7 @@ func NewDispatcher(ctx context.Context, return nil, err } - err = stream.Seek(ctx, []*Pos{position}, false) + err = stream.Seek(ctx, []*Pos{position}, includeCurrentMsg) if err != nil { stream.Close() log.Error("seek failed", zap.Error(err)) diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index e7fb70392194e..2ef4bbff66f06 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -34,8 +34,7 @@ import ( func TestDispatcher(t *testing.T) { ctx := context.Background() t.Run("test base", func(t *testing.T) { - d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, - "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil) + d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) assert.NoError(t, err) assert.NotPanics(t, func() { d.Handle(start) @@ -62,16 +61,14 @@ func TestDispatcher(t *testing.T) { return ms, nil }, } - d, err := NewDispatcher(ctx, factory, true, "mock_pchannel_0", nil, - "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil) + d, err := NewDispatcher(ctx, factory, true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) assert.Error(t, err) assert.Nil(t, d) }) t.Run("test target", func(t *testing.T) { - d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, - "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil) + d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) assert.NoError(t, err) output := make(chan *msgstream.MsgPack, 1024) d.AddTarget(&target{ @@ -136,8 +133,7 @@ func TestDispatcher(t *testing.T) { } func BenchmarkDispatcher_handle(b *testing.B) { - d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, - "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil) + d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", common.SubscriptionPositionEarliest, nil, nil, false) assert.NoError(b, err) for i := 0; i < b.N; i++ { diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index 8b4dd944eb8dd..72ff293825fde 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -89,8 +89,7 @@ func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, c.mu.Lock() defer c.mu.Unlock() isMain := c.mainDispatcher == nil - d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, - c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets) + d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets, false) if err != nil { return nil, err } @@ -234,8 +233,7 @@ func (c *dispatcherManager) split(t *target) { var newSolo *Dispatcher err := retry.Do(context.Background(), func() error { var err error - newSolo, err = NewDispatcher(context.Background(), c.factory, false, c.pchannel, t.pos, - c.constructSubName(t.vchannel, false), common.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets) + newSolo, err = NewDispatcher(context.Background(), c.factory, false, c.pchannel, t.pos, c.constructSubName(t.vchannel, false), common.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets, true) return err }, retry.Attempts(10)) if err != nil {