From e6d9067fe0ff98393984cfb411a2f69b07a12f53 Mon Sep 17 00:00:00 2001 From: SimFG Date: Sat, 22 Jun 2024 11:59:03 +0800 Subject: [PATCH] enhance: remove the lock in the msgdispatch/target struct Signed-off-by: SimFG --- pkg/mq/msgdispatcher/dispatcher.go | 11 ++++-- pkg/mq/msgdispatcher/dispatcher_test.go | 30 ++++++++--------- pkg/mq/msgdispatcher/manager_test.go | 10 +++--- pkg/mq/msgdispatcher/target.go | 45 +++++++++++++++++++------ 4 files changed, 62 insertions(+), 34 deletions(-) diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index 4d0ab3e2c606e..a6fc71cc0b343 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -234,9 +234,14 @@ func (d *Dispatcher) work() { } } if err != nil { - t.pos = pack.StartPositions[0] - // replace the pChannel with vChannel - t.pos.ChannelName = t.vchannel + // can't directly use the t.pos, since it's shared by all the targets and may be modified + t.pos = &Pos{ + // replace the pChannel with vChannel + ChannelName: t.vchannel, + MsgID: t.pos.MsgID, + MsgGroup: t.pos.MsgGroup, + Timestamp: t.pos.Timestamp, + } d.lagTargets.Insert(t.vchannel, t) d.nonBlockingNotify() delete(d.targets, vchannel) diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index e7c79b54fc0fa..63830f0db9ee5 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -73,16 +73,16 @@ func TestDispatcher(t *testing.T) { "mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil) assert.NoError(t, err) output := make(chan *msgstream.MsgPack, 1024) - d.AddTarget(&target{ - vchannel: "mock_vchannel_0", - pos: nil, - ch: output, - }) - d.AddTarget(&target{ - vchannel: "mock_vchannel_1", - pos: nil, - ch: nil, - }) + d.AddTarget(newTargetWithChan( + "mock_vchannel_0", + nil, + output, + )) + d.AddTarget(newTargetWithChan( + "mock_vchannel_1", + nil, + nil, + )) num := d.TargetNum() assert.Equal(t, 2, num) @@ -106,11 +106,11 @@ func TestDispatcher(t *testing.T) { t.Run("test concurrent send and close", func(t *testing.T) { for i := 0; i < 100; i++ { output := make(chan *msgstream.MsgPack, 1024) - target := &target{ - vchannel: "mock_vchannel_0", - pos: nil, - ch: output, - } + target := newTargetWithChan( + "mock_vchannel_0", + nil, + output, + ) assert.Equal(t, cap(output), cap(target.ch)) wg := &sync.WaitGroup{} for j := 0; j < 100; j++ { diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index 51c7790b40810..2dbcc4e7709f5 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -76,11 +76,11 @@ func TestManager(t *testing.T) { c.(*dispatcherManager).tryMerge() assert.Equal(t, 1, c.Num()) - info := &target{ - vchannel: "mock_vchannel_2", - pos: nil, - ch: nil, - } + info := newTargetWithChan( + "mock_vchannel_2", + nil, + nil, + ) c.(*dispatcherManager).split(info) assert.Equal(t, 2, c.Num()) }) diff --git a/pkg/mq/msgdispatcher/target.go b/pkg/mq/msgdispatcher/target.go index 8fd231e296fef..8e4951e31d694 100644 --- a/pkg/mq/msgdispatcher/target.go +++ b/pkg/mq/msgdispatcher/target.go @@ -19,6 +19,7 @@ package msgdispatcher import ( "fmt" "sync" + "sync/atomic" "time" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -29,36 +30,58 @@ type target struct { ch chan *MsgPack pos *Pos - closeMu sync.Mutex closeOnce sync.Once - closed bool + closed atomic.Bool + sendCnt int64 + cntLock *sync.Mutex + cntCond *sync.Cond } func newTarget(vchannel string, pos *Pos) *target { + ch := make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()) + return newTargetWithChan(vchannel, pos, ch) +} + +func newTargetWithChan(vchannel string, pos *Pos, packChan chan *MsgPack) *target { t := &target{ vchannel: vchannel, - ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()), + ch: packChan, pos: pos, } - t.closed = false + t.cntLock = &sync.Mutex{} + t.cntCond = sync.NewCond(t.cntLock) + t.closed.Store(false) return t } func (t *target) close() { - t.closeMu.Lock() - defer t.closeMu.Unlock() t.closeOnce.Do(func() { - t.closed = true - close(t.ch) + t.closed.Store(true) + // not block the close method + go func() { + t.cntLock.Lock() + for t.sendCnt > 0 { + t.cntCond.Wait() + } + t.cntLock.Unlock() + close(t.ch) + }() }) } func (t *target) send(pack *MsgPack) error { - t.closeMu.Lock() - defer t.closeMu.Unlock() - if t.closed { + if t.closed.Load() { return nil } + t.cntLock.Lock() + t.sendCnt++ + t.cntLock.Unlock() + defer func() { + t.cntLock.Lock() + t.sendCnt-- + t.cntLock.Unlock() + t.cntCond.Signal() + }() maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second) select { case <-time.After(maxTolerantLag):