Skip to content

Commit

Permalink
fix panic in getFlushState (milvus-io#27237)
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Liu <[email protected]>
  • Loading branch information
weiliu1031 authored Sep 21, 2023
1 parent ab2d8dd commit fc9a9a7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 26 deletions.
20 changes: 10 additions & 10 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3604,11 +3604,11 @@ func (node *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStat
log.Debug("received get flush state request",
zap.Any("request", req))
var err error
resp := &milvuspb.GetFlushStateResponse{}
failResp := &milvuspb.GetFlushStateResponse{}
if err := merr.CheckHealthy(node.stateCode.Load().(commonpb.StateCode)); err != nil {
resp.Status = merr.Status(err)
failResp.Status = merr.Status(err)
log.Warn("unable to get flush state because of closed server")
return resp, nil
return failResp, nil
}

stateReq := &datapb.GetFlushStateRequest{
Expand All @@ -3618,23 +3618,23 @@ func (node *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStat

if len(req.GetCollectionName()) > 0 { // For compatibility with old client
if err = validateCollectionName(req.GetCollectionName()); err != nil {
resp.Status = merr.Status(err)
return resp, nil
failResp.Status = merr.Status(err)
return failResp, nil
}
collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
resp.Status = merr.Status(err)
return resp, nil
failResp.Status = merr.Status(err)
return failResp, nil
}
stateReq.CollectionID = collectionID
}

resp, err = node.dataCoord.GetFlushState(ctx, stateReq)
resp, err := node.dataCoord.GetFlushState(ctx, stateReq)
if err != nil {
log.Warn("failed to get flush state response",
zap.Error(err))
resp.Status = merr.Status(err)
return resp, nil
failResp.Status = merr.Status(err)
return failResp, nil
}
log.Debug("received get flush state response",
zap.Any("response", resp))
Expand Down
117 changes: 101 additions & 16 deletions internal/proxy/impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,24 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
{"flushAll set db", &milvuspb.FlushAllRequest{DbName: "default"}, true},
{"flushAll set db, db not exist", &milvuspb.FlushAllRequest{DbName: "default2"}, false},
}

cacheBak := globalMetaCache
defer func() { globalMetaCache = cacheBak }()
// set expectations
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), nil).Maybe()

cache.On("RemoveDatabase",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
).Maybe()

globalMetaCache = cache

for _, test := range tests {
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
Expand All @@ -395,22 +413,6 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
node.dataCoord = mocks.NewMockDataCoord(t)
node.rootCoord = mocks.NewRootCoord(t)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}

// set expectations
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), nil).Maybe()

cache.On("RemoveDatabase",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
).Maybe()

globalMetaCache = cache

node.dataCoord.(*mocks.MockDataCoord).EXPECT().Flush(mock.Anything, mock.Anything).
Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe()
node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything).
Expand Down Expand Up @@ -451,6 +453,9 @@ func TestProxy_FlushAll(t *testing.T) {
node.dataCoord = mocks.NewMockDataCoord(t)
node.rootCoord = mocks.NewRootCoord(t)

cacheBak := globalMetaCache
defer func() { globalMetaCache = cacheBak }()

// set expectations
cache := NewMockCache(t)
cache.On("GetCollectionID",
Expand Down Expand Up @@ -595,6 +600,86 @@ func TestProxy_GetFlushAllState(t *testing.T) {
})
}

func TestProxy_GetFlushState(t *testing.T) {
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()

node, err := NewProxy(ctx, factory)
assert.NoError(t, err)
node.stateCode.Store(commonpb.StateCode_Healthy)
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
node.dataCoord = mocks.NewMockDataCoord(t)
node.rootCoord = mocks.NewRootCoord(t)

// set expectations
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushState(mock.Anything, mock.Anything).
Return(&milvuspb.GetFlushStateResponse{Status: successStatus}, nil).Maybe()

t.Run("GetFlushState success", func(t *testing.T) {
resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
})

t.Run("GetFlushState failed, server is abnormal", func(t *testing.T) {
node.stateCode.Store(commonpb.StateCode_Abnormal)
resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_NotReadyServe)
node.stateCode.Store(commonpb.StateCode_Healthy)
})

t.Run("GetFlushState with collection name", func(t *testing.T) {
resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
CollectionName: "*",
})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError)

cacheBak := globalMetaCache
defer func() { globalMetaCache = cacheBak }()
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), nil).Maybe()
globalMetaCache = cache

resp, err = node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
CollectionName: "collection1",
})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
})

t.Run("DataCoord GetFlushState failed", func(t *testing.T) {
node.dataCoord.(*mocks.MockDataCoord).ExpectedCalls = nil
node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushState(mock.Anything, mock.Anything).
Return(&milvuspb.GetFlushStateResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mock err",
},
}, nil)
resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError)
})

t.Run("GetFlushState return error", func(t *testing.T) {
node.dataCoord.(*mocks.MockDataCoord).ExpectedCalls = nil
node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushState(mock.Anything, mock.Anything).
Return(nil, errors.New("fake error"))
resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{})
assert.NoError(t, err)
assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError)
})
}

func TestProxy_GetReplicas(t *testing.T) {
factory := dependency.NewDefaultFactory(true)
ctx := context.Background()
Expand Down

0 comments on commit fc9a9a7

Please sign in to comment.