diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index f7f3c1e62ad84..abcfff069abf3 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -3604,11 +3604,11 @@ func (node *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStat log.Debug("received get flush state request", zap.Any("request", req)) var err error - resp := &milvuspb.GetFlushStateResponse{} + failResp := &milvuspb.GetFlushStateResponse{} if err := merr.CheckHealthy(node.stateCode.Load().(commonpb.StateCode)); err != nil { - resp.Status = merr.Status(err) + failResp.Status = merr.Status(err) log.Warn("unable to get flush state because of closed server") - return resp, nil + return failResp, nil } stateReq := &datapb.GetFlushStateRequest{ @@ -3618,23 +3618,23 @@ func (node *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStat if len(req.GetCollectionName()) > 0 { // For compatibility with old client if err = validateCollectionName(req.GetCollectionName()); err != nil { - resp.Status = merr.Status(err) - return resp, nil + failResp.Status = merr.Status(err) + return failResp, nil } collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { - resp.Status = merr.Status(err) - return resp, nil + failResp.Status = merr.Status(err) + return failResp, nil } stateReq.CollectionID = collectionID } - resp, err = node.dataCoord.GetFlushState(ctx, stateReq) + resp, err := node.dataCoord.GetFlushState(ctx, stateReq) if err != nil { log.Warn("failed to get flush state response", zap.Error(err)) - resp.Status = merr.Status(err) - return resp, nil + failResp.Status = merr.Status(err) + return failResp, nil } log.Debug("received get flush state response", zap.Any("response", resp)) diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 78fbe2ee32e33..947fd658fbde7 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -374,6 +374,24 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) { {"flushAll set db", &milvuspb.FlushAllRequest{DbName: "default"}, true}, {"flushAll set db, db not exist", &milvuspb.FlushAllRequest{DbName: "default2"}, false}, } + + cacheBak := globalMetaCache + defer func() { globalMetaCache = cacheBak }() + // set expectations + cache := NewMockCache(t) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(UniqueID(0), nil).Maybe() + + cache.On("RemoveDatabase", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + ).Maybe() + + globalMetaCache = cache + for _, test := range tests { factory := dependency.NewDefaultFactory(true) ctx := context.Background() @@ -395,22 +413,6 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) { node.dataCoord = mocks.NewMockDataCoord(t) node.rootCoord = mocks.NewRootCoord(t) successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - - // set expectations - cache := NewMockCache(t) - cache.On("GetCollectionID", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - mock.AnythingOfType("string"), - ).Return(UniqueID(0), nil).Maybe() - - cache.On("RemoveDatabase", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - ).Maybe() - - globalMetaCache = cache - node.dataCoord.(*mocks.MockDataCoord).EXPECT().Flush(mock.Anything, mock.Anything). Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe() node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). @@ -451,6 +453,9 @@ func TestProxy_FlushAll(t *testing.T) { node.dataCoord = mocks.NewMockDataCoord(t) node.rootCoord = mocks.NewRootCoord(t) + cacheBak := globalMetaCache + defer func() { globalMetaCache = cacheBak }() + // set expectations cache := NewMockCache(t) cache.On("GetCollectionID", @@ -595,6 +600,86 @@ func TestProxy_GetFlushAllState(t *testing.T) { }) } +func TestProxy_GetFlushState(t *testing.T) { + factory := dependency.NewDefaultFactory(true) + ctx := context.Background() + + node, err := NewProxy(ctx, factory) + assert.NoError(t, err) + node.stateCode.Store(commonpb.StateCode_Healthy) + node.tsoAllocator = ×tampAllocator{ + tso: newMockTimestampAllocatorInterface(), + } + node.dataCoord = mocks.NewMockDataCoord(t) + node.rootCoord = mocks.NewRootCoord(t) + + // set expectations + successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushState(mock.Anything, mock.Anything). + Return(&milvuspb.GetFlushStateResponse{Status: successStatus}, nil).Maybe() + + t.Run("GetFlushState success", func(t *testing.T) { + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + }) + + t.Run("GetFlushState failed, server is abnormal", func(t *testing.T) { + node.stateCode.Store(commonpb.StateCode_Abnormal) + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_NotReadyServe) + node.stateCode.Store(commonpb.StateCode_Healthy) + }) + + t.Run("GetFlushState with collection name", func(t *testing.T) { + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + CollectionName: "*", + }) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) + + cacheBak := globalMetaCache + defer func() { globalMetaCache = cacheBak }() + cache := NewMockCache(t) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(UniqueID(0), nil).Maybe() + globalMetaCache = cache + + resp, err = node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + CollectionName: "collection1", + }) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + }) + + t.Run("DataCoord GetFlushState failed", func(t *testing.T) { + node.dataCoord.(*mocks.MockDataCoord).ExpectedCalls = nil + node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushState(mock.Anything, mock.Anything). + Return(&milvuspb.GetFlushStateResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mock err", + }, + }, nil) + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) + }) + + t.Run("GetFlushState return error", func(t *testing.T) { + node.dataCoord.(*mocks.MockDataCoord).ExpectedCalls = nil + node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushState(mock.Anything, mock.Anything). + Return(nil, errors.New("fake error")) + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) + }) +} + func TestProxy_GetReplicas(t *testing.T) { factory := dependency.NewDefaultFactory(true) ctx := context.Background()