Skip to content

Commit

Permalink
fix: add unit test
Browse files Browse the repository at this point in the history
Signed-off-by: chyezh <[email protected]>
  • Loading branch information
chyezh committed Dec 23, 2024
1 parent 730933e commit 803fa3e
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 62 deletions.
45 changes: 45 additions & 0 deletions internal/mocks/distributed/mock_streaming/mock_WALAccesser.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion internal/util/streamingutil/util/wal_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

const (
walTypeDefault = "default"
walTypeNatsmq = "natsmq"
walTypeRocksmq = "rocksmq"
walTypeKafka = "kafka"
walTypePulsar = "pulsar"
Expand Down
35 changes: 16 additions & 19 deletions internal/util/streamingutil/util/wal_selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,24 @@ import (
)

func TestValidateWALType(t *testing.T) {
assert.Error(t, validateWALName(false, walTypeNatsmq))
assert.Error(t, validateWALName(false, walTypeRocksmq))
}

func TestSelectWALType(t *testing.T) {
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{true, true, true, true}), walTypeRocksmq)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, false, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, false, false, true}), walTypeKafka)
assert.Panics(t, func() { mustSelectWALName(true, walTypeDefault, walEnable{false, false, false, false}) })
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{true, true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, false, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, false, false, true}), walTypeKafka)
assert.Panics(t, func() { mustSelectWALName(false, walTypeDefault, walEnable{false, false, false, false}) })
assert.Equal(t, mustSelectWALName(true, walTypeRocksmq, walEnable{true, true, true, true}), walTypeRocksmq)
assert.Equal(t, mustSelectWALName(true, walTypeNatsmq, walEnable{true, true, true, true}), walTypeNatsmq)
assert.Equal(t, mustSelectWALName(true, walTypePulsar, walEnable{true, true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(true, walTypeKafka, walEnable{true, true, true, true}), walTypeKafka)
assert.Panics(t, func() { mustSelectWALName(false, walTypeRocksmq, walEnable{true, true, true, true}) })
assert.Panics(t, func() { mustSelectWALName(false, walTypeNatsmq, walEnable{true, true, true, true}) })
assert.Equal(t, mustSelectWALName(false, walTypePulsar, walEnable{true, true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeKafka, walEnable{true, true, true, true}), walTypeKafka)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{true, true, true}), walTypeRocksmq)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, false, true}), walTypeKafka)
assert.Panics(t, func() { mustSelectWALName(true, walTypeDefault, walEnable{false, false, false}) })
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, false, true}), walTypeKafka)
assert.Panics(t, func() { mustSelectWALName(false, walTypeDefault, walEnable{false, false, false}) })
assert.Equal(t, mustSelectWALName(true, walTypeRocksmq, walEnable{true, true, true}), walTypeRocksmq)
assert.Equal(t, mustSelectWALName(true, walTypePulsar, walEnable{true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(true, walTypeKafka, walEnable{true, true, true}), walTypeKafka)
assert.Panics(t, func() { mustSelectWALName(false, walTypeRocksmq, walEnable{true, true, true}) })
assert.Equal(t, mustSelectWALName(false, walTypePulsar, walEnable{true, true, true}), walTypePulsar)
assert.Equal(t, mustSelectWALName(false, walTypeKafka, walEnable{true, true, true}), walTypeKafka)
}
22 changes: 1 addition & 21 deletions pkg/mq/mqimpl/rocksmq/client/client_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
package client

import (
"bytes"
"context"
"reflect"
"sync"
"time"

"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"

"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server"
"github.com/milvus-io/milvus/pkg/streaming/proto/messagespb"
)

const (
Expand Down Expand Up @@ -201,7 +198,7 @@ func (c *client) tryToConsume(consumer *consumer) []*RmqMessage {
}
rmqMsgs := make([]*RmqMessage, 0, len(msgs))
for _, msg := range msgs {
rmqMsg, err := c.unmarshalStreamingMessage(consumer.topic, msg)
rmqMsg, err := unmarshalStreamingMessage(consumer.topic, msg)
if err == nil {
rmqMsgs = append(rmqMsgs, rmqMsg)
continue
Expand All @@ -228,23 +225,6 @@ func (c *client) tryToConsume(consumer *consumer) []*RmqMessage {
return rmqMsgs
}

func (c *client) unmarshalStreamingMessage(topic string, msg server.ConsumerMessage) (*RmqMessage, error) {
if !bytes.HasPrefix(msg.Payload, magicPrefix) {
return nil, errNotStreamingServiceMessage
}

var rmqMessage messagespb.RMQMessageLayout
if err := proto.Unmarshal(msg.Payload[len(magicPrefix):], &rmqMessage); err != nil {
return nil, err
}
return &RmqMessage{
msgID: msg.MsgID,
payload: rmqMessage.Payload,
properties: rmqMessage.Properties,
topic: topic,
}, nil
}

// Close close the channel to notify rocksmq to stop operation and close rocksmq server
func (c *client) Close() {
c.closeOnce.Do(func() {
Expand Down
10 changes: 0 additions & 10 deletions pkg/mq/mqimpl/rocksmq/client/magic.go

This file was deleted.

13 changes: 2 additions & 11 deletions pkg/mq/mqimpl/rocksmq/client/producer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ package client

import (
"go.uber.org/zap"
"google.golang.org/protobuf/proto"

"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server"
"github.com/milvus-io/milvus/pkg/streaming/proto/messagespb"
)

// assertion make sure implementation
Expand Down Expand Up @@ -79,19 +77,12 @@ func (p *producer) Send(message *common.ProducerMessage) (UniqueID, error) {
}

func (p *producer) SendForStreamingService(message *common.ProducerMessage) (UniqueID, error) {
rmqMessage := &messagespb.RMQMessageLayout{
Payload: message.Payload,
Properties: message.Properties,
}
payload, err := proto.Marshal(rmqMessage)
payload, err := marshalStreamingMessage(message)
if err != nil {
return 0, err
}

Check warning on line 83 in pkg/mq/mqimpl/rocksmq/client/producer_impl.go

View check run for this annotation

Codecov / codecov/patch

pkg/mq/mqimpl/rocksmq/client/producer_impl.go#L82-L83

Added lines #L82 - L83 were not covered by tests
finalPayload := make([]byte, len(payload)+len(magicPrefix))
copy(finalPayload, magicPrefix)
copy(finalPayload[len(magicPrefix):], payload)
ids, err := p.c.server.Produce(p.topic, []server.ProducerMessage{{
Payload: finalPayload,
Payload: payload,
}})
if err != nil {
return 0, err
Expand Down
53 changes: 53 additions & 0 deletions pkg/mq/mqimpl/rocksmq/client/streaming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package client

import (
"bytes"

"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"

"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server"
"github.com/milvus-io/milvus/pkg/streaming/proto/messagespb"
)

var (
// magicPrefix is used to identify the rocksmq legacy message and new message for streaming service.
// Make a low probability of collision with the legacy proto message.
magicPrefix = append([]byte{0xFF, 0xFE, 0xFD, 0xFC}, []byte("STREAM")...)
errNotStreamingServiceMessage = errors.New("not a streaming service message")
)

// marshalStreamingMessage marshals a streaming message to bytes.
func marshalStreamingMessage(message *common.ProducerMessage) ([]byte, error) {
rmqMessage := &messagespb.RMQMessageLayout{
Payload: message.Payload,
Properties: message.Properties,
}
payload, err := proto.Marshal(rmqMessage)
if err != nil {
return nil, err
}

Check warning on line 30 in pkg/mq/mqimpl/rocksmq/client/streaming.go

View check run for this annotation

Codecov / codecov/patch

pkg/mq/mqimpl/rocksmq/client/streaming.go#L29-L30

Added lines #L29 - L30 were not covered by tests
finalPayload := make([]byte, len(payload)+len(magicPrefix))
copy(finalPayload, magicPrefix)
copy(finalPayload[len(magicPrefix):], payload)
return finalPayload, nil
}

// unmarshalStreamingMessage unmarshals a streaming message from bytes.
func unmarshalStreamingMessage(topic string, msg server.ConsumerMessage) (*RmqMessage, error) {
if !bytes.HasPrefix(msg.Payload, magicPrefix) {
return nil, errNotStreamingServiceMessage
}

var rmqMessage messagespb.RMQMessageLayout
if err := proto.Unmarshal(msg.Payload[len(magicPrefix):], &rmqMessage); err != nil {
return nil, err
}

Check warning on line 46 in pkg/mq/mqimpl/rocksmq/client/streaming.go

View check run for this annotation

Codecov / codecov/patch

pkg/mq/mqimpl/rocksmq/client/streaming.go#L45-L46

Added lines #L45 - L46 were not covered by tests
return &RmqMessage{
msgID: msg.MsgID,
payload: rmqMessage.Payload,
properties: rmqMessage.Properties,
topic: topic,
}, nil
}
36 changes: 36 additions & 0 deletions pkg/mq/mqimpl/rocksmq/client/streaming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package client

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server"
)

func TestStreaming(t *testing.T) {
payload, err := marshalStreamingMessage(&common.ProducerMessage{
Payload: []byte("payload"),
Properties: map[string]string{
"key": "value",
},
})
assert.NoError(t, err)
assert.NotNil(t, payload)

msg, err := unmarshalStreamingMessage("topic", server.ConsumerMessage{
MsgID: 1,
Payload: payload,
})
assert.NoError(t, err)
assert.Equal(t, string(msg.Payload()), "payload")
assert.Equal(t, msg.Properties()["key"], "value")
msg, err = unmarshalStreamingMessage("topic", server.ConsumerMessage{
MsgID: 1,
Payload: payload[1:],
})
assert.Error(t, err)
assert.ErrorIs(t, err, errNotStreamingServiceMessage)
assert.Nil(t, msg)
}

0 comments on commit 803fa3e

Please sign in to comment.