Skip to content

Commit

Permalink
fix: rocksmq data race on consumers modification
Browse files Browse the repository at this point in the history
Related Issue: #29101

- Fix data race by add new lock and CopyOnWrite

- Add new unittest to verify it

Signed-off-by: chyezh <[email protected]>
  • Loading branch information
chyezh committed Dec 18, 2023
1 parent 4731c1b commit bad3b45
Show file tree
Hide file tree
Showing 9 changed files with 389 additions and 88 deletions.
6 changes: 6 additions & 0 deletions internal/mq/mqimpl/rocksmq/client/client_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ func (c *client) Subscribe(options ConsumerOptions) (Consumer, error) {
return nil, newError(0, "Rmq server is nil")
}

// Create a topic in rocksmq, ignore if topic exists
err := c.server.CreateTopic(options.Topic)
if err != nil {
return nil, err
}

exist, con, err := c.server.ExistConsumerGroup(options.Topic, options.SubscriptionName)
if err != nil {
return nil, err
Expand Down
1 change: 1 addition & 0 deletions internal/mq/mqimpl/rocksmq/client/client_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func TestClient_SubscribeError(t *testing.T) {
testGroupName := newConsumerName()

assert.NoError(t, err)
mockMQ.EXPECT().CreateTopic(testTopic).Return(nil)
mockMQ.EXPECT().ExistConsumerGroup(testTopic, testGroupName).Return(false, nil, nil)
mockMQ.EXPECT().CreateConsumerGroup(testTopic, testGroupName).Return(nil)
mockMQ.EXPECT().RegisterConsumer(mock.Anything).Return(nil)
Expand Down
123 changes: 60 additions & 63 deletions internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ func checkRetention() bool {
return params.RocksmqCfg.RetentionSizeInMB.GetAsInt64() != -1 || params.RocksmqCfg.RetentionTimeInMinutes.GetAsInt64() != -1
}

var topicMu = sync.Map{}
// TODO: We shouldn't use a mutex group to check if a topic is being created or not, it cause a lot of bugs.
// Refactor it in future.
var topicMu = typeutil.NewRemovableGroupMutex[string]()

type rocksmq struct {
store *gorocksdb.DB
Expand Down Expand Up @@ -376,7 +378,7 @@ func (rmq *rocksmq) stopRetention() {
}

// CreateTopic writes initialized messages for topic in rocksdb
func (rmq *rocksmq) CreateTopic(topicName string) error {
func (rmq *rocksmq) CreateTopic(topicName string) (err error) {
if rmq.isClosed() {
return errors.New(RmqNotServingErrMsg)
}
Expand All @@ -399,9 +401,16 @@ func (rmq *rocksmq) CreateTopic(topicName string) error {
return nil
}

if _, ok := topicMu.Load(topicName); !ok {
topicMu.Store(topicName, new(sync.Mutex))
}
lockGuard := topicMu.Lock(topicName)
defer func() {
if err != nil {
// If create topic failed, we should remove the lock.
// Some logic in rocksmq use these lock to check if a topic is being created or not.
lockGuard.UnlockAndRemove()
} else {
lockGuard.Unlock()
}
}()

// msgSizeKey -> msgSize
// topicIDKey -> topic creating time
Expand All @@ -426,24 +435,27 @@ func (rmq *rocksmq) CreateTopic(topicName string) error {
}

// DestroyTopic removes messages for topic in rocksmq
func (rmq *rocksmq) DestroyTopic(topicName string) error {
func (rmq *rocksmq) DestroyTopic(topicName string) (err error) {
start := time.Now()
ll, ok := topicMu.Load(topicName)
if !ok {
lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return fmt.Errorf("topic name = %s not exist", topicName)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer func() {
if err == nil {
// If destroy topic success, we should remove the lock.
lockGuard.UnlockAndRemove()
} else {
lockGuard.Unlock()
}
}()
// TODO: Lost transaction promised here. half destroy topic can be seen by others.

rmq.consumers.Delete(topicName)

// clean the topic data it self
fixTopicName := topicName + "/"
err := rmq.kv.RemoveWithPrefix(fixTopicName)
err = rmq.kv.RemoveWithPrefix(fixTopicName)
if err != nil {
return err
}
Expand Down Expand Up @@ -482,7 +494,6 @@ func (rmq *rocksmq) DestroyTopic(topicName string) error {
}

// clean up retention info
topicMu.Delete(topicName)
rmq.retentionInfo.topicRetetionTime.GetAndRemove(topicName)

log.Debug("Rocksmq destroy topic successfully ", zap.String("topic", topicName), zap.Int64("elapsed", time.Since(start).Milliseconds()))
Expand Down Expand Up @@ -529,12 +540,19 @@ func (rmq *rocksmq) RegisterConsumer(consumer *Consumer) error {
return errors.New(RmqNotServingErrMsg)
}
start := time.Now()
lockGuard := topicMu.LockIfExist(consumer.Topic)
if lockGuard == nil {
return fmt.Errorf("topic name = %s not exist at RegisterConsumer", consumer.Topic)
}
defer lockGuard.Unlock()

if vals, ok := rmq.consumers.Load(consumer.Topic); ok {
for _, v := range vals.([]*Consumer) {
if v.GroupName == consumer.GroupName {
return nil
}
}
// Append operation always CopyOnWrite, so it's safe to be accessed by other method.
consumers := vals.([]*Consumer)
consumers = append(consumers, consumer)
rmq.consumers.Store(consumer.Topic, consumers)
Expand Down Expand Up @@ -570,25 +588,25 @@ func (rmq *rocksmq) DestroyConsumerGroup(topicName, groupName string) error {
// DestroyConsumerGroup removes a consumer group from rocksdb_kv
func (rmq *rocksmq) destroyConsumerGroupInternal(topicName, groupName string) error {
start := time.Now()
ll, ok := topicMu.Load(topicName)
if !ok {
lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return fmt.Errorf("topic name = %s not exist", topicName)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

key := constructCurrentID(topicName, groupName)
rmq.consumersID.Delete(key)
if vals, ok := rmq.consumers.Load(topicName); ok {
consumers := vals.([]*Consumer)
for index, v := range consumers {
if v.GroupName == groupName {
close(v.MsgMutex)
consumers = append(consumers[:index], consumers[index+1:]...)
rmq.consumers.Store(topicName, consumers)
// Need CopyOnWrite operation here.
// Operate on a copy of the slice, so that the original slice is not modified and be safe to be accessed by other method.
newConsumers := make([]*Consumer, 0, len(consumers)-1)
newConsumers = append(newConsumers, consumers[:index]...)
newConsumers = append(newConsumers, consumers[index+1:]...)
rmq.consumers.Store(topicName, newConsumers)
break
}
}
Expand All @@ -605,16 +623,12 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni
return nil, errors.New(RmqNotServingErrMsg)
}
start := time.Now()
ll, ok := topicMu.Load(topicName)
if !ok {

lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return []UniqueID{}, fmt.Errorf("topic name = %s not exist", topicName)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return []UniqueID{}, fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

getLockTime := time.Since(start).Milliseconds()

Expand Down Expand Up @@ -741,17 +755,11 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum
return nil, errors.New(RmqNotServingErrMsg)
}
start := time.Now()
ll, ok := topicMu.Load(topicName)
if !ok {
lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return nil, fmt.Errorf("topic name = %s not exist", topicName)
}

lock, ok := ll.(*sync.Mutex)
if !ok {
return nil, fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

currentID, ok := rmq.getCurrentID(topicName, groupName)
if !ok {
Expand Down Expand Up @@ -900,17 +908,11 @@ func (rmq *rocksmq) Seek(topicName string, groupName string, msgID UniqueID) err
if rmq.isClosed() {
return errors.New(RmqNotServingErrMsg)
}
/* Step I: Check if key exists */
ll, ok := topicMu.Load(topicName)
if !ok {
lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return merr.WrapErrMqTopicNotFound(topicName)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

err := rmq.seek(topicName, groupName, msgID)
if err != nil {
Expand All @@ -927,21 +929,17 @@ func (rmq *rocksmq) ForceSeek(topicName string, groupName string, msgID UniqueID
return errors.New(RmqNotServingErrMsg)
}
/* Step I: Check if key exists */
ll, ok := topicMu.Load(topicName)
if !ok {
lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return merr.WrapErrMqTopicNotFound(topicName)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

rmq.storeMu.Lock()
defer rmq.storeMu.Unlock()

key := constructCurrentID(topicName, groupName)
_, ok = rmq.consumersID.Load(key)
_, ok := rmq.consumersID.Load(key)
if !ok {
return fmt.Errorf("ConsumerGroup %s, channel %s not exists", groupName, topicName)
}
Expand Down Expand Up @@ -1114,8 +1112,7 @@ func (rmq *rocksmq) updateAckedInfo(topicName, groupName string, firstID UniqueI
}

func (rmq *rocksmq) CheckTopicValid(topic string) error {
_, ok := topicMu.Load(topic)
if !ok {
if !topicMu.Exists(topic) {
return merr.WrapErrMqTopicNotFound(topic, "failed to get topic")
}

Expand Down
51 changes: 37 additions & 14 deletions internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,11 @@ func (rmq *rocksmq) produceBefore2(topicName string, messages []producerMessageB
return nil, errors.New(RmqNotServingErrMsg)
}
start := time.Now()
ll, ok := topicMu.Load(topicName)
if !ok {
lockGuard := topicMu.LockIfExist(topicName)
if lockGuard == nil {
return []UniqueID{}, fmt.Errorf("topic name = %s not exist", topicName)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return []UniqueID{}, fmt.Errorf("get mutex failed, topic name = %s", topicName)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

getLockTime := time.Since(start).Milliseconds()

Expand Down Expand Up @@ -221,6 +216,34 @@ func TestRocksmq_RegisterConsumer(t *testing.T) {
MsgMutex: make(chan struct{}),
}
rmq.RegisterConsumer(consumer2)

// Concurrent RegisterConsumer and Produce test.
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 1000; i++ {
_, err := rmq.Produce(topicName, pMsgs)
assert.NoError(t, err)
}
}()

for j := 0; j < 100; j++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
groupName := fmt.Sprintf("group_register_%d", j)
consumer := &Consumer{
Topic: topicName,
GroupName: groupName,
MsgMutex: make(chan struct{}),
}
err := rmq.RegisterConsumer(consumer)
assert.NoError(t, err)
rmq.DestroyConsumerGroup(topicName, consumer.GroupName)
}(j)
}
wg.Wait()
}

func TestRocksmq_Basic(t *testing.T) {
Expand Down Expand Up @@ -361,7 +384,7 @@ func TestRocksmq_Dummy(t *testing.T) {
assert.NoError(t, err)

channelName1 := "channel_dummy"
topicMu.Store(channelName1, new(sync.Mutex))
topicMu.Lock(channelName1).Unlock()
err = rmq.DestroyTopic(channelName1)
assert.NoError(t, err)

Expand Down Expand Up @@ -393,10 +416,10 @@ func TestRocksmq_Dummy(t *testing.T) {
pMsgA := ProducerMessage{Payload: []byte(msgA)}
pMsgs[0] = pMsgA

topicMu.Delete(channelName)
topicMu.Lock(channelName).UnlockAndRemove()
_, err = rmq.Consume(channelName, groupName1, 1)
assert.Error(t, err)
topicMu.Store(channelName, channelName)
topicMu.Lock(channelName).Unlock()
_, err = rmq.Produce(channelName, nil)
assert.Error(t, err)

Expand Down Expand Up @@ -1003,7 +1026,7 @@ func TestRocksmq_CheckPreTopicValid(t *testing.T) {
err = rmq.CreateTopic(channelName2)
defer rmq.DestroyTopic(channelName2)
assert.NoError(t, err)
topicMu.Store(channelName2, new(sync.Mutex))
topicMu.Lock(channelName2).Unlock()

pMsgs := make([]ProducerMessage, 10)
for i := 0; i < 10; i++ {
Expand All @@ -1023,7 +1046,7 @@ func TestRocksmq_CheckPreTopicValid(t *testing.T) {
defer rmq.DestroyTopic(channelName3)
assert.NoError(t, err)

topicMu.Store(channelName3, new(sync.Mutex))
topicMu.Lock(channelName3).Unlock()
err = rmq.CheckTopicValid(channelName3)
assert.NoError(t, err)
}
Expand Down Expand Up @@ -1131,7 +1154,7 @@ func TestRocksmq_SeekTopicMutexError(t *testing.T) {
assert.NoError(t, err)
defer rmq.Close()

topicMu.Store("test_topic_mutix_error", nil)
topicMu.Lock("test_topic_mutix_error").Unlock()
assert.Error(t, rmq.Seek("test_topic_mutix_error", "", 0))
assert.Error(t, rmq.ForceSeek("test_topic_mutix_error", "", 0))
}
Expand Down
15 changes: 6 additions & 9 deletions internal/mq/mqimpl/rocksmq/server/rocksmq_retention.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func initRetentionInfo(kv *rocksdbkv.RocksdbKV, db *gorocksdb.DB) (*retentionInf
for _, key := range topicKeys {
topic := key[len(TopicIDTitle):]
ri.topicRetetionTime.Insert(topic, time.Now().Unix())
topicMu.Store(topic, new(sync.Mutex))
// TODO: we are using a global lock here group to check if topic is exist, which is not good implementation.
lockGuard := topicMu.Lock(topic)
lockGuard.Unlock()
}
return ri, nil
}
Expand Down Expand Up @@ -311,16 +313,11 @@ func (ri *retentionInfo) cleanData(topic string, pageEndID UniqueID) error {
ackedEndIDKey := fixedAckedTsKey + "/" + strconv.FormatInt(pageEndID+1, 10)
writeBatch.DeleteRange([]byte(ackedStartIDKey), []byte(ackedEndIDKey))

ll, ok := topicMu.Load(topic)
if !ok {
lockGuard := topicMu.LockIfExist(topic)
if lockGuard == nil {
return fmt.Errorf("topic name = %s not exist", topic)
}
lock, ok := ll.(*sync.Mutex)
if !ok {
return fmt.Errorf("get mutex failed, topic name = %s", topic)
}
lock.Lock()
defer lock.Unlock()
defer lockGuard.Unlock()

err := DeleteMessages(ri.db, topic, 0, pageEndID)
if err != nil {
Expand Down
Loading

0 comments on commit bad3b45

Please sign in to comment.