Skip to content

Commit

Permalink
enhance: add unit test case for the replicate message feature
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed Dec 15, 2024
1 parent 7b11450 commit 2f56374
Show file tree
Hide file tree
Showing 23 changed files with 944 additions and 94 deletions.
2 changes: 1 addition & 1 deletion internal/datacoord/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (ch *channelMeta) String() string {
return fmt.Sprintf("Name: %s, CollectionID: %d, StartPositions: %v", ch.Name, ch.CollectionID, ch.StartPositions)
}

func (channelMeta) GetDBProperties() []*commonpb.KeyValuePair {
func (ch *channelMeta) GetDBProperties() []*commonpb.KeyValuePair {
return nil
}

Expand Down
1 change: 1 addition & 0 deletions internal/proto/root_coord.proto
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ service RootCoord {
message AllocTimestampRequest {
common.MsgBase base = 1;
uint32 count = 3;
uint64 blockTimestamp = 4;
}

message AllocTimestampResponse {
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2026,13 +2026,13 @@ func TestAlterCollectionReplicateProperty(t *testing.T) {
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().Close().Return().Maybe()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"alter_property": {mockMsgID1, mockMsgID2},
}, nil).Maybe()

Expand Down
81 changes: 81 additions & 0 deletions internal/proxy/meta_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,87 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
wg.Wait()
}

func TestMetaCacheGetCollectionWithUpdate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
ctx := context.Background()
rootCoord := mocks.NewMockRootCoordClient(t)
queryCoord := mocks.NewMockQueryCoordClient(t)
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, queryCoord, mgr)
assert.NoError(t, err)
t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionID: 1,
Schema: &schemapb.CollectionSchema{
Name: "bar",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1,
Name: "p",
},
{
FieldID: 100,
Name: "pk",
},
},
},
ShardsNum: 1,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
}, nil).Once()
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{11},
PartitionNames: []string{"p1"},
CreatedTimestamps: []uint64{11},
CreatedUtcTimestamps: []uint64{11},
}, nil).Once()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "bar", 1)
assert.NoError(t, err)
assert.Equal(t, c.collID, int64(1))
assert.Equal(t, c.schema.Name, "bar")
})

t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionID: 1,
Schema: &schemapb.CollectionSchema{
Name: "bar",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1,
Name: "p",
},
{
FieldID: 100,
Name: "pk",
},
},
},
ShardsNum: 1,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
VirtualChannelNames: []string{"by-dev-rootcoord-dml_1_1v0"},
}, nil).Once()
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []typeutil.UniqueID{11},
PartitionNames: []string{"p1"},
CreatedTimestamps: []uint64{11},
CreatedUtcTimestamps: []uint64{11},
}, nil).Once()
queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Once()
c, err := globalMetaCache.GetCollectionInfo(ctx, "foo", "hoo", 0)
assert.NoError(t, err)
assert.Equal(t, c.collID, int64(1))
assert.Equal(t, c.schema.Name, "bar")
})
}

func TestMetaCache_GetCollectionName(t *testing.T) {
ctx := context.Background()
rootCoord := &MockRootCoordClientInterface{}
Expand Down
24 changes: 10 additions & 14 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1088,20 +1088,16 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error {
}
endTS, ok := common.GetReplicateEndTS(t.Properties)
if ok && collBasicInfo.replicateID != "" {
var rootcoordTS uint64
for {
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
rootcoordTS = allocResp.GetTimestamp()
if rootcoordTS > endTS {
break
}
log.Info("wait for rootcoord ts", zap.Uint64("rootcoord ts", rootcoordTS), zap.Uint64("end ts", endTS))
time.Sleep(500 * time.Millisecond)
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
BlockTimestamp: endTS,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
if allocResp.GetTimestamp() <= endTS {
return merr.WrapErrServiceInternal("alter collection: alloc timestamp failed, timestamp is not greater than endTS",
fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS))
}
}

Expand Down
28 changes: 12 additions & 16 deletions internal/proxy/task_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package proxy

import (
"context"
"time"
"fmt"

"go.uber.org/zap"

Expand Down Expand Up @@ -290,22 +290,18 @@ func (t *alterDatabaseTask) PreExecute(ctx context.Context) error {
}
oldReplicateEnable, _ := common.IsReplicateEnabled(cacheInfo.properties)
if !oldReplicateEnable { // old replicate enable is false
return nil
return merr.WrapErrParameterInvalidMsg("can't set the replicate end ts property in alter database request when db replicate is disabled")
}
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
BlockTimestamp: endTS,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
var rootcoordTS uint64
for {
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
rootcoordTS = allocResp.GetTimestamp()
if rootcoordTS > endTS {
break
}
log.Info("wait for rootcoord ts", zap.Uint64("rootcoord ts", rootcoordTS), zap.Uint64("end ts", endTS))
time.Sleep(500 * time.Millisecond)
if allocResp.GetTimestamp() <= endTS {
return merr.WrapErrServiceInternal("alter database: alloc timestamp failed, timestamp is not greater than endTS",
fmt.Sprintf("timestamp = %d, endTS = %d", allocResp.GetTimestamp(), endTS))
}

return nil
Expand Down
158 changes: 158 additions & 0 deletions internal/proxy/task_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"
"testing"

"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -201,6 +202,163 @@ func TestAlterDatabase(t *testing.T) {
assert.Nil(t, err1)
}

func TestAlterDatabaseTaskForReplicateProperty(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache

t.Run("replicate id", func(t *testing.T) {
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.MmapEnabledKey,
Value: "true",
},
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})

t.Run("fail to get database info", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})

t.Run("not enable replicate", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{},
}, nil).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})

t.Run("fail to alloc ts", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Once()
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(nil, errors.New("err")).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})

t.Run("alloc wrong ts", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Once()
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Success(),
Timestamp: 999,
}, nil).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.Error(t, err)
})

t.Run("alloc wrong ts", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
},
}, nil).Once()
rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{
Status: merr.Success(),
Timestamp: 1001,
}, nil).Once()
task := &alterDatabaseTask{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{},
DbName: "test_alter_database",
Properties: []*commonpb.KeyValuePair{
{
Key: common.ReplicateEndTSKey,
Value: "1000",
},
},
},
rootCoord: rc,
}
err := task.PreExecute(context.Background())
assert.NoError(t, err)
})
}

func TestDescribeDatabaseTask(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)

Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
replicateID, err := GetReplicateID(it.ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil {
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
return merr.WrapErrAsInputError(err)
}
if replicateID != "" {
return merr.WrapErrCollectionReplicateMode("insert")
Expand Down
Loading

0 comments on commit 2f56374

Please sign in to comment.