Skip to content

Commit

Permalink
feat: add replicate.enable property to collection and related config
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed Oct 24, 2024
1 parent 39a91eb commit 5a673d4
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 4 deletions.
51 changes: 50 additions & 1 deletion internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,8 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC
ctx: ctx,
Condition: NewTaskCondition(ctx),
AlterCollectionRequest: request,
replicateTargetTSMap: node.replicateTargetTSMap,
replicateCurrentTSMap: node.replicateCurrentTSMap,
rootCoord: node.rootCoord,
queryCoord: node.queryCoord,
dataCoord: node.dataCoord,
Expand Down Expand Up @@ -6088,12 +6090,39 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest
}, nil
}

func (node *Proxy) storeCurrentReplicateTS(channelName string, ts uint64) {
current, ok := node.replicateCurrentTSMap.Get(channelName)
if !ok || ts > current {
node.replicateCurrentTSMap.Insert(channelName, ts)
}
}

func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}

if paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() {
if _, ok := node.replicateTargetTSMap.Get(req.GetChannelName()); ok {
log.Ctx(ctx).Warn("the related collection is altering properties, deny to replicate message")
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.ErrDenyReplicateMessage),
}, nil
}

collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool()
ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()

// replicate message can be use in two ways, otherwise return error
// 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode
// 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode
if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) {
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.ErrDenyReplicateMessage),
}, nil
}

if !paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool() &&
paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() {
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.ErrDenyReplicateMessage),
}, nil
Expand Down Expand Up @@ -6137,6 +6166,18 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
StartPositions: req.StartPositions,
EndPositions: req.EndPositions,
}
checkCollectionReplicateProperty := func(dbName, collectionName string) bool {
if !collectionReplicateEnable {
return true
}
info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, collectionName, 0)
if err != nil {
log.Warn("get collection info failed", zap.String("collectionName", collectionName), zap.Error(err))
return false
}
return info.replicateMode
}

// getTsMsgFromConsumerMsg
for i, msgBytes := range req.Msgs {
header := commonpb.MsgHeader{}
Expand All @@ -6154,8 +6195,12 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
node.storeCurrentReplicateTS(req.GetChannelName(), tsMsg.EndTs())
switch realMsg := tsMsg.(type) {
case *msgstream.InsertMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
if err != nil {
Expand All @@ -6170,6 +6215,10 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
realMsg.SegmentID = assignSegmentID
break
}
case *msgstream.DeleteMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
}
Expand Down
11 changes: 10 additions & 1 deletion internal/proxy/meta_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ type collectionBasicInfo struct {
createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel
partitionKeyIsolation bool
replicateMode bool
pchannels []string
}

type collectionInfo struct {
Expand All @@ -108,6 +110,8 @@ type collectionInfo struct {
createdUtcTimestamp uint64
consistencyLevel commonpb.ConsistencyLevel
partitionKeyIsolation bool
replicateMode bool
pchannels []string
}

type databaseInfo struct {
Expand Down Expand Up @@ -277,6 +281,8 @@ func (info *collectionInfo) getBasicInfo() *collectionBasicInfo {
createdUtcTimestamp: info.createdUtcTimestamp,
consistencyLevel: info.consistencyLevel,
partitionKeyIsolation: info.partitionKeyIsolation,
replicateMode: info.replicateMode,
pchannels: info.pchannels,
}

return basicInfo
Expand Down Expand Up @@ -493,6 +499,8 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
createdUtcTimestamp: collection.CreatedUtcTimestamp,
consistencyLevel: collection.ConsistencyLevel,
partitionKeyIsolation: isolation,
replicateMode: common.IsReplicateEnabled(collection.Properties),
pchannels: collection.PhysicalChannelNames,
}

log.Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), zap.Int64("collectionID", collection.CollectionID))
Expand Down Expand Up @@ -570,7 +578,8 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll
method := "GetCollectionInfo"
// if collInfo.collID != collectionID, means that the cache is not trustable
// try to get collection according to collectionID
if !ok || collInfo.collID != collectionID {
// Why use collectionID? Because the collectionID is not always provided in the proxy.
if !ok || (collectionID != 0 && collInfo.collID != collectionID) {
tr := timerecord.NewTimeRecorder("UpdateCache")
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()

Expand Down
4 changes: 4 additions & 0 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ type Proxy struct {
// resource manager
resourceManager resource.Manager
replicateStreamManager *ReplicateStreamManager
replicateTargetTSMap *typeutil.ConcurrentMap[string, uint64]
replicateCurrentTSMap *typeutil.ConcurrentMap[string, uint64]

// materialized view
enableMaterializedView bool
Expand Down Expand Up @@ -152,6 +154,8 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
lbPolicy: lbPolicy,
resourceManager: resourceManager,
replicateStreamManager: replicateStreamManager,
replicateTargetTSMap: typeutil.NewConcurrentMap[string, uint64](),
replicateCurrentTSMap: typeutil.NewConcurrentMap[string, uint64](),
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node)
Expand Down
82 changes: 82 additions & 0 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/ctokenizer"
"github.com/milvus-io/milvus/pkg/common"
Expand Down Expand Up @@ -881,6 +882,10 @@ type alterCollectionTask struct {
result *commonpb.Status
queryCoord types.QueryCoordClient
dataCoord types.DataCoordClient

pchannels []string
replicateTargetTSMap *typeutil.ConcurrentMap[string, uint64]
replicateCurrentTSMap *typeutil.ConcurrentMap[string, uint64]
}

func (t *alterCollectionTask) TraceCtx() context.Context {
Expand Down Expand Up @@ -1041,6 +1046,80 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error {
}
}

newReplicateMode := common.IsReplicateEnabled(t.Properties)
if newReplicateMode {
return merr.WrapErrParameterInvalidMsg("can not enable replicate mode in alter collection")
}
if !newReplicateMode && collBasicInfo.replicateMode {
t.pchannels = collBasicInfo.pchannels
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
for _, pchannel := range collBasicInfo.pchannels {
t.replicateTargetTSMap.Insert(pchannel, allocResp.GetTimestamp())
}
err = t.waitReplicateTs(ctx, collBasicInfo.pchannels)
if err != nil {
for _, pchannel := range collBasicInfo.pchannels {
t.replicateTargetTSMap.Remove(pchannel)
}
}
}

return nil
}

func (t *alterCollectionTask) waitReplicateTs(ctx context.Context, pchannels []string) error {
// wait the current replicate is stable
cnt := 0
currentReplicateTS := make(map[string]uint64)
for {
// t.replicateCurrentTSMap.Len() == 0: force to alter collection because the milvus is just started
// cnt > 3: force to alter collection because the current ts is stable
if t.replicateCurrentTSMap.Len() == 0 || cnt > 3 {
break
}
sameCnt := 0
for _, pchannel := range pchannels {
ts, ok := t.replicateCurrentTSMap.Get(pchannel)
if !ok {
continue
}
if currentReplicateTS[pchannel] == ts {
sameCnt++
} else {
currentReplicateTS[pchannel] = ts
}
}
if sameCnt == len(pchannels) {
cnt++
}
time.Sleep(500 * time.Millisecond)
}

// make sure the allocate ts is larger than the current ts
for {
allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{
Count: 1,
})
if err = merr.CheckRPCCall(allocResp, err); err != nil {
return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error())
}
allocTS := allocResp.GetTimestamp()
needWait := false
for _, replicateTS := range currentReplicateTS {
if allocTS <= replicateTS {
needWait = true
break
}
}
if !needWait {
break
}
}
return nil
}

Expand All @@ -1051,6 +1130,9 @@ func (t *alterCollectionTask) Execute(ctx context.Context) error {
}

func (t *alterCollectionTask) PostExecute(ctx context.Context) error {
for _, pchannel := range t.pchannels {
t.replicateTargetTSMap.Remove(pchannel)
}
return nil
}

Expand Down
9 changes: 9 additions & 0 deletions internal/proxy/task_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,15 @@ func (dr *deleteRunner) Init(ctx context.Context) error {
return ErrWithLog(log, "Failed to get collection id", merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound))
}

info, err := globalMetaCache.GetCollectionInfo(ctx, dr.req.GetDbName(), collName, dr.collectionID)
if err != nil {
log.Warn("get collection info failed", zap.String("collectionName", collName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
if info.replicateMode {
return merr.WrapErrCollectionReplicateMode("delete")
}

dr.schema, err = globalMetaCache.GetCollectionSchema(ctx, dr.req.GetDbName(), collName)
if err != nil {
return ErrWithLog(log, "Failed to get collection schema", err)
Expand Down
9 changes: 9 additions & 0 deletions internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return merr.WrapErrAsInputError(merr.WrapErrParameterTooLarge("insert request size exceeds maxInsertSize"))
}

info, err := globalMetaCache.GetCollectionInfo(it.ctx, it.insertMsg.GetDbName(), collectionName, 0)
if err != nil {
log.Warn("get collection info failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
if info.replicateMode {
return merr.WrapErrCollectionReplicateMode("insert")
}

schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName)
if err != nil {
log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))
Expand Down
5 changes: 3 additions & 2 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3899,8 +3899,9 @@ func TestTaskPartitionKeyIsolation(t *testing.T) {
CollectionName: colName,
Properties: []*commonpb.KeyValuePair{{Key: common.PartitionKeyIsolationKey, Value: isoStr}},
},
queryCoord: qc,
dataCoord: dc,
queryCoord: qc,
dataCoord: dc,
replicateTSMap: typeutil.NewConcurrentMap[int64, uint64](),
}
}

Expand Down
9 changes: 9 additions & 0 deletions internal/proxy/task_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
Timestamp: it.EndTs(),
}

info, err := globalMetaCache.GetCollectionInfo(ctx, it.req.GetDbName(), collectionName, 0)
if err != nil {
log.Warn("get collection info failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
if info.replicateMode {
return merr.WrapErrCollectionReplicateMode("upsert")
}

schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("Failed to get collection schema",
Expand Down
11 changes: 11 additions & 0 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ const (
PartitionKeyIsolationKey = "partitionkey.isolation"
FieldSkipLoadKey = "field.skipLoad"
IndexOffsetCacheEnabledKey = "indexoffsetcache.enabled"
ReplicateEnableKey = "replicate.enable"
)

const (
Expand Down Expand Up @@ -387,3 +388,13 @@ func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) {
}
return true, nil
}

func IsReplicateEnabled(kvs []*commonpb.KeyValuePair) bool {
for _, kv := range kvs {
if kv.GetKey() == ReplicateEnableKey {
val, _ := strconv.ParseBool(kv.GetValue())
return val
}
}
return false
}
1 change: 1 addition & 0 deletions pkg/util/merr/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ var (
ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false)
ErrCollectionOnRecovering = newMilvusError("collection on recovering", 106, true)
ErrCollectionVectorClusteringKeyNotAllowed = newMilvusError("vector clustering key not allowed", 107, false)
ErrCollectionReplicateMode = newMilvusError("can't operate when the collection is replicate mode", 108, false)

// Partition related
ErrPartitionNotFound = newMilvusError("partition not found", 200, false)
Expand Down
4 changes: 4 additions & 0 deletions pkg/util/merr/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error {
return err
}

func WrapErrCollectionReplicateMode(operation string) error {
return wrapFields(ErrCollectionReplicateMode, value("operation", operation))
}

func GetErrorType(err error) ErrorType {
if merr, ok := err.(milvusError); ok {
return merr.errType
Expand Down
10 changes: 10 additions & 0 deletions pkg/util/paramtable/component_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ type commonConfig struct {
MaxBloomFalsePositive ParamItem `refreshable:"true"`
BloomFilterApplyBatchSize ParamItem `refreshable:"true"`
PanicWhenPluginFail ParamItem `refreshable:"false"`
CollectionReplicateEnable ParamItem `refreshable:"true"`

UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"`
UseVectorAsClusteringKey ParamItem `refreshable:"true"`
Expand Down Expand Up @@ -773,6 +774,15 @@ This helps Milvus-CDC synchronize incremental data`,
}
p.TTMsgEnabled.Init(base.mgr)

p.CollectionReplicateEnable = ParamItem{
Key: "common.collectionReplicateEnable",
Version: "2.4.14",
DefaultValue: "false",
Doc: `Whether to enable collection replication.`,
Export: true,
}
p.CollectionReplicateEnable.Init(base.mgr)

p.TraceLogMode = ParamItem{
Key: "common.traceLogMode",
Version: "2.3.4",
Expand Down

0 comments on commit 5a673d4

Please sign in to comment.