diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index ed7ab17641a81..7fdf65bd2b58d 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1238,6 +1238,8 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC ctx: ctx, Condition: NewTaskCondition(ctx), AlterCollectionRequest: request, + replicateTargetTSMap: node.replicateTargetTSMap, + replicateCurrentTSMap: node.replicateCurrentTSMap, rootCoord: node.rootCoord, queryCoord: node.queryCoord, dataCoord: node.dataCoord, @@ -6088,12 +6090,39 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest }, nil } +func (node *Proxy) storeCurrentReplicateTS(channelName string, ts uint64) { + current, ok := node.replicateCurrentTSMap.Get(channelName) + if !ok || ts > current { + node.replicateCurrentTSMap.Insert(channelName, ts) + } +} + func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) { if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil } - if paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() { + if _, ok := node.replicateTargetTSMap.Get(req.GetChannelName()); ok { + log.Ctx(ctx).Warn("the related collection is altering properties, deny to replicate message") + return &milvuspb.ReplicateMessageResponse{ + Status: merr.Status(merr.ErrDenyReplicateMessage), + }, nil + } + + collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool() + ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() + + // replicate message can be use in two ways, otherwise return error + // 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode + // 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode + if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) { + return &milvuspb.ReplicateMessageResponse{ + Status: merr.Status(merr.ErrDenyReplicateMessage), + }, nil + } + + if !paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool() && + paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() { return &milvuspb.ReplicateMessageResponse{ Status: merr.Status(merr.ErrDenyReplicateMessage), }, nil @@ -6137,6 +6166,18 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate StartPositions: req.StartPositions, EndPositions: req.EndPositions, } + checkCollectionReplicateProperty := func(dbName, collectionName string) bool { + if !collectionReplicateEnable { + return true + } + info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, collectionName, 0) + if err != nil { + log.Warn("get collection info failed", zap.String("collectionName", collectionName), zap.Error(err)) + return false + } + return info.replicateMode + } + // getTsMsgFromConsumerMsg for i, msgBytes := range req.Msgs { header := commonpb.MsgHeader{} @@ -6154,8 +6195,12 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil } + node.storeCurrentReplicateTS(req.GetChannelName(), tsMsg.EndTs()) switch realMsg := tsMsg.(type) { case *msgstream.InsertMsg: + if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) { + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil + } assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(), realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs) if err != nil { @@ -6170,6 +6215,10 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate realMsg.SegmentID = assignSegmentID break } + case *msgstream.DeleteMsg: + if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) { + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil + } } msgPack.Msgs = append(msgPack.Msgs, tsMsg) } diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 7dda1c7047002..4808f077a5277 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -98,6 +98,8 @@ type collectionBasicInfo struct { createdUtcTimestamp uint64 consistencyLevel commonpb.ConsistencyLevel partitionKeyIsolation bool + replicateMode bool + pchannels []string } type collectionInfo struct { @@ -108,6 +110,8 @@ type collectionInfo struct { createdUtcTimestamp uint64 consistencyLevel commonpb.ConsistencyLevel partitionKeyIsolation bool + replicateMode bool + pchannels []string } type databaseInfo struct { @@ -277,6 +281,8 @@ func (info *collectionInfo) getBasicInfo() *collectionBasicInfo { createdUtcTimestamp: info.createdUtcTimestamp, consistencyLevel: info.consistencyLevel, partitionKeyIsolation: info.partitionKeyIsolation, + replicateMode: info.replicateMode, + pchannels: info.pchannels, } return basicInfo @@ -493,6 +499,8 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, createdUtcTimestamp: collection.CreatedUtcTimestamp, consistencyLevel: collection.ConsistencyLevel, partitionKeyIsolation: isolation, + replicateMode: common.IsReplicateEnabled(collection.Properties), + pchannels: collection.PhysicalChannelNames, } log.Info("meta update success", zap.String("database", database), zap.String("collectionName", collectionName), zap.Int64("collectionID", collection.CollectionID)) @@ -570,7 +578,8 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, coll method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID - if !ok || collInfo.collID != collectionID { + // Why use collectionID? Because the collectionID is not always provided in the proxy. + if !ok || (collectionID != 0 && collInfo.collID != collectionID) { tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index acb4222294cb1..28f07a407cf92 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -124,6 +124,8 @@ type Proxy struct { // resource manager resourceManager resource.Manager replicateStreamManager *ReplicateStreamManager + replicateTargetTSMap *typeutil.ConcurrentMap[string, uint64] + replicateCurrentTSMap *typeutil.ConcurrentMap[string, uint64] // materialized view enableMaterializedView bool @@ -152,6 +154,8 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { lbPolicy: lbPolicy, resourceManager: resourceManager, replicateStreamManager: replicateStreamManager, + replicateTargetTSMap: typeutil.NewConcurrentMap[string, uint64](), + replicateCurrentTSMap: typeutil.NewConcurrentMap[string, uint64](), } node.UpdateStateCode(commonpb.StateCode_Abnormal) expr.Register("proxy", node) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 850bbd25fdfd5..f9f7814ac9536 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/ctokenizer" "github.com/milvus-io/milvus/pkg/common" @@ -881,6 +882,10 @@ type alterCollectionTask struct { result *commonpb.Status queryCoord types.QueryCoordClient dataCoord types.DataCoordClient + + pchannels []string + replicateTargetTSMap *typeutil.ConcurrentMap[string, uint64] + replicateCurrentTSMap *typeutil.ConcurrentMap[string, uint64] } func (t *alterCollectionTask) TraceCtx() context.Context { @@ -1041,6 +1046,80 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error { } } + newReplicateMode := common.IsReplicateEnabled(t.Properties) + if newReplicateMode { + return merr.WrapErrParameterInvalidMsg("can not enable replicate mode in alter collection") + } + if !newReplicateMode && collBasicInfo.replicateMode { + t.pchannels = collBasicInfo.pchannels + allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: 1, + }) + if err = merr.CheckRPCCall(allocResp, err); err != nil { + return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) + } + for _, pchannel := range collBasicInfo.pchannels { + t.replicateTargetTSMap.Insert(pchannel, allocResp.GetTimestamp()) + } + err = t.waitReplicateTs(ctx, collBasicInfo.pchannels) + if err != nil { + for _, pchannel := range collBasicInfo.pchannels { + t.replicateTargetTSMap.Remove(pchannel) + } + } + } + + return nil +} + +func (t *alterCollectionTask) waitReplicateTs(ctx context.Context, pchannels []string) error { + // wait the current replicate is stable + cnt := 0 + currentReplicateTS := make(map[string]uint64) + for { + // t.replicateCurrentTSMap.Len() == 0: force to alter collection because the milvus is just started + // cnt > 3: force to alter collection because the current ts is stable + if t.replicateCurrentTSMap.Len() == 0 || cnt > 3 { + break + } + sameCnt := 0 + for _, pchannel := range pchannels { + ts, ok := t.replicateCurrentTSMap.Get(pchannel) + if !ok { + continue + } + if currentReplicateTS[pchannel] == ts { + sameCnt++ + } else { + currentReplicateTS[pchannel] = ts + } + } + if sameCnt == len(pchannels) { + cnt++ + } + time.Sleep(500 * time.Millisecond) + } + + // make sure the allocate ts is larger than the current ts + for { + allocResp, err := t.rootCoord.AllocTimestamp(ctx, &rootcoordpb.AllocTimestampRequest{ + Count: 1, + }) + if err = merr.CheckRPCCall(allocResp, err); err != nil { + return merr.WrapErrServiceInternal("alloc timestamp failed", err.Error()) + } + allocTS := allocResp.GetTimestamp() + needWait := false + for _, replicateTS := range currentReplicateTS { + if allocTS <= replicateTS { + needWait = true + break + } + } + if !needWait { + break + } + } return nil } @@ -1051,6 +1130,9 @@ func (t *alterCollectionTask) Execute(ctx context.Context) error { } func (t *alterCollectionTask) PostExecute(ctx context.Context) error { + for _, pchannel := range t.pchannels { + t.replicateTargetTSMap.Remove(pchannel) + } return nil } diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 494fcf4256d6b..b5c556767f089 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -331,6 +331,15 @@ func (dr *deleteRunner) Init(ctx context.Context) error { return ErrWithLog(log, "Failed to get collection id", merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound)) } + info, err := globalMetaCache.GetCollectionInfo(ctx, dr.req.GetDbName(), collName, dr.collectionID) + if err != nil { + log.Warn("get collection info failed", zap.String("collectionName", collName), zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + if info.replicateMode { + return merr.WrapErrCollectionReplicateMode("delete") + } + dr.schema, err = globalMetaCache.GetCollectionSchema(ctx, dr.req.GetDbName(), collName) if err != nil { return ErrWithLog(log, "Failed to get collection schema", err) diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index fd86fc9d3c343..3e0816caf1f67 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -125,6 +125,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return merr.WrapErrAsInputError(merr.WrapErrParameterTooLarge("insert request size exceeds maxInsertSize")) } + info, err := globalMetaCache.GetCollectionInfo(it.ctx, it.insertMsg.GetDbName(), collectionName, 0) + if err != nil { + log.Warn("get collection info failed", zap.String("collectionName", collectionName), zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + if info.replicateMode { + return merr.WrapErrCollectionReplicateMode("insert") + } + schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName) if err != nil { log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err)) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index e60d046f2d50c..ad38f45ee5c77 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -3899,8 +3899,9 @@ func TestTaskPartitionKeyIsolation(t *testing.T) { CollectionName: colName, Properties: []*commonpb.KeyValuePair{{Key: common.PartitionKeyIsolationKey, Value: isoStr}}, }, - queryCoord: qc, - dataCoord: dc, + queryCoord: qc, + dataCoord: dc, + replicateTSMap: typeutil.NewConcurrentMap[int64, uint64](), } } diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 154bbba8753b7..d9f9568f824c8 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -292,6 +292,15 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { Timestamp: it.EndTs(), } + info, err := globalMetaCache.GetCollectionInfo(ctx, it.req.GetDbName(), collectionName, 0) + if err != nil { + log.Warn("get collection info failed", zap.String("collectionName", collectionName), zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + if info.replicateMode { + return merr.WrapErrCollectionReplicateMode("upsert") + } + schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName) if err != nil { log.Warn("Failed to get collection schema", diff --git a/pkg/common/common.go b/pkg/common/common.go index 768577ccf9c75..53c152c051139 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -183,6 +183,7 @@ const ( PartitionKeyIsolationKey = "partitionkey.isolation" FieldSkipLoadKey = "field.skipLoad" IndexOffsetCacheEnabledKey = "indexoffsetcache.enabled" + ReplicateEnableKey = "replicate.enable" ) const ( @@ -387,3 +388,13 @@ func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) { } return true, nil } + +func IsReplicateEnabled(kvs []*commonpb.KeyValuePair) bool { + for _, kv := range kvs { + if kv.GetKey() == ReplicateEnableKey { + val, _ := strconv.ParseBool(kv.GetValue()) + return val + } + } + return false +} diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 30bd26ec49f8f..69b58f6530465 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -70,6 +70,7 @@ var ( ErrCollectionIllegalSchema = newMilvusError("illegal collection schema", 105, false) ErrCollectionOnRecovering = newMilvusError("collection on recovering", 106, true) ErrCollectionVectorClusteringKeyNotAllowed = newMilvusError("vector clustering key not allowed", 107, false) + ErrCollectionReplicateMode = newMilvusError("can't operate when the collection is replicate mode", 108, false) // Partition related ErrPartitionNotFound = newMilvusError("partition not found", 200, false) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 2ad30564e570b..c4885e9c05206 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -330,6 +330,10 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error { return err } +func WrapErrCollectionReplicateMode(operation string) error { + return wrapFields(ErrCollectionReplicateMode, value("operation", operation)) +} + func GetErrorType(err error) ErrorType { if merr, ok := err.(milvusError); ok { return merr.errType diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index cf9f4b3247df0..e9bfbfb361a53 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -260,6 +260,7 @@ type commonConfig struct { MaxBloomFalsePositive ParamItem `refreshable:"true"` BloomFilterApplyBatchSize ParamItem `refreshable:"true"` PanicWhenPluginFail ParamItem `refreshable:"false"` + CollectionReplicateEnable ParamItem `refreshable:"true"` UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"` UseVectorAsClusteringKey ParamItem `refreshable:"true"` @@ -773,6 +774,15 @@ This helps Milvus-CDC synchronize incremental data`, } p.TTMsgEnabled.Init(base.mgr) + p.CollectionReplicateEnable = ParamItem{ + Key: "common.collectionReplicateEnable", + Version: "2.4.14", + DefaultValue: "false", + Doc: `Whether to enable collection replication.`, + Export: true, + } + p.CollectionReplicateEnable.Init(base.mgr) + p.TraceLogMode = ParamItem{ Key: "common.traceLogMode", Version: "2.3.4",