Skip to content

Commit

Permalink
enhance: add unit test case for the replicate message feature
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed Dec 5, 2024
1 parent f3c9ca7 commit 61e14f5
Show file tree
Hide file tree
Showing 14 changed files with 697 additions and 43 deletions.
2 changes: 1 addition & 1 deletion internal/datacoord/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (ch *channelMeta) String() string {
return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions)
}

func (channelMeta) GetDBProperties() []*commonpb.KeyValuePair {
func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair {
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2026,13 +2026,13 @@ func TestAlterCollectionReplicateProperty(t *testing.T) {
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().Close().Return().Maybe()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"alter_property": {mockMsgID1, mockMsgID2},
}, nil).Maybe()

Expand Down
81 changes: 81 additions & 0 deletions internal/proxy/meta_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
wg.Wait()
}

func TestMetaCacheGetCollectionWithUpdate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
ctx := context.Background()
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := mocks.NewMockQueryCoordClient(t)
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err)
t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionID: 1,
Schema: &schemapb.CollectionSchema{
Name: "bar",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1,
Name: "p",
},
{
FieldID: 100,
Name: "pk",
},
},
},
ShardsNum: 1,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
}, nil).Once()
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{11},
PartitionNames: []string{"p1"},
CreatedTimestamps: []uint64{11},
CreatedUtcTimestamps: []uint64{11},
}, nil).Once()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1)
assert.NoError(t, err)
assert.Equal(t, c.collID, int64(1))
assert.Equal(t, c.schema.Name, "bar")
})

t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionID: 1,
Schema: &schemapb.CollectionSchema{
Name: "bar",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1,
Name: "p",
},
{
FieldID: 100,
Name: "pk",
},
},
},
ShardsNum: 1,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
}, nil).Once()
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{11},
PartitionNames: []string{"p1"},
CreatedTimestamps: []uint64{11},
CreatedUtcTimestamps: []uint64{11},
}, nil).Once()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0)
assert.NoError(t, err)
assert.Equal(t, c.collID, int64(1))
assert.Equal(t, c.schema.Name, "bar")
})
}

func TestMetaCache_GetCollectionName(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil {
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
return merr.WrapErrAsInputError(err)
}
if replicateID != "" {
return merr.WrapErrCollectionReplicateMode("insert")
Expand Down
66 changes: 43 additions & 23 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4312,28 +4312,48 @@ func TestAlterCollectionForReplicateProperty(t *testing.T) {
err := task.PreExecute(ctx)
assert.Error(t, err)
})
}

// t.Run("fail to wait ts", func(t *testing.T) {
// task := &alterCollectionTask{
// AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
// CollectionName: "test",
// Properties: []*commonpb.KeyValuePair{
// {
// Key: common.ReplicateIDKey,
// Value: "",
// },
// },
// },
// rootCoord: mockRootcoord,
// }
//
// mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
// Status: merr.Success(),
// Timestamp: 100,
// Count: 1,
// }, nil).Once()
// mockRootcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
// err := task.PreExecute(ctx)
// assert.Error(t, err)
// })
func TestInsertForReplicate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache

t.Run("get replicate id fail", func(t *testing.T) {
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
task := &insertTask{
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "foo",
},
},
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("insert with replicate id", func(t *testing.T) {
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
schema: &schemaInfo{
CollectionSchema: &schemapb.CollectionSchema{
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-mac",
},
},
},
},
replicateID: "local-mac",
}, nil).Once()
task := &insertTask{
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "foo",
},
},
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
}
18 changes: 13 additions & 5 deletions internal/querynodev2/pipeline/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgdispatcher"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
Expand Down Expand Up @@ -111,6 +113,12 @@ func (suite *PipelineTestSuite) TestBasic() {
schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
DbProperties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
})
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)

Expand Down Expand Up @@ -143,16 +151,16 @@ func (suite *PipelineTestSuite) TestBasic() {
Collection: suite.collectionManager,
Segment: suite.segmentManager,
}
pipeline, err := NewPipeLine(collection, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator)
pipelineObj, err := NewPipeLine(collection, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator)
suite.NoError(err)

// Init Consumer
err = pipeline.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{})
err = pipelineObj.ConsumeMsgStream(context.Background(), &msgpb.MsgPosition{})
suite.NoError(err)

err = pipeline.Start()
err = pipelineObj.Start()
suite.NoError(err)
defer pipeline.Close()
defer pipelineObj.Close()

// watch tsafe manager
listener := suite.tSafeManager.WatchChannel(suite.channel)
Expand All @@ -161,7 +169,7 @@ func (suite *PipelineTestSuite) TestBasic() {
in := suite.buildMsgPack(schema)
suite.msgChan <- in

// wait pipeline work
// wait pipelineObj work
<-listener.On()

// check tsafe
Expand Down
41 changes: 36 additions & 5 deletions internal/rootcoord/alter_collection_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package rootcoord
import (
"context"
"testing"
"time"

"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
)

func Test_alterCollectionTask_Prepare(t *testing.T) {
Expand Down Expand Up @@ -217,14 +219,25 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
assert.NoError(t, err)
})

t.Run("alter successfully", func(t *testing.T) {
t.Run("alter successfully2", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{CollectionID: int64(1)}, nil)
).Return(&model.Collection{
CollectionID: int64(1),
Name: "cn",
DBName: "foo",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
Expand All @@ -237,19 +250,37 @@ func Test_alterCollectionTask_Execute(t *testing.T) {
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}

core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker))
packChan := make(chan *msgstream.MsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")

core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker))
newPros := append(properties, &commonpb.KeyValuePair{
Key: common.ReplicateEndTSKey,
Value: "10000",
})
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: properties,
Properties: newPros,
},
}

err := task.Execute(context.Background())
assert.NoError(t, err)
time.Sleep(time.Second)
select {
case pack := <-packChan:
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].Type())
replicateMsg := pack.Msgs[0].(*msgstream.ReplicateMsg)
assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase())
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection())
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
default:
assert.Fail(t, "no message sent")
}
})

t.Run("test update collection props", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 61e14f5

Please sign in to comment.