diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index 6b391ff89a062..e077e04030f15 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -331,6 +331,51 @@ func (_c *MockWALAccesser_Txn_Call) RunAndReturn(run func(context.Context, strea return _c } +// WALName provides a mock function with given fields: +func (_m *MockWALAccesser) WALName() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for WALName") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockWALAccesser_WALName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WALName' +type MockWALAccesser_WALName_Call struct { + *mock.Call +} + +// WALName is a helper method to define mock.On call +func (_e *MockWALAccesser_Expecter) WALName() *MockWALAccesser_WALName_Call { + return &MockWALAccesser_WALName_Call{Call: _e.mock.On("WALName")} +} + +func (_c *MockWALAccesser_WALName_Call) Run(run func()) *MockWALAccesser_WALName_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWALAccesser_WALName_Call) Return(_a0 string) *MockWALAccesser_WALName_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWALAccesser_WALName_Call) RunAndReturn(run func() string) *MockWALAccesser_WALName_Call { + _c.Call.Return(run) + return _c +} + // NewMockWALAccesser creates a new instance of MockWALAccesser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockWALAccesser(t interface { diff --git a/internal/util/streamingutil/util/wal_selector.go b/internal/util/streamingutil/util/wal_selector.go index 91ec7ff76a71d..ef6c7abfba914 100644 --- a/internal/util/streamingutil/util/wal_selector.go +++ b/internal/util/streamingutil/util/wal_selector.go @@ -9,7 +9,6 @@ import ( const ( walTypeDefault = "default" - walTypeNatsmq = "natsmq" walTypeRocksmq = "rocksmq" walTypeKafka = "kafka" walTypePulsar = "pulsar" diff --git a/internal/util/streamingutil/util/wal_selector_test.go b/internal/util/streamingutil/util/wal_selector_test.go index 6343eaf1b3718..a3cc1804254e0 100644 --- a/internal/util/streamingutil/util/wal_selector_test.go +++ b/internal/util/streamingutil/util/wal_selector_test.go @@ -7,27 +7,24 @@ import ( ) func TestValidateWALType(t *testing.T) { - assert.Error(t, validateWALName(false, walTypeNatsmq)) assert.Error(t, validateWALName(false, walTypeRocksmq)) } func TestSelectWALType(t *testing.T) { - assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{true, true, true, true}), walTypeRocksmq) - assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, true, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, false, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, false, false, true}), walTypeKafka) - assert.Panics(t, func() { mustSelectWALName(true, walTypeDefault, walEnable{false, false, false, false}) }) - assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{true, true, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, true, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, false, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, false, false, true}), walTypeKafka) - assert.Panics(t, func() { mustSelectWALName(false, walTypeDefault, walEnable{false, false, false, false}) }) - assert.Equal(t, mustSelectWALName(true, walTypeRocksmq, walEnable{true, true, true, true}), walTypeRocksmq) - assert.Equal(t, mustSelectWALName(true, walTypeNatsmq, walEnable{true, true, true, true}), walTypeNatsmq) - assert.Equal(t, mustSelectWALName(true, walTypePulsar, walEnable{true, true, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(true, walTypeKafka, walEnable{true, true, true, true}), walTypeKafka) - assert.Panics(t, func() { mustSelectWALName(false, walTypeRocksmq, walEnable{true, true, true, true}) }) - assert.Panics(t, func() { mustSelectWALName(false, walTypeNatsmq, walEnable{true, true, true, true}) }) - assert.Equal(t, mustSelectWALName(false, walTypePulsar, walEnable{true, true, true, true}), walTypePulsar) - assert.Equal(t, mustSelectWALName(false, walTypeKafka, walEnable{true, true, true, true}), walTypeKafka) + assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{true, true, true}), walTypeRocksmq) + assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(true, walTypeDefault, walEnable{false, false, true}), walTypeKafka) + assert.Panics(t, func() { mustSelectWALName(true, walTypeDefault, walEnable{false, false, false}) }) + assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{true, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(false, walTypeDefault, walEnable{false, false, true}), walTypeKafka) + assert.Panics(t, func() { mustSelectWALName(false, walTypeDefault, walEnable{false, false, false}) }) + assert.Equal(t, mustSelectWALName(true, walTypeRocksmq, walEnable{true, true, true}), walTypeRocksmq) + assert.Equal(t, mustSelectWALName(true, walTypePulsar, walEnable{true, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(true, walTypeKafka, walEnable{true, true, true}), walTypeKafka) + assert.Panics(t, func() { mustSelectWALName(false, walTypeRocksmq, walEnable{true, true, true}) }) + assert.Equal(t, mustSelectWALName(false, walTypePulsar, walEnable{true, true, true}), walTypePulsar) + assert.Equal(t, mustSelectWALName(false, walTypeKafka, walEnable{true, true, true}), walTypeKafka) } diff --git a/pkg/mq/mqimpl/rocksmq/client/client_impl.go b/pkg/mq/mqimpl/rocksmq/client/client_impl.go index d51335b2a390d..f68e6c602f2e1 100644 --- a/pkg/mq/mqimpl/rocksmq/client/client_impl.go +++ b/pkg/mq/mqimpl/rocksmq/client/client_impl.go @@ -12,7 +12,6 @@ package client import ( - "bytes" "context" "reflect" "sync" @@ -20,12 +19,10 @@ import ( "github.com/cockroachdb/errors" "go.uber.org/zap" - "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server" - "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" ) const ( @@ -201,7 +198,7 @@ func (c *client) tryToConsume(consumer *consumer) []*RmqMessage { } rmqMsgs := make([]*RmqMessage, 0, len(msgs)) for _, msg := range msgs { - rmqMsg, err := c.unmarshalStreamingMessage(consumer.topic, msg) + rmqMsg, err := unmarshalStreamingMessage(consumer.topic, msg) if err == nil { rmqMsgs = append(rmqMsgs, rmqMsg) continue @@ -228,23 +225,6 @@ func (c *client) tryToConsume(consumer *consumer) []*RmqMessage { return rmqMsgs } -func (c *client) unmarshalStreamingMessage(topic string, msg server.ConsumerMessage) (*RmqMessage, error) { - if !bytes.HasPrefix(msg.Payload, magicPrefix) { - return nil, errNotStreamingServiceMessage - } - - var rmqMessage messagespb.RMQMessageLayout - if err := proto.Unmarshal(msg.Payload[len(magicPrefix):], &rmqMessage); err != nil { - return nil, err - } - return &RmqMessage{ - msgID: msg.MsgID, - payload: rmqMessage.Payload, - properties: rmqMessage.Properties, - topic: topic, - }, nil -} - // Close close the channel to notify rocksmq to stop operation and close rocksmq server func (c *client) Close() { c.closeOnce.Do(func() { diff --git a/pkg/mq/mqimpl/rocksmq/client/magic.go b/pkg/mq/mqimpl/rocksmq/client/magic.go deleted file mode 100644 index a1940f9ce566c..0000000000000 --- a/pkg/mq/mqimpl/rocksmq/client/magic.go +++ /dev/null @@ -1,10 +0,0 @@ -package client - -import "github.com/cockroachdb/errors" - -var ( - // magicPrefix is used to identify the rocksmq legacy message and new message for streaming service. - // Make a low probability of collision with the legacy proto message. - magicPrefix = append([]byte{0xFF, 0xFE, 0xFD, 0xFC}, []byte("STREAM")...) - errNotStreamingServiceMessage = errors.New("not a streaming service message") -) diff --git a/pkg/mq/mqimpl/rocksmq/client/producer_impl.go b/pkg/mq/mqimpl/rocksmq/client/producer_impl.go index da7ba3ddb1bed..4f2aad064c293 100644 --- a/pkg/mq/mqimpl/rocksmq/client/producer_impl.go +++ b/pkg/mq/mqimpl/rocksmq/client/producer_impl.go @@ -13,12 +13,10 @@ package client import ( "go.uber.org/zap" - "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server" - "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" ) // assertion make sure implementation @@ -79,19 +77,12 @@ func (p *producer) Send(message *common.ProducerMessage) (UniqueID, error) { } func (p *producer) SendForStreamingService(message *common.ProducerMessage) (UniqueID, error) { - rmqMessage := &messagespb.RMQMessageLayout{ - Payload: message.Payload, - Properties: message.Properties, - } - payload, err := proto.Marshal(rmqMessage) + payload, err := marshalStreamingMessage(message) if err != nil { return 0, err } - finalPayload := make([]byte, len(payload)+len(magicPrefix)) - copy(finalPayload, magicPrefix) - copy(finalPayload[len(magicPrefix):], payload) ids, err := p.c.server.Produce(p.topic, []server.ProducerMessage{{ - Payload: finalPayload, + Payload: payload, }}) if err != nil { return 0, err diff --git a/pkg/mq/mqimpl/rocksmq/client/streaming.go b/pkg/mq/mqimpl/rocksmq/client/streaming.go new file mode 100644 index 0000000000000..c317bb26d9ed9 --- /dev/null +++ b/pkg/mq/mqimpl/rocksmq/client/streaming.go @@ -0,0 +1,53 @@ +package client + +import ( + "bytes" + + "github.com/cockroachdb/errors" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" +) + +var ( + // magicPrefix is used to identify the rocksmq legacy message and new message for streaming service. + // Make a low probability of collision with the legacy proto message. + magicPrefix = append([]byte{0xFF, 0xFE, 0xFD, 0xFC}, []byte("STREAM")...) + errNotStreamingServiceMessage = errors.New("not a streaming service message") +) + +// marshalStreamingMessage marshals a streaming message to bytes. +func marshalStreamingMessage(message *common.ProducerMessage) ([]byte, error) { + rmqMessage := &messagespb.RMQMessageLayout{ + Payload: message.Payload, + Properties: message.Properties, + } + payload, err := proto.Marshal(rmqMessage) + if err != nil { + return nil, err + } + finalPayload := make([]byte, len(payload)+len(magicPrefix)) + copy(finalPayload, magicPrefix) + copy(finalPayload[len(magicPrefix):], payload) + return finalPayload, nil +} + +// unmarshalStreamingMessage unmarshals a streaming message from bytes. +func unmarshalStreamingMessage(topic string, msg server.ConsumerMessage) (*RmqMessage, error) { + if !bytes.HasPrefix(msg.Payload, magicPrefix) { + return nil, errNotStreamingServiceMessage + } + + var rmqMessage messagespb.RMQMessageLayout + if err := proto.Unmarshal(msg.Payload[len(magicPrefix):], &rmqMessage); err != nil { + return nil, err + } + return &RmqMessage{ + msgID: msg.MsgID, + payload: rmqMessage.Payload, + properties: rmqMessage.Properties, + topic: topic, + }, nil +} diff --git a/pkg/mq/mqimpl/rocksmq/client/streaming_test.go b/pkg/mq/mqimpl/rocksmq/client/streaming_test.go new file mode 100644 index 0000000000000..3028c98fb8c13 --- /dev/null +++ b/pkg/mq/mqimpl/rocksmq/client/streaming_test.go @@ -0,0 +1,36 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/mq/common" + "github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server" +) + +func TestStreaming(t *testing.T) { + payload, err := marshalStreamingMessage(&common.ProducerMessage{ + Payload: []byte("payload"), + Properties: map[string]string{ + "key": "value", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, payload) + + msg, err := unmarshalStreamingMessage("topic", server.ConsumerMessage{ + MsgID: 1, + Payload: payload, + }) + assert.NoError(t, err) + assert.Equal(t, string(msg.Payload()), "payload") + assert.Equal(t, msg.Properties()["key"], "value") + msg, err = unmarshalStreamingMessage("topic", server.ConsumerMessage{ + MsgID: 1, + Payload: payload[1:], + }) + assert.Error(t, err) + assert.ErrorIs(t, err, errNotStreamingServiceMessage) + assert.Nil(t, msg) +}