Skip to content

Commit

Permalink
fix: streaming consumer may get stucked when handler is un-consumed (m…
Browse files Browse the repository at this point in the history
…ilvus-io#36818)

issue: milvus-io#36378

Signed-off-by: chyezh <[email protected]>
  • Loading branch information
chyezh authored Oct 14, 2024
1 parent 8905b04 commit f0f5147
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ func (rc *resumableConsumerImpl) resumeLoop() {
// consumer need to resume when error occur, so message handler shouldn't close if the internal consumer encounter failure.
nopCloseMH := nopCloseHandler{
Handler: rc.mh,
HandleInterceptor: func(msg message.ImmutableMessage, handle func(message.ImmutableMessage)) {
HandleInterceptor: func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error) {
g := rc.metrics.StartConsume(msg.EstimateSize())
handle(msg)
ok, err := handle(ctx, msg)
g.Finish()
return ok, err
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestResumableConsumer(t *testing.T) {
rc := NewResumableConsumer(func(ctx context.Context, opts *handler.ConsumerOptions) (consumer.Consumer, error) {
if i == 0 {
i++
opts.MessageHandler.Handle(message.NewImmutableMesasge(
ok, err := opts.MessageHandler.Handle(context.Background(), message.NewImmutableMesasge(
walimplstest.NewTestMessageID(123),
[]byte("payload"),
map[string]string{
Expand All @@ -36,6 +36,8 @@ func TestResumableConsumer(t *testing.T) {
"_v": "1",
"_lc": walimplstest.NewTestMessageID(123).Marshal(),
}))
assert.True(t, ok)
assert.NoError(t, err)
return c, nil
} else if i == 1 {
i++
Expand Down Expand Up @@ -76,7 +78,7 @@ func TestHandler(t *testing.T) {
hNop := nopCloseHandler{
Handler: message.ChanMessageHandler(ch),
}
hNop.Handle(nil)
hNop.Handle(context.Background(), nil)
assert.Nil(t, <-ch)
hNop.Close()
select {
Expand Down
17 changes: 11 additions & 6 deletions internal/distributed/streaming/internal/consumer/handler.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
package consumer

import "github.com/milvus-io/milvus/pkg/streaming/util/message"
import (
"context"

"github.com/milvus-io/milvus/pkg/streaming/util/message"
)

type handleFunc func(ctx context.Context, msg message.ImmutableMessage) (bool, error)

// nopCloseHandler is a handler that do nothing when close.
type nopCloseHandler struct {
message.Handler
HandleInterceptor func(msg message.ImmutableMessage, handle func(message.ImmutableMessage))
HandleInterceptor func(ctx context.Context, msg message.ImmutableMessage, handle handleFunc) (bool, error)
}

// Handle is the callback for handling message.
func (nch nopCloseHandler) Handle(msg message.ImmutableMessage) {
func (nch nopCloseHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) {
if nch.HandleInterceptor != nil {
nch.HandleInterceptor(msg, nch.Handler.Handle)
return
return nch.HandleInterceptor(ctx, msg, nch.Handler.Handle)
}
nch.Handler.Handle(msg)
return nch.Handler.Handle(ctx, msg)
}

// Close is called after all messages are handled or handling is interrupted.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
package consumer

import (
"context"

"github.com/milvus-io/milvus/pkg/streaming/util/message"
)

// timeTickOrderMessageHandler is a message handler that will do metrics and record the last sent message id.
// timeTickOrderMessageHandler is a message handler that will record the last sent message id.
type timeTickOrderMessageHandler struct {
inner message.Handler
lastConfirmedMessageID message.MessageID
lastTimeTick uint64
}

func (mh *timeTickOrderMessageHandler) Handle(msg message.ImmutableMessage) {
func (mh *timeTickOrderMessageHandler) Handle(ctx context.Context, msg message.ImmutableMessage) (bool, error) {
lastConfirmedMessageID := msg.LastConfirmedMessageID()
timetick := msg.TimeTick()

mh.inner.Handle(msg)

mh.lastConfirmedMessageID = lastConfirmedMessageID
mh.lastTimeTick = timetick
ok, err := mh.inner.Handle(ctx, msg)
if ok {
mh.lastConfirmedMessageID = lastConfirmedMessageID
mh.lastTimeTick = timetick
}
return ok, err
}

func (mh *timeTickOrderMessageHandler) Close() {
Expand Down
29 changes: 20 additions & 9 deletions internal/streamingnode/client/handler/consumer/consumer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ func CreateConsumer(
opts *ConsumerOptions,
handlerClient streamingpb.StreamingNodeHandlerServiceClient,
) (Consumer, error) {
ctx, err := createConsumeRequest(ctx, opts)
ctxWithReq, err := createConsumeRequest(ctx, opts)
if err != nil {
return nil, err
}

// TODO: configurable or auto adjust grpc.MaxCallRecvMsgSize
// The messages are always managed by milvus cluster, so the size of message shouldn't be controlled here
// to avoid infinitely blocks.
streamClient, err := handlerClient.Consume(ctx, grpc.MaxCallRecvMsgSize(math.MaxInt32))
streamClient, err := handlerClient.Consume(ctxWithReq, grpc.MaxCallRecvMsgSize(math.MaxInt32))
if err != nil {
return nil, err
}
Expand All @@ -64,6 +64,7 @@ func CreateConsumer(
return nil, status.NewInvalidRequestSeq("first message arrive must be create response")
}
cli := &consumerImpl{
ctx: ctx,
walName: createResp.GetWalName(),
assignment: *opts.Assignment,
grpcStreamClient: streamClient,
Expand Down Expand Up @@ -93,6 +94,7 @@ func createConsumeRequest(ctx context.Context, opts *ConsumerOptions) (context.C
}

type consumerImpl struct {
ctx context.Context // TODO: the cancel method of consumer should be managed by consumerImpl, fix it in future.
walName string
assignment types.PChannelInfoAssigned
grpcStreamClient streamingpb.StreamingNodeHandlerService_ConsumeClient
Expand Down Expand Up @@ -177,12 +179,17 @@ func (c *consumerImpl) recvLoop() (err error) {
resp.Consume.GetMessage().GetProperties(),
)
if newImmutableMsg.TxnContext() != nil {
c.handleTxnMessage(newImmutableMsg)
if err := c.handleTxnMessage(newImmutableMsg); err != nil {
return err
}
} else {
if c.txnBuilder != nil {
panic("unreachable code: txn builder should be nil if we receive a non-txn message")
}
c.msgHandler.Handle(newImmutableMsg)
if _, err := c.msgHandler.Handle(c.ctx, newImmutableMsg); err != nil {
c.logger.Warn("message handle canceled", zap.Error(err))
return errors.Wrapf(err, "At Handler")
}
}
case *streamingpb.ConsumeResponse_Close:
// Should receive io.EOF after that.
Expand All @@ -193,7 +200,7 @@ func (c *consumerImpl) recvLoop() (err error) {
}
}

func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) {
func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) error {
switch msg.MessageType() {
case message.MessageTypeBeginTxn:
if c.txnBuilder != nil {
Expand All @@ -202,7 +209,7 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) {
beginMsg, err := message.AsImmutableBeginTxnMessageV2(msg)
if err != nil {
c.logger.Warn("failed to convert message to begin txn message", zap.Any("messageID", beginMsg.MessageID()), zap.Error(err))
return
return nil
}
c.txnBuilder = message.NewImmutableTxnMessageBuilder(beginMsg)
case message.MessageTypeCommitTxn:
Expand All @@ -213,19 +220,23 @@ func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) {
if err != nil {
c.logger.Warn("failed to convert message to commit txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err))
c.txnBuilder = nil
return
return nil
}
msg, err := c.txnBuilder.Build(commitMsg)
c.txnBuilder = nil
if err != nil {
c.logger.Warn("failed to build txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err))
return
return nil
}
if _, err := c.msgHandler.Handle(c.ctx, msg); err != nil {
c.logger.Warn("message handle canceled at txn", zap.Error(err))
return errors.Wrap(err, "At Handler Of Txn")
}
c.msgHandler.Handle(msg)
default:
if c.txnBuilder == nil {
panic("unreachable code: txn builder should not be nil if we receive a non-begin txn message")
}
c.txnBuilder.Add(msg)
}
return nil
}
149 changes: 102 additions & 47 deletions internal/streamingnode/client/handler/consumer/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,101 @@ import (
)

func TestConsumer(t *testing.T) {
resultCh := make(message.ChanMessageHandler, 1)
c := newMockedConsumerImpl(t, context.Background(), resultCh)

mmsg, _ := message.NewInsertMessageBuilderV1().
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
WithVChannel("test-1").
BuildMutable()
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)

msg := <-resultCh
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1)))

txnCtx := message.TxnContext{
TxnID: 1,
Keepalive: time.Second,
}
mmsg, _ = message.NewBeginTxnMessageBuilderV2().
WithVChannel("test-1").
WithHeader(&message.BeginTxnMessageHeader{}).
WithBody(&message.BeginTxnMessageBody{}).
BuildMutable()
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx))

mmsg, _ = message.NewInsertMessageBuilderV1().
WithVChannel("test-1").
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
BuildMutable()
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx))

mmsg, _ = message.NewCommitTxnMessageBuilderV2().
WithVChannel("test-1").
WithHeader(&message.CommitTxnMessageHeader{}).
WithBody(&message.CommitTxnMessageBody{}).
BuildMutable()
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx))

msg = <-resultCh
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4)))
assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID)
assert.Equal(t, message.MessageTypeTxn, msg.MessageType())

c.consumer.Close()
<-c.consumer.Done()
assert.NoError(t, c.consumer.Error())
}

func TestConsumerWithCancellation(t *testing.T) {
resultCh := make(message.ChanMessageHandler, 1)
ctx, cancel := context.WithCancel(context.Background())
c := newMockedConsumerImpl(t, ctx, resultCh)

mmsg, _ := message.NewInsertMessageBuilderV1().
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
WithVChannel("test-1").
BuildMutable()
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)
// The recv goroutinue will be blocked until the context is canceled.
mmsg, _ = message.NewInsertMessageBuilderV1().
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
WithVChannel("test-1").
BuildMutable()
c.recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)

// The background recv loop should be started.
time.Sleep(20 * time.Millisecond)

go func() {
c.consumer.Close()
}()

select {
case <-c.consumer.Done():
panic("should not reach here")
case <-time.After(10 * time.Millisecond):
}

cancel()
select {
case <-c.consumer.Done():
case <-time.After(20 * time.Millisecond):
panic("should not reach here")
}
assert.ErrorIs(t, c.consumer.Error(), context.Canceled)
}

type mockedConsumer struct {
consumer Consumer
recvCh chan *streamingpb.ConsumeResponse
}

func newMockedConsumerImpl(t *testing.T, ctx context.Context, h message.Handler) *mockedConsumer {
c := mock_streamingpb.NewMockStreamingNodeHandlerServiceClient(t)
cc := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeClient(t)
recvCh := make(chan *streamingpb.ConsumeResponse, 10)
Expand All @@ -43,8 +138,6 @@ func TestConsumer(t *testing.T) {
return nil
})

ctx := context.Background()
resultCh := make(message.ChanMessageHandler, 1)
opts := &ConsumerOptions{
Assignment: &types.PChannelInfoAssigned{
Channel: types.PChannelInfo{Name: "test", Term: 1},
Expand All @@ -55,7 +148,7 @@ func TestConsumer(t *testing.T) {
options.DeliverFilterVChannel("test-1"),
options.DeliverFilterTimeTickGT(100),
},
MessageHandler: resultCh,
MessageHandler: h,
}

recvCh <- &streamingpb.ConsumeResponse{
Expand All @@ -65,53 +158,15 @@ func TestConsumer(t *testing.T) {
},
},
}

mmsg, _ := message.NewInsertMessageBuilderV1().
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
WithVChannel("test-1").
BuildMutable()
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg)

consumer, err := CreateConsumer(ctx, opts, c)
assert.NoError(t, err)
assert.NotNil(t, consumer)
msg := <-resultCh
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1)))

txnCtx := message.TxnContext{
TxnID: 1,
Keepalive: time.Second,
if err != nil {
panic(err)
}
mmsg, _ = message.NewBeginTxnMessageBuilderV2().
WithVChannel("test-1").
WithHeader(&message.BeginTxnMessageHeader{}).
WithBody(&message.BeginTxnMessageBody{}).
BuildMutable()
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx))

mmsg, _ = message.NewInsertMessageBuilderV1().
WithVChannel("test-1").
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
BuildMutable()
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx))

mmsg, _ = message.NewCommitTxnMessageBuilderV2().
WithVChannel("test-1").
WithHeader(&message.CommitTxnMessageHeader{}).
WithBody(&message.CommitTxnMessageBody{}).
BuildMutable()
recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx))

msg = <-resultCh
assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4)))
assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID)
assert.Equal(t, message.MessageTypeTxn, msg.MessageType())

consumer.Close()
<-consumer.Done()
assert.NoError(t, consumer.Error())
return &mockedConsumer{
consumer: consumer,
recvCh: recvCh,
}
}

func newConsumeResponse(id message.MessageID, msg message.MutableMessage) *streamingpb.ConsumeResponse {
Expand Down
Loading

0 comments on commit f0f5147

Please sign in to comment.