From bc57f7d568c48ed6d54e2717e11abfb766c89c1d Mon Sep 17 00:00:00 2001 From: SimFG Date: Sun, 1 Dec 2024 13:08:53 +0800 Subject: [PATCH] fix: unit test case Signed-off-by: SimFG --- internal/proxy/impl_test.go | 5 +- internal/proxy/task_test.go | 20 ++++- internal/querycoordv2/server_test.go | 2 + internal/querycoordv2/task/task_test.go | 2 + internal/querynodev2/metrics_info_test.go | 3 +- .../querynodev2/pipeline/filter_node_test.go | 4 +- internal/querynodev2/pipeline/manager_test.go | 4 +- internal/querynodev2/segments/collection.go | 7 +- .../rootcoord/alter_database_task_test.go | 2 +- pkg/common/common_test.go | 81 +++++++++++++++++++ pkg/mq/msgdispatcher/dispatcher.go | 7 -- pkg/mq/msgstream/mock_msgstream.go | 33 -------- pkg/mq/msgstream/mq_msgstream.go | 9 --- pkg/mq/msgstream/msgstream.go | 1 - 14 files changed, 120 insertions(+), 60 deletions(-) diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index ad8c885d5a19c..a944f29248a27 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -1974,6 +1974,7 @@ func TestReplicateMessageForCollectionMode(t *testing.T) { mockCache := NewMockCache(t) mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice() + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice() globalMetaCache = mockCache { @@ -2185,8 +2186,8 @@ func TestAlterCollectionReplicateProperty(t *testing.T) { CollectionName: "foo_collection", Properties: []*commonpb.KeyValuePair{ { - Key: "replicate.enable", - Value: "false", + Key: "replicate.endTS", + Value: "1", }, }, }) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index b043da8f88d50..94e89f4504b63 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -4276,7 +4276,7 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) { assert.Error(t, err) }) - t.Run("fail to alloc ts", func(t *testing.T) { + t.Run("invalid replicate id", func(t *testing.T) { task := &alterCollectionTask{ AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ CollectionName: "test", @@ -4290,6 +4290,24 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) { rootCoord: mockRootcoord, } + err := task.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("fail to alloc ts", func(t *testing.T) { + task := &alterCollectionTask{ + AlterCollectionRequest: &milvuspb.AlterCollectionRequest{ + CollectionName: "test", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.ReplicateEndTSKey, + Value: "100", + }, + }, + }, + rootCoord: mockRootcoord, + } + mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once() err := task.PreExecute(ctx) assert.Error(t, err) diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 948d0d7a9277b..d3e7a8dd1b6cd 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -36,6 +36,7 @@ import ( coordMocks "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/querycoordv2/checkers" "github.com/milvus-io/milvus/internal/querycoordv2/dist" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -602,6 +603,7 @@ func (suite *ServerSuite) hackServer() { ) suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe() + suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil).Maybe() suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).Return(nil, nil).Maybe() for _, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 2eb53639d5f9d..6de6a70d49175 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" @@ -230,6 +231,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { }, }, nil }) + suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil) for channel, segment := range suite.growingSegments { suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment). Return([]*datapb.SegmentInfo{ diff --git a/internal/querynodev2/metrics_info_test.go b/internal/querynodev2/metrics_info_test.go index be00650d5723c..28ff974960ce1 100644 --- a/internal/querynodev2/metrics_info_test.go +++ b/internal/querynodev2/metrics_info_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" @@ -49,7 +50,7 @@ func TestGetPipelineJSON(t *testing.T) { collectionManager := segments.NewMockCollectionManager(t) segmentManager := segments.NewMockSegmentManager(t) - collectionManager.EXPECT().Get(mock.Anything).Return(&segments.Collection{}) + collectionManager.EXPECT().Get(mock.Anything).Return(segments.NewTestCollection(1, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{})) manager := &segments.Manager{ Collection: collectionManager, Segment: segmentManager, diff --git a/internal/querynodev2/pipeline/filter_node_test.go b/internal/querynodev2/pipeline/filter_node_test.go index 001ca4cef21bb..738a6e0399e53 100644 --- a/internal/querynodev2/pipeline/filter_node_test.go +++ b/internal/querynodev2/pipeline/filter_node_test.go @@ -72,7 +72,7 @@ func (suite *FilterNodeSuite) TestWithLoadCollection() { suite.validSegmentIDs = []int64{2, 3, 4, 5, 6} // mock - collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection) + collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadCollection, nil) for _, partitionID := range suite.partitionIDs { collection.AddPartition(partitionID) } @@ -111,7 +111,7 @@ func (suite *FilterNodeSuite) TestWithLoadPartation() { suite.validSegmentIDs = []int64{2, 3, 4, 5, 6} // mock - collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition) + collection := segments.NewTestCollection(suite.collectionID, querypb.LoadType_LoadPartition, nil) collection.AddPartition(suite.partitionIDs[0]) mockCollectionManager := segments.NewMockCollectionManager(suite.T()) diff --git a/internal/querynodev2/pipeline/manager_test.go b/internal/querynodev2/pipeline/manager_test.go index dccb4a3597de1..07fcc13a58024 100644 --- a/internal/querynodev2/pipeline/manager_test.go +++ b/internal/querynodev2/pipeline/manager_test.go @@ -24,6 +24,8 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" @@ -77,7 +79,7 @@ func (suite *PipelineManagerTestSuite) SetupTest() { func (suite *PipelineManagerTestSuite) TestBasic() { // init mock // mock collection manager - suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{}) + suite.collectionManager.EXPECT().Get(suite.collectionID).Return(segments.NewTestCollection(suite.collectionID, querypb.LoadType_UnKnownType, &schemapb.CollectionSchema{})) // mock mq factory suite.msgDispatcher.EXPECT().Register(mock.Anything, mock.Anything).Return(suite.msgChan, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 029f4d3ca832f..1810cd9403aab 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -296,13 +296,16 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM return coll } -func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *Collection { - return &Collection{ +// Only for test +func NewTestCollection(collectionID int64, loadType querypb.LoadType, schema *schemapb.CollectionSchema) *Collection { + col := &Collection{ id: collectionID, partitions: typeutil.NewConcurrentSet[int64](), loadType: loadType, refCount: atomic.NewUint32(0), } + col.schema.Store(schema) + return col } // new collection without segcore prepare diff --git a/internal/rootcoord/alter_database_task_test.go b/internal/rootcoord/alter_database_task_test.go index 99855d0d6e8c9..69ad9ecfaa832 100644 --- a/internal/rootcoord/alter_database_task_test.go +++ b/internal/rootcoord/alter_database_task_test.go @@ -154,7 +154,7 @@ func Test_alterDatabaseTask_Execute(t *testing.T) { mock.Anything, ).Return(nil) - core := newTestCore(withMeta(meta)) + core := newTestCore(withMeta(meta), withValidProxyManager()) task := &alterDatabaseTask{ baseTask: newBaseTask(context.Background(), core), Req: &rootcoordpb.AlterDatabaseRequest{ diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 7e77b782f38eb..422c020589731 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -177,3 +177,84 @@ func TestShouldFieldBeLoaded(t *testing.T) { }) } } + +func TestReplicateProperty(t *testing.T) { + t.Run("ReplicateID", func(t *testing.T) { + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateIDKey, + Value: "1001", + }, + } + e, ok := IsReplicateEnabled(p) + assert.True(t, e) + assert.True(t, ok) + i, ok := GetReplicateID(p) + assert.True(t, ok) + assert.Equal(t, "1001", i) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateIDKey, + Value: "", + }, + } + e, ok := IsReplicateEnabled(p) + assert.False(t, e) + assert.True(t, ok) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: "foo", + Value: "1001", + }, + } + e, ok := IsReplicateEnabled(p) + assert.False(t, e) + assert.False(t, ok) + } + }) + + t.Run("ReplicateTS", func(t *testing.T) { + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateEndTSKey, + Value: "1001", + }, + } + ts, ok := GetReplicateEndTS(p) + assert.True(t, ok) + assert.EqualValues(t, 1001, ts) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: ReplicateEndTSKey, + Value: "foo", + }, + } + ts, ok := GetReplicateEndTS(p) + assert.False(t, ok) + assert.EqualValues(t, 0, ts) + } + + { + p := []*commonpb.KeyValuePair{ + { + Key: "foo", + Value: "1001", + }, + } + ts, ok := GetReplicateEndTS(p) + assert.False(t, ok) + assert.EqualValues(t, 0, ts) + } + }) +} diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index a64441efb2473..d0ef5a4243e5e 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -336,13 +336,6 @@ func (d *Dispatcher) groupingMsgs(pack *MsgPack) map[string]*MsgPack { } newMsgs = append(newMsgs, msg) } - log.Info("fubang messages", - zap.String("vchannel", vchannel), - zap.Any("beginTs", beginTs), - zap.Any("endTs", endTs), - zap.Any("oldMsgs", targetPacks[vchannel].Msgs), - zap.Any("newMsgs", newMsgs), - ) targetPacks[vchannel].Msgs = newMsgs d.resetMsgPackTS(targetPacks[vchannel], beginTs, endTs) } diff --git a/pkg/mq/msgstream/mock_msgstream.go b/pkg/mq/msgstream/mock_msgstream.go index 6257743e8515f..47a1b9cb6db93 100644 --- a/pkg/mq/msgstream/mock_msgstream.go +++ b/pkg/mq/msgstream/mock_msgstream.go @@ -558,39 +558,6 @@ func (_c *MockMsgStream_SetRepackFunc_Call) RunAndReturn(run func(RepackFunc)) * return _c } -// SetReplicate provides a mock function with given fields: config -func (_m *MockMsgStream) SetReplicate(config *ReplicateConfig) { - _m.Called(config) -} - -// MockMsgStream_SetReplicate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetReplicate' -type MockMsgStream_SetReplicate_Call struct { - *mock.Call -} - -// SetReplicate is a helper method to define mock.On call -// - config *ReplicateConfig -func (_e *MockMsgStream_Expecter) SetReplicate(config interface{}) *MockMsgStream_SetReplicate_Call { - return &MockMsgStream_SetReplicate_Call{Call: _e.mock.On("SetReplicate", config)} -} - -func (_c *MockMsgStream_SetReplicate_Call) Run(run func(config *ReplicateConfig)) *MockMsgStream_SetReplicate_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*ReplicateConfig)) - }) - return _c -} - -func (_c *MockMsgStream_SetReplicate_Call) Return() *MockMsgStream_SetReplicate_Call { - _c.Call.Return() - return _c -} - -func (_c *MockMsgStream_SetReplicate_Call) RunAndReturn(run func(*ReplicateConfig)) *MockMsgStream_SetReplicate_Call { - _c.Call.Return(run) - return _c -} - // NewMockMsgStream creates a new instance of MockMsgStream. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockMsgStream(t interface { diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index abb30bcaa65d9..9975480408863 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -272,15 +272,6 @@ func (ms *mqMsgStream) EnableProduce(can bool) { ms.enableProduce.Store(can) } -// SetReplicate not safe, please call it only onece before produce or consume -func (ms *mqMsgStream) SetReplicate(config *ReplicateConfig) { - if config == nil { - return - } - ms.replicateID = config.ReplicateID - ms.checkFunc = config.CheckFunc -} - func (ms *mqMsgStream) isEnabledProduce() bool { return ms.enableProduce.Load().(bool) } diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 413a103e334b5..9393baa660955 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -75,7 +75,6 @@ type MsgStream interface { CheckTopicValid(channel string) error EnableProduce(can bool) - SetReplicate(config *ReplicateConfig) } type ReplicateConfig struct {