Skip to content

Commit

Permalink
enhance: remove the lock in the msgdispatch/target struct
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed Jun 24, 2024
1 parent d08cb88 commit e6d9067
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 34 deletions.
11 changes: 8 additions & 3 deletions pkg/mq/msgdispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,14 @@ func (d *Dispatcher) work() {
}
}
if err != nil {
t.pos = pack.StartPositions[0]
// replace the pChannel with vChannel
t.pos.ChannelName = t.vchannel
// can't directly use the t.pos, since it's shared by all the targets and may be modified
t.pos = &Pos{
// replace the pChannel with vChannel
ChannelName: t.vchannel,
MsgID: t.pos.MsgID,
MsgGroup: t.pos.MsgGroup,
Timestamp: t.pos.Timestamp,
}
d.lagTargets.Insert(t.vchannel, t)
d.nonBlockingNotify()
delete(d.targets, vchannel)
Expand Down
30 changes: 15 additions & 15 deletions pkg/mq/msgdispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ func TestDispatcher(t *testing.T) {
"mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil)
assert.NoError(t, err)
output := make(chan *msgstream.MsgPack, 1024)
d.AddTarget(&target{
vchannel: "mock_vchannel_0",
pos: nil,
ch: output,
})
d.AddTarget(&target{
vchannel: "mock_vchannel_1",
pos: nil,
ch: nil,
})
d.AddTarget(newTargetWithChan(
"mock_vchannel_0",
nil,
output,
))
d.AddTarget(newTargetWithChan(
"mock_vchannel_1",
nil,
nil,
))
num := d.TargetNum()
assert.Equal(t, 2, num)

Expand All @@ -106,11 +106,11 @@ func TestDispatcher(t *testing.T) {
t.Run("test concurrent send and close", func(t *testing.T) {
for i := 0; i < 100; i++ {
output := make(chan *msgstream.MsgPack, 1024)
target := &target{
vchannel: "mock_vchannel_0",
pos: nil,
ch: output,
}
target := newTargetWithChan(
"mock_vchannel_0",
nil,
output,
)
assert.Equal(t, cap(output), cap(target.ch))
wg := &sync.WaitGroup{}
for j := 0; j < 100; j++ {
Expand Down
10 changes: 5 additions & 5 deletions pkg/mq/msgdispatcher/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ func TestManager(t *testing.T) {
c.(*dispatcherManager).tryMerge()
assert.Equal(t, 1, c.Num())

info := &target{
vchannel: "mock_vchannel_2",
pos: nil,
ch: nil,
}
info := newTargetWithChan(
"mock_vchannel_2",
nil,
nil,
)
c.(*dispatcherManager).split(info)
assert.Equal(t, 2, c.Num())
})
Expand Down
45 changes: 34 additions & 11 deletions pkg/mq/msgdispatcher/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package msgdispatcher
import (
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/milvus-io/milvus/pkg/util/paramtable"
Expand All @@ -29,36 +30,58 @@ type target struct {
ch chan *MsgPack
pos *Pos

closeMu sync.Mutex
closeOnce sync.Once
closed bool
closed atomic.Bool
sendCnt int64
cntLock *sync.Mutex
cntCond *sync.Cond
}

func newTarget(vchannel string, pos *Pos) *target {
ch := make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt())
return newTargetWithChan(vchannel, pos, ch)
}

func newTargetWithChan(vchannel string, pos *Pos, packChan chan *MsgPack) *target {
t := &target{
vchannel: vchannel,
ch: make(chan *MsgPack, paramtable.Get().MQCfg.TargetBufSize.GetAsInt()),
ch: packChan,
pos: pos,
}
t.closed = false
t.cntLock = &sync.Mutex{}
t.cntCond = sync.NewCond(t.cntLock)
t.closed.Store(false)
return t
}

func (t *target) close() {
t.closeMu.Lock()
defer t.closeMu.Unlock()
t.closeOnce.Do(func() {
t.closed = true
close(t.ch)
t.closed.Store(true)
// not block the close method
go func() {
t.cntLock.Lock()
for t.sendCnt > 0 {
t.cntCond.Wait()
}
t.cntLock.Unlock()
close(t.ch)
}()
})
}

func (t *target) send(pack *MsgPack) error {
t.closeMu.Lock()
defer t.closeMu.Unlock()
if t.closed {
if t.closed.Load() {
return nil
}
t.cntLock.Lock()
t.sendCnt++
t.cntLock.Unlock()
defer func() {
t.cntLock.Lock()
t.sendCnt--
t.cntLock.Unlock()
t.cntCond.Signal()
}()
maxTolerantLag := paramtable.Get().MQCfg.MaxTolerantLag.GetAsDuration(time.Second)
select {
case <-time.After(maxTolerantLag):
Expand Down

0 comments on commit e6d9067

Please sign in to comment.