diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 42555652f26ca..659ce08e596eb 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -371,7 +371,7 @@ queryCoord: channelExclusiveNodeFactor: 4 # the least node number for enable channel's exclusive mode collectionObserverInterval: 200 # the interval of collection observer checkExecutedFlagInterval: 100 # the interval of check executed flag to force to pull dist - updateCollectionLoadStatusInterval: 5 # 5m, max interval for updating collection loaded status + updateCollectionLoadStatusInterval: 300 # 300s, max interval of updating collection loaded status for check health cleanExcludeSegmentInterval: 60 # the time duration of clean pipeline exclude segment which used for filter invalid data, in seconds ip: # TCP/IP address of queryCoord. If not specified, use the first unicastable address port: 19531 # TCP port of queryCoord @@ -806,7 +806,7 @@ common: readwrite: privileges: ListDatabases,SelectOwnership,SelectUser,DescribeResourceGroup,ListResourceGroups,FlushAll,TransferNode,TransferReplica,UpdateResourceGroups # Cluster level readwrite privileges admin: - privileges: ListDatabases,SelectOwnership,SelectUser,DescribeResourceGroup,ListResourceGroups,FlushAll,TransferNode,TransferReplica,UpdateResourceGroups,BackupRBAC,RestoreRBAC,CreateDatabase,DropDatabase,CreateOwnership,DropOwnership,ManageOwnership,CreateResourceGroup,DropResourceGroup,UpdateUser # Cluster level admin privileges + privileges: ListDatabases,SelectOwnership,SelectUser,DescribeResourceGroup,ListResourceGroups,FlushAll,TransferNode,TransferReplica,UpdateResourceGroups,BackupRBAC,RestoreRBAC,CreateDatabase,DropDatabase,CreateOwnership,DropOwnership,ManageOwnership,CreateResourceGroup,DropResourceGroup,UpdateUser,RenameCollection # Cluster level admin privileges database: readonly: privileges: ShowCollections,DescribeDatabase # Database level readonly privileges @@ -818,9 +818,9 @@ common: readonly: privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases # Collection level readonly privileges readwrite: - privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition # Collection level readwrite privileges + privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,CreateIndex,DropIndex,CreatePartition,DropPartition # Collection level readwrite privileges admin: - privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,RenameCollection,CreateIndex,DropIndex,CreatePartition,DropPartition,CreateAlias,DropAlias # Collection level admin privileges + privileges: Query,Search,IndexDetail,GetFlushState,GetLoadState,GetLoadingProgress,HasPartition,ShowPartitions,DescribeCollection,DescribeAlias,GetStatistics,ListAliases,Load,Release,Insert,Delete,Upsert,Import,Flush,Compaction,LoadBalance,CreateIndex,DropIndex,CreatePartition,DropPartition,CreateAlias,DropAlias # Collection level admin privileges tlsMode: 0 session: ttl: 30 # ttl value when session granting a lease to register service diff --git a/internal/datacoord/compaction_task_clustering.go b/internal/datacoord/compaction_task_clustering.go index 4f5ae946d9781..62417fc8268a5 100644 --- a/internal/datacoord/compaction_task_clustering.go +++ b/internal/datacoord/compaction_task_clustering.go @@ -193,16 +193,6 @@ func (t *clusteringCompactionTask) processPipelining() error { log.Debug("wait for the node to be assigned before proceeding with the subsequent steps") return nil } - var operators []UpdateOperator - for _, segID := range t.InputSegments { - operators = append(operators, UpdateSegmentLevelOperator(segID, datapb.SegmentLevel_L2)) - } - err := t.meta.UpdateSegmentsInfo(operators...) - if err != nil { - log.Warn("fail to set segment level to L2", zap.Error(err)) - return merr.WrapErrClusteringCompactionMetaError("UpdateSegmentsInfo before compaction executing", err) - } - if typeutil.IsVectorType(t.GetClusteringKeyField().DataType) { err := t.doAnalyze() if err != nil { @@ -309,34 +299,49 @@ func (t *clusteringCompactionTask) processIndexing() error { return nil } +func (t *clusteringCompactionTask) markInputSegmentsDropped() error { + var operators []UpdateOperator + // mark + for _, segID := range t.GetInputSegments() { + operators = append(operators, UpdateStatusOperator(segID, commonpb.SegmentState_Dropped)) + } + err := t.meta.UpdateSegmentsInfo(operators...) + if err != nil { + log.Warn("markInputSegmentsDropped UpdateSegmentsInfo fail", zap.Error(err)) + return merr.WrapErrClusteringCompactionMetaError("markInputSegmentsDropped UpdateSegmentsInfo", err) + } + return nil +} + // indexed is the final state of a clustering compaction task // one task should only run this once func (t *clusteringCompactionTask) completeTask() error { - err := t.meta.GetPartitionStatsMeta().SavePartitionStatsInfo(&datapb.PartitionStatsInfo{ + var err error + // update current partition stats version + // at this point, the segment view includes both the input segments and the result segments. + if err = t.meta.GetPartitionStatsMeta().SavePartitionStatsInfo(&datapb.PartitionStatsInfo{ CollectionID: t.GetCollectionID(), PartitionID: t.GetPartitionID(), VChannel: t.GetChannel(), Version: t.GetPlanID(), SegmentIDs: t.GetResultSegments(), CommitTime: time.Now().Unix(), - }) - if err != nil { + }); err != nil { return merr.WrapErrClusteringCompactionMetaError("SavePartitionStatsInfo", err) } - var operators []UpdateOperator - for _, segID := range t.GetResultSegments() { - operators = append(operators, UpdateSegmentPartitionStatsVersionOperator(segID, t.GetPlanID())) - } - err = t.meta.UpdateSegmentsInfo(operators...) - if err != nil { - return merr.WrapErrClusteringCompactionMetaError("UpdateSegmentPartitionStatsVersion", err) - } - err = t.meta.GetPartitionStatsMeta().SaveCurrentPartitionStatsVersion(t.GetCollectionID(), t.GetPartitionID(), t.GetChannel(), t.GetPlanID()) if err != nil { return merr.WrapErrClusteringCompactionMetaError("SaveCurrentPartitionStatsVersion", err) } + + // mark input segments as dropped + // now, the segment view only includes the result segments. + if err = t.markInputSegmentsDropped(); err != nil { + log.Warn("mark input segments as Dropped failed, skip it and wait retry", + zap.Int64("planID", t.GetPlanID()), zap.Error(err)) + } + return t.updateAndSaveTaskMeta(setState(datapb.CompactionTaskState_completed)) } @@ -376,25 +381,50 @@ func (t *clusteringCompactionTask) processFailedOrTimeout() error { }); err != nil { log.Warn("clusteringCompactionTask processFailedOrTimeout unable to drop compaction plan", zap.Int64("planID", t.GetPlanID()), zap.Error(err)) } - - // revert segments meta - var operators []UpdateOperator - // revert level of input segments - // L1 : L1 ->(processPipelining)-> L2 ->(processFailedOrTimeout)-> L1 - // L2 : L2 ->(processPipelining)-> L2 ->(processFailedOrTimeout)-> L2 - for _, segID := range t.InputSegments { - operators = append(operators, RevertSegmentLevelOperator(segID)) - } - // if result segments are generated but task fail in the other steps, mark them as L1 segments without partitions stats - for _, segID := range t.ResultSegments { - operators = append(operators, UpdateSegmentLevelOperator(segID, datapb.SegmentLevel_L1)) - operators = append(operators, UpdateSegmentPartitionStatsVersionOperator(segID, 0)) + isInputDropped := false + for _, segID := range t.GetInputSegments() { + if t.meta.GetHealthySegment(segID) == nil { + isInputDropped = true + break + } } - err := t.meta.UpdateSegmentsInfo(operators...) - if err != nil { - log.Warn("UpdateSegmentsInfo fail", zap.Error(err)) - return merr.WrapErrClusteringCompactionMetaError("UpdateSegmentsInfo", err) + if isInputDropped { + log.Info("input segments dropped, doing for compatibility", + zap.Int64("triggerID", t.GetTriggerID()), zap.Int64("planID", t.GetPlanID())) + // this task must be generated by v2.4, just for compatibility + // revert segments meta + var operators []UpdateOperator + // revert level of input segments + // L1 : L1 ->(processPipelining)-> L2 ->(processFailedOrTimeout)-> L1 + // L2 : L2 ->(processPipelining)-> L2 ->(processFailedOrTimeout)-> L2 + for _, segID := range t.GetInputSegments() { + operators = append(operators, RevertSegmentLevelOperator(segID)) + } + // if result segments are generated but task fail in the other steps, mark them as L1 segments without partitions stats + for _, segID := range t.GetResultSegments() { + operators = append(operators, UpdateSegmentLevelOperator(segID, datapb.SegmentLevel_L1)) + operators = append(operators, UpdateSegmentPartitionStatsVersionOperator(segID, 0)) + } + err := t.meta.UpdateSegmentsInfo(operators...) + if err != nil { + log.Warn("UpdateSegmentsInfo fail", zap.Error(err)) + return merr.WrapErrClusteringCompactionMetaError("UpdateSegmentsInfo", err) + } + } else { + // after v2.4.16, mark the results segment as dropped + var operators []UpdateOperator + for _, segID := range t.GetResultSegments() { + // Don't worry about them being loaded; they are all invisible. + operators = append(operators, UpdateStatusOperator(segID, commonpb.SegmentState_Dropped)) + } + + err := t.meta.UpdateSegmentsInfo(operators...) + if err != nil { + log.Warn("UpdateSegmentsInfo fail", zap.Error(err)) + return merr.WrapErrClusteringCompactionMetaError("UpdateSegmentsInfo", err) + } } + t.resetSegmentCompacting() // drop partition stats if uploaded @@ -405,7 +435,7 @@ func (t *clusteringCompactionTask) processFailedOrTimeout() error { Version: t.GetPlanID(), SegmentIDs: t.GetResultSegments(), } - err = t.meta.CleanPartitionStatsInfo(partitionStatsInfo) + err := t.meta.CleanPartitionStatsInfo(partitionStatsInfo) if err != nil { log.Warn("gcPartitionStatsInfo fail", zap.Error(err)) } diff --git a/internal/datacoord/compaction_task_clustering_test.go b/internal/datacoord/compaction_task_clustering_test.go index 8072b270e80ff..a9a4f16ff5b23 100644 --- a/internal/datacoord/compaction_task_clustering_test.go +++ b/internal/datacoord/compaction_task_clustering_test.go @@ -112,7 +112,7 @@ func (s *ClusteringCompactionTaskSuite) TestClusteringCompactionSegmentMetaChang task.processPipelining() seg11 := s.meta.GetSegment(101) - s.Equal(datapb.SegmentLevel_L2, seg11.Level) + s.Equal(datapb.SegmentLevel_L1, seg11.Level) seg21 := s.meta.GetSegment(102) s.Equal(datapb.SegmentLevel_L2, seg21.Level) s.Equal(int64(10000), seg21.PartitionStatsVersion) @@ -147,11 +147,13 @@ func (s *ClusteringCompactionTaskSuite) TestClusteringCompactionSegmentMetaChang s.Equal(int64(10000), seg22.PartitionStatsVersion) seg32 := s.meta.GetSegment(103) - s.Equal(datapb.SegmentLevel_L1, seg32.Level) - s.Equal(int64(0), seg32.PartitionStatsVersion) + s.Equal(datapb.SegmentLevel_L2, seg32.Level) + s.Equal(int64(10001), seg32.PartitionStatsVersion) + s.Equal(commonpb.SegmentState_Dropped, seg32.GetState()) seg42 := s.meta.GetSegment(104) - s.Equal(datapb.SegmentLevel_L1, seg42.Level) - s.Equal(int64(0), seg42.PartitionStatsVersion) + s.Equal(datapb.SegmentLevel_L2, seg42.Level) + s.Equal(int64(10001), seg42.PartitionStatsVersion) + s.Equal(commonpb.SegmentState_Dropped, seg42.GetState()) } func (s *ClusteringCompactionTaskSuite) generateBasicTask(vectorClusteringKey bool) *clusteringCompactionTask { diff --git a/internal/datacoord/handler.go b/internal/datacoord/handler.go index 60bec831c306c..4f40575136489 100644 --- a/internal/datacoord/handler.go +++ b/internal/datacoord/handler.go @@ -119,10 +119,11 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . } var ( - flushedIDs = make(typeutil.UniqueSet) - droppedIDs = make(typeutil.UniqueSet) - growingIDs = make(typeutil.UniqueSet) - levelZeroIDs = make(typeutil.UniqueSet) + flushedIDs = make(typeutil.UniqueSet) + droppedIDs = make(typeutil.UniqueSet) + growingIDs = make(typeutil.UniqueSet) + levelZeroIDs = make(typeutil.UniqueSet) + newFlushedIDs = make(typeutil.UniqueSet) ) // cannot use GetSegmentsByChannel since dropped segments are needed here @@ -132,7 +133,6 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . indexedSegments := FilterInIndexedSegments(h, h.s.meta, false, segments...) indexed := typeutil.NewUniqueSet(lo.Map(indexedSegments, func(segment *SegmentInfo, _ int) int64 { return segment.GetID() })...) - unIndexedIDs := make(typeutil.UniqueSet) for _, s := range segments { if s.GetStartPosition() == nil && s.GetDmlPosition() == nil { continue @@ -142,36 +142,17 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . continue } - currentPartitionStatsVersion := h.s.meta.partitionStatsMeta.GetCurrentPartitionStatsVersion(channel.GetCollectionID(), s.GetPartitionID(), channel.GetName()) - if s.GetLevel() == datapb.SegmentLevel_L2 && s.GetPartitionStatsVersion() != currentPartitionStatsVersion { - // in the process of L2 compaction, newly generated segment may be visible before the whole L2 compaction Plan - // is finished, we have to skip these fast-finished segment because all segments in one L2 Batch must be - // seen atomically, otherwise users will see intermediate result - continue - } - validSegmentInfos[s.GetID()] = s switch { case s.GetState() == commonpb.SegmentState_Dropped: - if s.GetLevel() == datapb.SegmentLevel_L2 && s.GetPartitionStatsVersion() == currentPartitionStatsVersion { - // if segment.partStatsVersion is equal to currentPartitionStatsVersion, - // it must have been indexed, this is guaranteed by clustering compaction process - // this is to ensure that the current valid L2 compaction produce is available to search/query - // to avoid insufficient data - flushedIDs.Insert(s.GetID()) - continue - } droppedIDs.Insert(s.GetID()) case !isFlushState(s.GetState()): growingIDs.Insert(s.GetID()) case s.GetLevel() == datapb.SegmentLevel_L0: levelZeroIDs.Insert(s.GetID()) - case indexed.Contain(s.GetID()) || s.GetNumOfRows() < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64(): - // fill in indexed segments into flushed directly - flushedIDs.Insert(s.GetID()) + default: - // unIndexed segments will be checked if it's parents are all indexed - unIndexedIDs.Insert(s.GetID()) + flushedIDs.Insert(s.GetID()) } } @@ -203,36 +184,55 @@ func (h *ServerHandler) GetQueryVChanPositions(channel RWChannel, partitionIDs . return true } - retrieveUnIndexed := func() bool { + var compactionFromExist func(segID UniqueID) bool + + compactionFromExist = func(segID UniqueID) bool { + compactionFrom := validSegmentInfos[segID].GetCompactionFrom() + if len(compactionFrom) == 0 || !isValid(compactionFrom...) { + return false + } + for _, fromID := range compactionFrom { + if flushedIDs.Contain(fromID) || newFlushedIDs.Contain(fromID) { + return true + } + if compactionFromExist(fromID) { + return true + } + } + return false + } + + segmentIndexed := func(segID UniqueID) bool { + return indexed.Contain(segID) || validSegmentInfos[segID].GetNumOfRows() < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() + } + + retrieve := func() bool { continueRetrieve := false - for id := range unIndexedIDs { + for id := range flushedIDs { compactionFrom := validSegmentInfos[id].GetCompactionFrom() - compactTos := []UniqueID{} // neighbors and itself - if len(compactionFrom) > 0 && isValid(compactionFrom...) { + if len(compactionFrom) == 0 || !isValid(compactionFrom...) { + newFlushedIDs.Insert(id) + continue + } + if segmentIndexed(id) && !compactionFromExist(id) { + newFlushedIDs.Insert(id) + } else { for _, fromID := range compactionFrom { - if len(compactTos) == 0 { - compactToInfo, _ := h.s.meta.GetCompactionTo(fromID) - compactTos = lo.Map(compactToInfo, func(s *SegmentInfo, _ int) UniqueID { return s.GetID() }) - } - if indexed.Contain(fromID) { - flushedIDs.Insert(fromID) - } else { - unIndexedIDs.Insert(fromID) - continueRetrieve = true - } + newFlushedIDs.Insert(fromID) + continueRetrieve = true + droppedIDs.Remove(fromID) } - unIndexedIDs.Remove(compactTos...) - flushedIDs.Remove(compactTos...) - droppedIDs.Remove(compactionFrom...) } } return continueRetrieve } - for retrieveUnIndexed() { + + for retrieve() { + flushedIDs = newFlushedIDs + newFlushedIDs = make(typeutil.UniqueSet) } - // unindexed is flushed segments as well - flushedIDs.Insert(unIndexedIDs.Collect()...) + flushedIDs = newFlushedIDs log.Info("GetQueryVChanPositions", zap.Int64("collectionID", channel.GetCollectionID()), diff --git a/internal/datacoord/handler_test.go b/internal/datacoord/handler_test.go index d85705e7c0a71..643a46b74a064 100644 --- a/internal/datacoord/handler_test.go +++ b/internal/datacoord/handler_test.go @@ -560,6 +560,312 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { assert.EqualValues(t, 0, len(vchan.UnflushedSegmentIds)) assert.ElementsMatch(t, []int64{e.GetID()}, vchan.FlushedSegmentIds) // expected e }) + + t.Run("complex derivation", func(t *testing.T) { + // numbers indicate segmentID, letters indicate segment information + // i: indexed, u: unindexed, g: gced + // 1i, 2i, 3g 4i, 5i, 6i + // | | | | | | + // \ | / \ | / + // \ | / \ | / + // 7u, [8i,9i,10i] [11u, 12i] + // | | | | | | + // \ | / \ / | + // \ | / \ / | + // [13u] [14i, 15u] 12i + // | | | | + // \ / \ / + // \ / \ / + // [16u] [17u] + // all leaf nodes are [1,2,3,4,5,6,7], but because segment3 has been gced, the leaf node becomes [7,8,9,10,4,5,6] + // should be returned: flushed: [7, 8, 9, 10, 4, 5, 6] + svr := newTestServer(t) + defer closeTestServer(t, svr) + schema := newTestSchema() + svr.meta.AddCollection(&collectionInfo{ + ID: 0, + Schema: schema, + }) + err := svr.meta.indexMeta.CreateIndex(&model.Index{ + TenantID: "", + CollectionID: 0, + FieldID: 2, + IndexID: 1, + }) + assert.NoError(t, err) + seg1 := &datapb.SegmentInfo{ + ID: 1, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) + assert.NoError(t, err) + seg2 := &datapb.SegmentInfo{ + ID: 2, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) + assert.NoError(t, err) + // seg3 was GCed + seg4 := &datapb.SegmentInfo{ + ID: 4, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg4)) + assert.NoError(t, err) + seg5 := &datapb.SegmentInfo{ + ID: 5, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5)) + assert.NoError(t, err) + seg6 := &datapb.SegmentInfo{ + ID: 6, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg6)) + assert.NoError(t, err) + seg7 := &datapb.SegmentInfo{ + ID: 7, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 2048, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg7)) + assert.NoError(t, err) + seg8 := &datapb.SegmentInfo{ + ID: 8, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{1, 2, 3}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg8)) + assert.NoError(t, err) + seg9 := &datapb.SegmentInfo{ + ID: 9, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{1, 2, 3}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg9)) + assert.NoError(t, err) + seg10 := &datapb.SegmentInfo{ + ID: 10, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{1, 2, 3}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg10)) + assert.NoError(t, err) + seg11 := &datapb.SegmentInfo{ + ID: 11, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 2048, + CompactionFrom: []int64{4, 5, 6}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg11)) + assert.NoError(t, err) + seg12 := &datapb.SegmentInfo{ + ID: 12, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{4, 5, 6}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg12)) + assert.NoError(t, err) + seg13 := &datapb.SegmentInfo{ + ID: 13, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 2047, + CompactionFrom: []int64{7, 8, 9}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg13)) + assert.NoError(t, err) + seg14 := &datapb.SegmentInfo{ + ID: 14, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 100, + CompactionFrom: []int64{10, 11}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg14)) + assert.NoError(t, err) + seg15 := &datapb.SegmentInfo{ + ID: 15, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Dropped, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 2048, + CompactionFrom: []int64{10, 11}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg15)) + assert.NoError(t, err) + seg16 := &datapb.SegmentInfo{ + ID: 16, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 2048, + CompactionFrom: []int64{13, 14}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg16)) + assert.NoError(t, err) + seg17 := &datapb.SegmentInfo{ + ID: 17, + CollectionID: 0, + PartitionID: 0, + InsertChannel: "ch1", + State: commonpb.SegmentState_Flushed, + DmlPosition: &msgpb.MsgPosition{ + ChannelName: "ch1", + MsgID: []byte{1, 2, 3}, + MsgGroup: "", + Timestamp: 1, + }, + NumOfRows: 2048, + CompactionFrom: []int64{12, 15}, + } + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg17)) + assert.NoError(t, err) + + vchan := svr.handler.GetQueryVChanPositions(&channelMeta{Name: "ch1", CollectionID: 0}) + assert.ElementsMatch(t, []int64{7, 8, 9, 10, 4, 5, 6}, vchan.FlushedSegmentIds) + assert.ElementsMatch(t, []int64{1, 2}, vchan.DroppedSegmentIds) + }) + } func TestShouldDropChannel(t *testing.T) { diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index 574ff69856edc..5fdb4dee3ea9a 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -710,7 +710,7 @@ func (p *updateSegmentPack) Get(segmentID int64) *SegmentInfo { } segment := p.meta.segments.GetSegment(segmentID) - if segment == nil || !isSegmentHealthy(segment) { + if segment == nil { log.Warn("meta update: get segment failed - segment not found", zap.Int64("segmentID", segmentID), zap.Bool("segment nil", segment == nil), @@ -835,23 +835,13 @@ func RevertSegmentLevelOperator(segmentID int64) UpdateOperator { zap.Int64("segmentID", segmentID)) return false } - segment.Level = segment.LastLevel - log.Debug("revert segment level", zap.Int64("segmentID", segmentID), zap.String("LastLevel", segment.LastLevel.String())) - return true - } -} - -func RevertSegmentPartitionStatsVersionOperator(segmentID int64) UpdateOperator { - return func(modPack *updateSegmentPack) bool { - segment := modPack.Get(segmentID) - if segment == nil { - log.Warn("meta update: revert level fail - segment not found", - zap.Int64("segmentID", segmentID)) - return false + // just for compatibility, + if segment.GetLevel() != segment.GetLastLevel() && segment.GetLastLevel() != datapb.SegmentLevel_Legacy { + segment.Level = segment.LastLevel + log.Debug("revert segment level", zap.Int64("segmentID", segmentID), zap.String("LastLevel", segment.LastLevel.String())) + return true } - segment.PartitionStatsVersion = segment.LastPartitionStatsVersion - log.Debug("revert segment partition stats version", zap.Int64("segmentID", segmentID), zap.Int64("LastPartitionStatsVersion", segment.LastPartitionStatsVersion)) - return true + return false } } @@ -1413,8 +1403,14 @@ func (m *meta) completeClusterCompactionMutation(t *datapb.CompactionTask, resul metricMutation := &segMetricMutation{stateChange: make(map[string]map[string]int)} compactFromSegIDs := make([]int64, 0) compactToSegIDs := make([]int64, 0) - compactFromSegInfos := make([]*SegmentInfo, 0) compactToSegInfos := make([]*SegmentInfo, 0) + var ( + collectionID int64 + partitionID int64 + maxRowNum int64 + startPosition *msgpb.MsgPosition + dmlPosition *msgpb.MsgPosition + ) for _, segmentID := range t.GetInputSegments() { segment := m.segments.GetSegment(segmentID) @@ -1422,38 +1418,35 @@ func (m *meta) completeClusterCompactionMutation(t *datapb.CompactionTask, resul return nil, nil, merr.WrapErrSegmentNotFound(segmentID) } - cloned := segment.Clone() - cloned.DroppedAt = uint64(time.Now().UnixNano()) - cloned.Compacted = true - - compactFromSegInfos = append(compactFromSegInfos, cloned) - compactFromSegIDs = append(compactFromSegIDs, cloned.GetID()) + collectionID = segment.GetCollectionID() + partitionID = segment.GetPartitionID() + maxRowNum = segment.GetMaxRowNum() + if startPosition == nil || segment.GetStartPosition().GetTimestamp() < startPosition.GetTimestamp() { + startPosition = segment.GetStartPosition() + } - // metrics mutation for compaction from segments - updateSegStateAndPrepareMetrics(cloned, commonpb.SegmentState_Dropped, metricMutation) + if dmlPosition == nil || segment.GetDmlPosition().GetTimestamp() < dmlPosition.GetTimestamp() { + dmlPosition = segment.GetDmlPosition() + } } for _, seg := range result.GetSegments() { segmentInfo := &datapb.SegmentInfo{ ID: seg.GetSegmentID(), - CollectionID: compactFromSegInfos[0].CollectionID, - PartitionID: compactFromSegInfos[0].PartitionID, + CollectionID: collectionID, + PartitionID: partitionID, InsertChannel: t.GetChannel(), NumOfRows: seg.NumOfRows, State: commonpb.SegmentState_Flushed, - MaxRowNum: compactFromSegInfos[0].MaxRowNum, + MaxRowNum: maxRowNum, Binlogs: seg.GetInsertLogs(), Statslogs: seg.GetField2StatslogPaths(), CreatedByCompaction: true, CompactionFrom: compactFromSegIDs, LastExpireTime: tsoutil.ComposeTSByTime(time.Unix(t.GetStartTime(), 0), 0), Level: datapb.SegmentLevel_L2, - StartPosition: getMinPosition(lo.Map(compactFromSegInfos, func(info *SegmentInfo, _ int) *msgpb.MsgPosition { - return info.GetStartPosition() - })), - DmlPosition: getMinPosition(lo.Map(compactFromSegInfos, func(info *SegmentInfo, _ int) *msgpb.MsgPosition { - return info.GetDmlPosition() - })), + StartPosition: startPosition, + DmlPosition: dmlPosition, } segment := NewSegmentInfo(segmentInfo) compactToSegInfos = append(compactToSegInfos, segment) @@ -1464,10 +1457,6 @@ func (m *meta) completeClusterCompactionMutation(t *datapb.CompactionTask, resul log = log.With(zap.Int64s("compact from", compactFromSegIDs), zap.Int64s("compact to", compactToSegIDs)) log.Debug("meta update: prepare for meta mutation - complete") - compactFromInfos := lo.Map(compactFromSegInfos, func(info *SegmentInfo, _ int) *datapb.SegmentInfo { - return info.SegmentInfo - }) - compactToInfos := lo.Map(compactToSegInfos, func(info *SegmentInfo, _ int) *datapb.SegmentInfo { return info.SegmentInfo }) @@ -1476,18 +1465,11 @@ func (m *meta) completeClusterCompactionMutation(t *datapb.CompactionTask, resul for _, seg := range compactToInfos { binlogs = append(binlogs, metastore.BinlogsIncrement{Segment: seg}) } - // alter compactTo before compactFrom segments to avoid data lost if service crash during AlterSegments + // only add new segments if err := m.catalog.AlterSegments(m.ctx, compactToInfos, binlogs...); err != nil { log.Warn("fail to alter compactTo segments", zap.Error(err)) return nil, nil, err } - if err := m.catalog.AlterSegments(m.ctx, compactFromInfos); err != nil { - log.Warn("fail to alter compactFrom segments", zap.Error(err)) - return nil, nil, err - } - lo.ForEach(compactFromSegInfos, func(info *SegmentInfo, _ int) { - m.segments.SetSegment(info.GetID(), info) - }) lo.ForEach(compactToSegInfos, func(info *SegmentInfo, _ int) { m.segments.SetSegment(info.GetID(), info) }) diff --git a/internal/datacoord/mock_segment_manager.go b/internal/datacoord/mock_segment_manager.go index 37164ae979cfd..8a61177baa829 100644 --- a/internal/datacoord/mock_segment_manager.go +++ b/internal/datacoord/mock_segment_manager.go @@ -53,11 +53,11 @@ type MockManager_AllocSegment_Call struct { } // AllocSegment is a helper method to define mock.On call -// - ctx context.Context -// - collectionID int64 -// - partitionID int64 -// - channelName string -// - requestRows int64 +// - ctx context.Context +// - collectionID int64 +// - partitionID int64 +// - channelName string +// - requestRows int64 func (_e *MockManager_Expecter) AllocSegment(ctx interface{}, collectionID interface{}, partitionID interface{}, channelName interface{}, requestRows interface{}) *MockManager_AllocSegment_Call { return &MockManager_AllocSegment_Call{Call: _e.mock.On("AllocSegment", ctx, collectionID, partitionID, channelName, requestRows)} } @@ -90,8 +90,8 @@ type MockManager_DropSegment_Call struct { } // DropSegment is a helper method to define mock.On call -// - ctx context.Context -// - segmentID int64 +// - ctx context.Context +// - segmentID int64 func (_e *MockManager_Expecter) DropSegment(ctx interface{}, segmentID interface{}) *MockManager_DropSegment_Call { return &MockManager_DropSegment_Call{Call: _e.mock.On("DropSegment", ctx, segmentID)} } @@ -124,8 +124,8 @@ type MockManager_DropSegmentsOfChannel_Call struct { } // DropSegmentsOfChannel is a helper method to define mock.On call -// - ctx context.Context -// - channel string +// - ctx context.Context +// - channel string func (_e *MockManager_Expecter) DropSegmentsOfChannel(ctx interface{}, channel interface{}) *MockManager_DropSegmentsOfChannel_Call { return &MockManager_DropSegmentsOfChannel_Call{Call: _e.mock.On("DropSegmentsOfChannel", ctx, channel)} } @@ -167,8 +167,8 @@ type MockManager_ExpireAllocations_Call struct { } // ExpireAllocations is a helper method to define mock.On call -// - channel string -// - ts uint64 +// - channel string +// - ts uint64 func (_e *MockManager_Expecter) ExpireAllocations(channel interface{}, ts interface{}) *MockManager_ExpireAllocations_Call { return &MockManager_ExpireAllocations_Call{Call: _e.mock.On("ExpireAllocations", channel, ts)} } @@ -222,9 +222,9 @@ type MockManager_GetFlushableSegments_Call struct { } // GetFlushableSegments is a helper method to define mock.On call -// - ctx context.Context -// - channel string -// - ts uint64 +// - ctx context.Context +// - channel string +// - ts uint64 func (_e *MockManager_Expecter) GetFlushableSegments(ctx interface{}, channel interface{}, ts interface{}) *MockManager_GetFlushableSegments_Call { return &MockManager_GetFlushableSegments_Call{Call: _e.mock.On("GetFlushableSegments", ctx, channel, ts)} } @@ -278,9 +278,9 @@ type MockManager_SealAllSegments_Call struct { } // SealAllSegments is a helper method to define mock.On call -// - ctx context.Context -// - collectionID int64 -// - segIDs []int64 +// - ctx context.Context +// - collectionID int64 +// - segIDs []int64 func (_e *MockManager_Expecter) SealAllSegments(ctx interface{}, collectionID interface{}, segIDs interface{}) *MockManager_SealAllSegments_Call { return &MockManager_SealAllSegments_Call{Call: _e.mock.On("SealAllSegments", ctx, collectionID, segIDs)} } diff --git a/internal/datacoord/mock_session_manager.go b/internal/datacoord/mock_session_manager.go index a63f90ca94354..6221db4f1d456 100644 --- a/internal/datacoord/mock_session_manager.go +++ b/internal/datacoord/mock_session_manager.go @@ -6,6 +6,8 @@ import ( context "context" datapb "github.com/milvus-io/milvus/internal/proto/datapb" + healthcheck "github.com/milvus-io/milvus/internal/util/healthcheck" + mock "github.com/stretchr/testify/mock" typeutil "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -113,44 +115,44 @@ func (_c *MockSessionManager_CheckChannelOperationProgress_Call) RunAndReturn(ru return _c } -// CheckHealth provides a mock function with given fields: ctx -func (_m *MockSessionManager) CheckHealth(ctx context.Context) error { +// CheckDNHealth provides a mock function with given fields: ctx +func (_m *MockSessionManager) CheckDNHealth(ctx context.Context) *healthcheck.Result { ret := _m.Called(ctx) - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { + var r0 *healthcheck.Result + if rf, ok := ret.Get(0).(func(context.Context) *healthcheck.Result); ok { r0 = rf(ctx) } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(*healthcheck.Result) } return r0 } -// MockSessionManager_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' -type MockSessionManager_CheckHealth_Call struct { +// MockSessionManager_CheckDNHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckDNHealth' +type MockSessionManager_CheckDNHealth_Call struct { *mock.Call } -// CheckHealth is a helper method to define mock.On call +// CheckDNHealth is a helper method to define mock.On call // - ctx context.Context -func (_e *MockSessionManager_Expecter) CheckHealth(ctx interface{}) *MockSessionManager_CheckHealth_Call { - return &MockSessionManager_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx)} +func (_e *MockSessionManager_Expecter) CheckDNHealth(ctx interface{}) *MockSessionManager_CheckDNHealth_Call { + return &MockSessionManager_CheckDNHealth_Call{Call: _e.mock.On("CheckDNHealth", ctx)} } -func (_c *MockSessionManager_CheckHealth_Call) Run(run func(ctx context.Context)) *MockSessionManager_CheckHealth_Call { +func (_c *MockSessionManager_CheckDNHealth_Call) Run(run func(ctx context.Context)) *MockSessionManager_CheckDNHealth_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context)) }) return _c } -func (_c *MockSessionManager_CheckHealth_Call) Return(_a0 error) *MockSessionManager_CheckHealth_Call { +func (_c *MockSessionManager_CheckDNHealth_Call) Return(_a0 healthcheck.Result) *MockSessionManager_CheckDNHealth_Call { _c.Call.Return(_a0) return _c } -func (_c *MockSessionManager_CheckHealth_Call) RunAndReturn(run func(context.Context) error) *MockSessionManager_CheckHealth_Call { +func (_c *MockSessionManager_CheckDNHealth_Call) RunAndReturn(run func(context.Context) healthcheck.Result) *MockSessionManager_CheckDNHealth_Call { _c.Call.Return(run) return _c } diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index 5012e309fca51..e513bfe5c4a87 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -343,6 +343,22 @@ func (c *mockDataNodeClient) Stop() error { return nil } +func (c *mockDataNodeClient) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + if c.state == commonpb.StateCode_Healthy { + return &milvuspb.CheckHealthResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + IsHealthy: true, + Reasons: []string{}, + }, nil + } else { + return &milvuspb.CheckHealthResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_NotReadyServe}, + IsHealthy: false, + Reasons: []string{"fails"}, + }, nil + } +} + type mockRootCoordClient struct { state commonpb.StateCode cnt int64 diff --git a/internal/datacoord/partition_stats_meta_test.go b/internal/datacoord/partition_stats_meta_test.go index 904f6b3d2ce6f..0c67f2d4424b9 100644 --- a/internal/datacoord/partition_stats_meta_test.go +++ b/internal/datacoord/partition_stats_meta_test.go @@ -42,6 +42,7 @@ func (s *PartitionStatsMetaSuite) SetupTest() { catalog := mocks.NewDataCoordCatalog(s.T()) catalog.EXPECT().SavePartitionStatsInfo(mock.Anything, mock.Anything).Return(nil).Maybe() catalog.EXPECT().ListPartitionStatsInfos(mock.Anything).Return(nil, nil).Maybe() + catalog.EXPECT().SaveCurrentPartitionStatsVersion(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() s.catalog = catalog } @@ -74,8 +75,11 @@ func (s *PartitionStatsMetaSuite) TestGetPartitionStats() { ps := partitionStatsMeta.GetPartitionStats(1, 2, "ch-1", 100) s.NotNil(ps) + err = partitionStatsMeta.SaveCurrentPartitionStatsVersion(1, 2, "ch-1", 100) + s.NoError(err) + currentVersion := partitionStatsMeta.GetCurrentPartitionStatsVersion(1, 2, "ch-1") - s.Equal(emptyPartitionStatsVersion, currentVersion) + s.Equal(int64(100), currentVersion) currentVersion2 := partitionStatsMeta.GetCurrentPartitionStatsVersion(1, 2, "ch-2") s.Equal(emptyPartitionStatsVersion, currentVersion2) diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index ecff0a56e3d26..359400f08c6fb 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -46,6 +46,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/log" @@ -162,6 +163,8 @@ type Server struct { // manage ways that data coord access other coord broker broker.Broker + + healthChecker *healthcheck.Checker } type CollectionNameInfo struct { @@ -407,6 +410,8 @@ func (s *Server) initDataCoord() error { s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(s.ctx) + interval := Params.CommonCfg.HealthCheckInterval.GetAsDuration(time.Second) + s.healthChecker = healthcheck.NewChecker(interval, s.healthCheckFn) log.Info("init datacoord done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", s.address)) return nil } @@ -725,6 +730,7 @@ func (s *Server) startServerLoop() { go s.importChecker.Start() s.garbageCollector.start() s.syncSegmentsScheduler.Start() + s.healthChecker.Start() } // startDataNodeTtLoop start a goroutine to recv data node tt msg from msgstream @@ -1107,6 +1113,9 @@ func (s *Server) Stop() error { if !s.stateCode.CompareAndSwap(commonpb.StateCode_Healthy, commonpb.StateCode_Abnormal) { return nil } + if s.healthChecker != nil { + s.healthChecker.Close() + } logutil.Logger(s.ctx).Info("datacoord server shutdown") s.garbageCollector.close() logutil.Logger(s.ctx).Info("datacoord garbage collector stopped") diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index c830d831c8771..501aaa7c9c9d8 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -51,6 +51,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -2527,12 +2528,12 @@ func Test_CheckHealth(t *testing.T) { return sm } - getChannelManager := func(t *testing.T, findWatcherOk bool) ChannelManager { + getChannelManager := func(findWatcherOk bool) ChannelManager { channelManager := NewMockChannelManager(t) if findWatcherOk { - channelManager.EXPECT().FindWatcher(mock.Anything).Return(0, nil) + channelManager.EXPECT().FindWatcher(mock.Anything).Return(0, nil).Maybe() } else { - channelManager.EXPECT().FindWatcher(mock.Anything).Return(0, errors.New("error")) + channelManager.EXPECT().FindWatcher(mock.Anything).Return(0, errors.New("error")).Maybe() } return channelManager } @@ -2545,6 +2546,21 @@ func Test_CheckHealth(t *testing.T) { 2: nil, } + newServer := func(isHealthy bool, findWatcherOk bool, meta *meta) *Server { + svr := &Server{ + ctx: context.TODO(), + sessionManager: getSessionManager(isHealthy), + channelManager: getChannelManager(findWatcherOk), + meta: meta, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}, + } + svr.stateCode.Store(commonpb.StateCode_Healthy) + svr.healthChecker = healthcheck.NewChecker(20*time.Millisecond, svr.healthCheckFn) + svr.healthChecker.Start() + time.Sleep(30 * time.Millisecond) // wait for next cycle for health checker + return svr + } + t.Run("not healthy", func(t *testing.T) { ctx := context.Background() s := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} @@ -2556,9 +2572,8 @@ func Test_CheckHealth(t *testing.T) { }) t.Run("data node health check is fail", func(t *testing.T) { - svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} - svr.stateCode.Store(commonpb.StateCode_Healthy) - svr.sessionManager = getSessionManager(false) + svr := newServer(false, true, &meta{channelCPs: newChannelCps()}) + defer svr.healthChecker.Close() ctx := context.Background() resp, err := svr.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -2567,11 +2582,8 @@ func Test_CheckHealth(t *testing.T) { }) t.Run("check channel watched fail", func(t *testing.T) { - svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} - svr.stateCode.Store(commonpb.StateCode_Healthy) - svr.sessionManager = getSessionManager(true) - svr.channelManager = getChannelManager(t, false) - svr.meta = &meta{collections: collections} + svr := newServer(true, false, &meta{collections: collections, channelCPs: newChannelCps()}) + defer svr.healthChecker.Close() ctx := context.Background() resp, err := svr.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -2580,11 +2592,7 @@ func Test_CheckHealth(t *testing.T) { }) t.Run("check checkpoint fail", func(t *testing.T) { - svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} - svr.stateCode.Store(commonpb.StateCode_Healthy) - svr.sessionManager = getSessionManager(true) - svr.channelManager = getChannelManager(t, true) - svr.meta = &meta{ + svr := newServer(true, true, &meta{ collections: collections, channelCPs: &channelCPs{ checkpoints: map[string]*msgpb.MsgPosition{ @@ -2594,8 +2602,8 @@ func Test_CheckHealth(t *testing.T) { }, }, }, - } - + }) + defer svr.healthChecker.Close() ctx := context.Background() resp, err := svr.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -2604,11 +2612,7 @@ func Test_CheckHealth(t *testing.T) { }) t.Run("ok", func(t *testing.T) { - svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} - svr.stateCode.Store(commonpb.StateCode_Healthy) - svr.sessionManager = getSessionManager(true) - svr.channelManager = getChannelManager(t, true) - svr.meta = &meta{ + svr := newServer(true, true, &meta{ collections: collections, channelCPs: &channelCPs{ checkpoints: map[string]*msgpb.MsgPosition{ @@ -2626,7 +2630,8 @@ func Test_CheckHealth(t *testing.T) { }, }, }, - } + }) + defer svr.healthChecker.Close() ctx := context.Background() resp, err := svr.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 9b77d6a6a43e8..fd46a5cb2e363 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -35,7 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/internal/util/segmentutil" "github.com/milvus-io/milvus/pkg/common" @@ -1550,20 +1550,24 @@ func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthReque }, nil } - err := s.sessionManager.CheckHealth(ctx) - if err != nil { - return componentutil.CheckHealthRespWithErr(err), nil - } + latestCheckResult := s.healthChecker.GetLatestCheckResult() + return healthcheck.GetCheckHealthResponseFromResult(latestCheckResult), nil +} - if err = CheckAllChannelsWatched(s.meta, s.channelManager); err != nil { - return componentutil.CheckHealthRespWithErr(err), nil - } +func (s *Server) healthCheckFn() *healthcheck.Result { + timeout := Params.CommonCfg.HealthCheckRPCTimeout.GetAsDuration(time.Second) + ctx, cancel := context.WithTimeout(s.ctx, timeout) + defer cancel() - if err = CheckCheckPointsHealth(s.meta); err != nil { - return componentutil.CheckHealthRespWithErr(err), nil + checkResults := s.sessionManager.CheckDNHealth(ctx) + for collectionID, failReason := range CheckAllChannelsWatched(s.meta, s.channelManager) { + checkResults.AppendUnhealthyCollectionMsgs(healthcheck.NewUnhealthyCollectionMsg(collectionID, failReason, healthcheck.ChannelsWatched)) } - return componentutil.CheckHealthRespWithErr(nil), nil + for collectionID, failReason := range CheckCheckPointsHealth(s.meta) { + checkResults.AppendUnhealthyCollectionMsgs(healthcheck.NewUnhealthyCollectionMsg(collectionID, failReason, healthcheck.CheckpointLagExceed)) + } + return checkResults } func (s *Server) GcConfirm(ctx context.Context, request *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) { diff --git a/internal/datacoord/session_manager.go b/internal/datacoord/session_manager.go index c962bf9797ce7..45a698b8a4e4b 100644 --- a/internal/datacoord/session_manager.go +++ b/internal/datacoord/session_manager.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + "sync" "time" "github.com/cockroachdb/errors" @@ -31,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/kv/binlog" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/commonpbutil" @@ -69,7 +71,7 @@ type SessionManager interface { QueryPreImport(nodeID int64, in *datapb.QueryPreImportRequest) (*datapb.QueryPreImportResponse, error) QueryImport(nodeID int64, in *datapb.QueryImportRequest) (*datapb.QueryImportResponse, error) DropImport(nodeID int64, in *datapb.DropImportRequest) error - CheckHealth(ctx context.Context) error + CheckDNHealth(ctx context.Context) *healthcheck.Result QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) DropCompactionPlan(nodeID int64, req *datapb.DropCompactionPlanRequest) error Close() @@ -508,28 +510,44 @@ func (c *SessionManagerImpl) DropImport(nodeID int64, in *datapb.DropImportReque return VerifyResponse(status, err) } -func (c *SessionManagerImpl) CheckHealth(ctx context.Context) error { - group, ctx := errgroup.WithContext(ctx) - +func (c *SessionManagerImpl) CheckDNHealth(ctx context.Context) *healthcheck.Result { + result := healthcheck.NewResult() + wg := sync.WaitGroup{} + wlock := sync.Mutex{} ids := c.GetSessionIDs() + for _, nodeID := range ids { nodeID := nodeID - group.Go(func() error { - cli, err := c.getClient(ctx, nodeID) + wg.Add(1) + go func() { + defer wg.Done() + + datanodeClient, err := c.getClient(ctx, nodeID) if err != nil { - return fmt.Errorf("failed to get DataNode %d: %v", nodeID, err) + err = fmt.Errorf("failed to get node:%d: %v", nodeID, err) + return } - sta, err := cli.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) - if err != nil { - return err + checkHealthResp, err := datanodeClient.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + if err = merr.CheckRPCCall(checkHealthResp, err); err != nil && !errors.Is(err, merr.ErrServiceUnimplemented) { + err = fmt.Errorf("CheckHealth fails for datanode:%d, %w", nodeID, err) + wlock.Lock() + result.AppendUnhealthyClusterMsg( + healthcheck.NewUnhealthyClusterMsg(typeutil.DataNodeRole, nodeID, err.Error(), healthcheck.NodeHealthCheck)) + wlock.Unlock() + return } - err = merr.AnalyzeState("DataNode", nodeID, sta) - return err - }) + + if len(checkHealthResp.Reasons) > 0 { + wlock.Lock() + result.AppendResult(healthcheck.GetHealthCheckResultFromResp(checkHealthResp)) + wlock.Unlock() + } + }() } - return group.Wait() + wg.Wait() + return result } func (c *SessionManagerImpl) QuerySlot(nodeID int64) (*datapb.QuerySlotResponse, error) { diff --git a/internal/datacoord/util.go b/internal/datacoord/util.go index 5ceaa9b014dc2..59e2c4ad8ee59 100644 --- a/internal/datacoord/util.go +++ b/internal/datacoord/util.go @@ -271,7 +271,8 @@ func getCompactionMergeInfo(task *datapb.CompactionTask) *milvuspb.CompactionMer } } -func CheckCheckPointsHealth(meta *meta) error { +func CheckCheckPointsHealth(meta *meta) map[int64]string { + checkResult := make(map[int64]string) for channel, cp := range meta.GetChannelCheckpoints() { collectionID := funcutil.GetCollectionIDFromVChannel(channel) if collectionID == -1 { @@ -285,31 +286,30 @@ func CheckCheckPointsHealth(meta *meta) error { ts, _ := tsoutil.ParseTS(cp.Timestamp) lag := time.Since(ts) if lag > paramtable.Get().DataCoordCfg.ChannelCheckpointMaxLag.GetAsDuration(time.Second) { - return merr.WrapErrChannelCPExceededMaxLag(channel, fmt.Sprintf("checkpoint lag: %f(min)", lag.Minutes())) + checkResult[collectionID] = fmt.Sprintf("exceeds max lag:%s on channel:%s checkpoint", lag, channel) } } - return nil + return checkResult } -func CheckAllChannelsWatched(meta *meta, channelManager ChannelManager) error { +func CheckAllChannelsWatched(meta *meta, channelManager ChannelManager) map[int64]string { collIDs := meta.ListCollections() + checkResult := make(map[int64]string) for _, collID := range collIDs { collInfo := meta.GetCollection(collID) if collInfo == nil { - log.Warn("collection info is nil, skip it", zap.Int64("collectionID", collID)) + log.RatedWarn(60, "collection info is nil, skip it", zap.Int64("collectionID", collID)) continue } for _, channelName := range collInfo.VChannelNames { _, err := channelManager.FindWatcher(channelName) if err != nil { - log.Warn("find watcher for channel failed", zap.Int64("collectionID", collID), - zap.String("channelName", channelName), zap.Error(err)) - return err + checkResult[collID] = fmt.Sprintf("channel:%s is not watched", channelName) } } } - return nil + return checkResult } func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 { diff --git a/internal/datanode/services.go b/internal/datanode/services.go index 24436df1de9ab..429634013962b 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -22,6 +22,7 @@ package datanode import ( "context" "fmt" + "time" "github.com/samber/lo" "go.uber.org/zap" @@ -37,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -45,7 +47,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // WatchDmChannels is not in use @@ -583,3 +588,20 @@ func (node *DataNode) DropCompactionPlan(ctx context.Context, req *datapb.DropCo log.Ctx(ctx).Info("DropCompactionPlans success", zap.Int64("planID", req.GetPlanID())) return merr.Success(), nil } + +func (node *DataNode) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.CheckHealthResponse{ + Status: merr.Status(err), + Reasons: []string{err.Error()}, + }, nil + } + + maxDelay := paramtable.Get().QuotaConfig.MaxTimeTickDelay.GetAsDuration(time.Second) + minFGChannel, minFGTt := rateCol.getMinFlowGraphTt() + if err := ratelimitutil.CheckTimeTickDelay(minFGChannel, minFGTt, maxDelay); err != nil { + msg := healthcheck.NewUnhealthyClusterMsg(typeutil.DataNodeRole, node.GetNodeID(), err.Error(), healthcheck.TimeTickLagExceed) + return healthcheck.GetCheckHealthResponseFromClusterMsg(msg), nil + } + return healthcheck.OK(), nil +} diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index 1a1820bf804cc..d69ab10713a96 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -1114,6 +1114,40 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { }) } +func (s *DataNodeServicesSuite) TestCheckHealth() { + s.Run("node not healthy", func() { + s.SetupTest() + s.node.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := s.node.CheckHealth(ctx, nil) + s.NoError(err) + s.False(merr.Ok(resp.GetStatus())) + s.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + }) + + s.Run("exceeded timetick lag on pipeline", func() { + s.SetupTest() + rateCol.updateFlowGraphTt("timetick-lag-ch", 1) + ctx := context.Background() + resp, err := s.node.CheckHealth(ctx, nil) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + s.False(resp.GetIsHealthy()) + s.NotEmpty(resp.Reasons) + }) + + s.Run("ok", func() { + s.SetupTest() + rateCol.removeFlowGraphChannel("timetick-lag-ch") + ctx := context.Background() + resp, err := s.node.CheckHealth(ctx, nil) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + s.True(resp.GetIsHealthy()) + s.Empty(resp.Reasons) + }) +} + func (s *DataNodeServicesSuite) TestDropCompactionPlan() { s.Run("node not healthy", func() { s.SetupTest() diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 67d5081a19e8c..3944d4039a85b 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -267,3 +267,9 @@ func (c *Client) DropCompactionPlan(ctx context.Context, req *datapb.DropCompact return client.DropCompactionPlan(ctx, req) }) } + +func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*milvuspb.CheckHealthResponse, error) { + return client.CheckHealth(ctx, req) + }) +} diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 5e4ae6f0095e9..fe564e5bafe60 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -407,3 +407,7 @@ func (s *Server) QuerySlot(ctx context.Context, req *datapb.QuerySlotRequest) (* func (s *Server) DropCompactionPlan(ctx context.Context, req *datapb.DropCompactionPlanRequest) (*commonpb.Status, error) { return s.datanode.DropCompactionPlan(ctx, req) } + +func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + return s.datanode.CheckHealth(ctx, req) +} diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go index 640ee87e28916..364347f8c29e5 100644 --- a/internal/distributed/datanode/service_test.go +++ b/internal/distributed/datanode/service_test.go @@ -185,6 +185,10 @@ func (m *MockDataNode) DropCompactionPlan(ctx context.Context, req *datapb.DropC return m.status, m.err } +func (m *MockDataNode) CheckHealth(ctx context.Context, request *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + return &milvuspb.CheckHealthResponse{}, m.err +} + // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// func Test_NewServer(t *testing.T) { paramtable.Init() diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index ef39d46ef7c1f..d56ce8e7ff38e 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -9,6 +9,7 @@ import ( // v2 const ( // --- category --- + DataBaseCategory = "/databases/" CollectionCategory = "/collections/" EntityCategory = "/entities/" PartitionCategory = "/partitions/" @@ -74,6 +75,8 @@ const ( HTTPCollectionName = "collectionName" HTTPCollectionID = "collectionID" HTTPDbName = "dbName" + HTTPDbID = "dbID" + HTTPProperties = "properties" HTTPPartitionName = "partitionName" HTTPPartitionNames = "partitionNames" HTTPUserName = "userName" diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 3784847d4190e..4faf7fba0ffc8 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -517,6 +517,13 @@ func (h *HandlersV1) query(c *gin.Context) { username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { + if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error() + ", error: " + err.Error(), + }) + return nil, err + } return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest)) }) if err == RestRequestInterceptorErr { @@ -588,6 +595,13 @@ func (h *HandlersV1) get(c *gin.Context) { return nil, RestRequestInterceptorErr } queryReq := req.(*milvuspb.QueryRequest) + if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error() + ", error: " + err.Error(), + }) + return nil, err + } queryReq.Expr = filter return h.proxy.Query(reqCtx, queryReq) }) @@ -661,6 +675,13 @@ func (h *HandlersV1) delete(c *gin.Context) { } deleteReq.Expr = filter } + if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error() + ", error: " + err.Error(), + }) + return nil, err + } return h.proxy.Delete(ctx, deleteReq) }) if err == RestRequestInterceptorErr { @@ -737,6 +758,13 @@ func (h *HandlersV1) insert(c *gin.Context) { }) return nil, RestRequestInterceptorErr } + if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error() + ", error: " + err.Error(), + }) + return nil, err + } return h.proxy.Insert(ctx, insertReq) }) if err == RestRequestInterceptorErr { @@ -836,6 +864,13 @@ func (h *HandlersV1) upsert(c *gin.Context) { }) return nil, RestRequestInterceptorErr } + if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error() + ", error: " + err.Error(), + }) + return nil, err + } return h.proxy.Upsert(ctx, upsertReq) }) if err == RestRequestInterceptorErr { @@ -932,6 +967,13 @@ func (h *HandlersV1) search(c *gin.Context) { username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { + if _, err := CheckLimiter(ctx, &req, h.proxy); err != nil { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error() + ", error: " + err.Error(), + }) + return nil, err + } return h.proxy.Search(ctx, req.(*milvuspb.SearchRequest)) }) if err == RestRequestInterceptorErr { diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index 3ca26d2e45ad8..6eb25c6b291d1 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -517,6 +517,10 @@ func TestQuery(t *testing.T) { expectedBody: "{\"code\":200,\"data\":[{\"book_id\":1,\"book_intro\":[0.1,0.11],\"word_count\":1000},{\"book_id\":2,\"book_intro\":[0.2,0.22],\"word_count\":2000},{\"book_id\":3,\"book_intro\":[0.3,0.33],\"word_count\":3000}]}", }) + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") + for _, tt := range testCases { reqs := []*http.Request{genQueryRequest(), genGetRequest()} t.Run(tt.name, func(t *testing.T) { @@ -601,6 +605,10 @@ func TestDelete(t *testing.T) { expectedBody: "{\"code\":200,\"data\":{}}", }) + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") + for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) @@ -624,11 +632,15 @@ func TestDelete(t *testing.T) { } func TestDeleteForFilter(t *testing.T) { + paramtable.Init() jsonBodyList := [][]byte{ []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`), []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "id in [1,2,3]"}`), []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3], "filter": "id in [1,2,3]"}`), } + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") for _, jsonBody := range jsonBodyList { t.Run("delete success", func(t *testing.T) { mp := mocks.NewMockProxy(t) @@ -726,6 +738,10 @@ func TestInsert(t *testing.T) { HTTPCollectionName: DefaultCollectionName, HTTPReturnData: rows[0], }) + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") + for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) @@ -771,6 +787,9 @@ func TestInsertForDataType(t *testing.T) { "[success]with dynamic field": withDynamicField(newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, false, true))), "[success]with array fields": withArrayField(newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, false, true))), } + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") for name, schema := range schemas { t.Run(name, func(t *testing.T) { mp := mocks.NewMockProxy(t) @@ -838,6 +857,9 @@ func TestReturnInt64(t *testing.T) { schemapb.DataType_Int64: "1,2,3", schemapb.DataType_VarChar: "\"1\",\"2\",\"3\"", } + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") for _, dataType := range schemas { t.Run("[insert]httpCfg.allow: false", func(t *testing.T) { schema := newCollectionSchema(generateCollectionSchema(dataType, false, true)) @@ -1167,6 +1189,9 @@ func TestUpsert(t *testing.T) { HTTPCollectionName: DefaultCollectionName, HTTPReturnData: rows[0], }) + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) @@ -1265,6 +1290,9 @@ func TestSearch(t *testing.T) { exceptCode: 200, expectedBody: "{\"code\":200,\"data\":[{\"book_id\":1,\"book_intro\":[0.1,0.11],\"distance\":0.01,\"word_count\":1000},{\"book_id\":2,\"book_intro\":[0.2,0.22],\"distance\":0.04,\"word_count\":2000},{\"book_id\":3,\"book_intro\":[0.3,0.33],\"distance\":0.09,\"word_count\":3000}]}", }) + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true") for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 38f3984a1aa21..5ab7a2be52f1f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -50,70 +50,75 @@ func NewHandlersV2(proxyClient types.ProxyComponent) *HandlersV2 { } func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { - router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listCollections))))) - router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasCollection))))) + router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listCollections)))) + router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.hasCollection)))) // todo review the return data - router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionDetails))))) - router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionStats))))) - router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionLoadState))))) - router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{AutoID: DisableAutoID} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createCollection))))) - router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropCollection))))) - router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.renameCollection))))) - router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadCollection))))) - router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releaseCollection))))) - + router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionDetails)))) + router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionStats)))) + router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionLoadState)))) + router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{AutoID: DisableAutoID} }, wrapperTraceLog(h.createCollection)))) + router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.dropCollection)))) + router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.renameCollection)))) + router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.loadCollection)))) + router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection)))) + + router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase)))) + router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.dropDatabase)))) + router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases)))) + router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.describeDatabase)))) + router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase)))) // Query router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &QueryReqV2{ Limit: 100, OutputFields: []string{DefaultOutputFields}, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.query)))), true)) + }, wrapperTraceLog(h.query))), true)) // Get router.POST(EntityCategory+GetAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionIDReq{ OutputFields: []string{DefaultOutputFields}, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.get)))), true)) + }, wrapperTraceLog(h.get))), true)) // Delete router.POST(EntityCategory+DeleteAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionFilterReq{} - }, wrapperTraceLog(h.wrapperCheckDatabase(h.delete)))), false)) + }, wrapperTraceLog(h.delete))), false)) // Insert router.POST(EntityCategory+InsertAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionDataReq{} - }, wrapperTraceLog(h.wrapperCheckDatabase(h.insert)))), false)) + }, wrapperTraceLog(h.insert))), false)) // Upsert router.POST(EntityCategory+UpsertAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionDataReq{} - }, wrapperTraceLog(h.wrapperCheckDatabase(h.upsert)))), false)) + }, wrapperTraceLog(h.upsert))), false)) // Search router.POST(EntityCategory+SearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &SearchReqV2{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))), true)) + }, wrapperTraceLog(h.search))), true)) // advanced_search, backward compatible uri router.POST(EntityCategory+AdvancedSearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &HybridSearchReq{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))), true)) + }, wrapperTraceLog(h.advancedSearch))), true)) // HybridSearch router.POST(EntityCategory+HybridSearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &HybridSearchReq{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))), true)) + }, wrapperTraceLog(h.advancedSearch))), true)) - router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions))))) - router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions))))) - router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.statsPartition))))) + router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.listPartitions)))) + router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.hasPartitions)))) + router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.statsPartition)))) - router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createPartition))))) - router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropPartition))))) - router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadPartitions))))) - router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releasePartitions))))) + router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.createPartition)))) + router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.dropPartition)))) + router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.loadPartitions)))) + router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.releasePartitions)))) router.POST(UserCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listUsers)))) router.POST(UserCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &UserReq{} }, wrapperTraceLog(h.describeUser)))) @@ -141,24 +146,24 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.addPrivilegesToGroup)))) router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup)))) - router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listIndexes))))) - router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeIndex))))) + router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.listIndexes)))) + router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.describeIndex)))) - router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createIndex))))) + router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.createIndex)))) // todo cannot drop index before release it ? - router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropIndex))))) + router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.dropIndex)))) - router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listAlias))))) - router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeAlias))))) + router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.listAlias)))) + router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.describeAlias)))) - router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createAlias))))) - router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropAlias))))) - router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.alterAlias))))) + router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.createAlias)))) + router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.dropAlias)))) + router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.alterAlias)))) - router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob))))) - router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob))))) - router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) - router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.listImportJob)))) + router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.createImportJob)))) + router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess)))) + router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess)))) } type ( @@ -191,13 +196,15 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc { return } dbName := "" - if getter, ok := req.(requestutil.DBNameGetter); ok { - dbName = getter.GetDbName() - } - if dbName == "" { - dbName = c.Request.Header.Get(HTTPHeaderDBName) + if req != nil { + if getter, ok := req.(requestutil.DBNameGetter); ok { + dbName = getter.GetDbName() + } if dbName == "" { - dbName = DefaultDbName + dbName = c.Request.Header.Get(HTTPHeaderDBName) + if dbName == "" { + dbName = DefaultDbName + } } } username, _ := c.Get(ContextUsername) @@ -261,7 +268,7 @@ func wrapperTraceLog(v2 handlerFuncV2) handlerFuncV2 { if err != nil { log.Ctx(ctx).Info("trace info: all, error", zap.Error(err)) } else { - log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp)) + log.Ctx(ctx).Info("trace info: all, unknown") } } return resp, err @@ -334,31 +341,6 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu return response, err } -func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV2) handlerFuncV2 { - return func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) { - if dbName == DefaultDbName || proxy.CheckDatabase(ctx, dbName) { - return v2(ctx, c, req, dbName) - } - resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) { - return h.proxy.ListDatabases(reqCtx, &milvuspb.ListDatabasesRequest{}) - }) - if err != nil { - return resp, err - } - for _, db := range resp.(*milvuspb.ListDatabasesResponse).DbNames { - if db == dbName { - return v2(ctx, c, req, dbName) - } - } - log.Ctx(ctx).Warn("high level restful api, non-exist database", zap.String("database", dbName)) - HTTPAbortReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), - HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName, - }) - return nil, merr.ErrDatabaseNotFound - } -} - func (h *HandlersV2) hasCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { getter, _ := anyReq.(requestutil.CollectionNameGetter) collectionName := getter.GetCollectionName() @@ -1381,6 +1363,99 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe return statusResponse, err } +func (h *HandlersV2) createDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DatabaseReqWithProperties) + req := &milvuspb.CreateDatabaseRequest{ + DbName: dbName, + } + properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties)) + for key, value := range httpReq.Properties { + properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + req.Properties = properties + + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.CreateDatabase(reqCtx, req.(*milvuspb.CreateDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.DropDatabaseRequest{ + DbName: dbName, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DropDatabase(reqCtx, req.(*milvuspb.DropDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +// todo: use a more flexible way to handle the number of input parameters of req +func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ListDatabasesRequest{} + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.ListDatabases(reqCtx, req.(*milvuspb.ListDatabasesRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListDatabasesResponse).DbNames)) + } + return resp, err +} + +func (h *HandlersV2) describeDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.DescribeDatabaseRequest{ + DbName: dbName, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeDatabase", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DescribeDatabase(reqCtx, req.(*milvuspb.DescribeDatabaseRequest)) + }) + if err != nil { + return nil, err + } + info, _ := resp.(*milvuspb.DescribeDatabaseResponse) + if info.Properties == nil { + info.Properties = []*commonpb.KeyValuePair{} + } + dataBaseInfo := map[string]any{ + HTTPDbName: info.DbName, + HTTPDbID: info.DbID, + HTTPProperties: info.Properties, + } + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: dataBaseInfo}) + return resp, err +} + +func (h *HandlersV2) alterDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DatabaseReqWithProperties) + req := &milvuspb.AlterDatabaseRequest{ + DbName: dbName, + } + properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties)) + for key, value := range httpReq.Properties { + properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + req.Properties = properties + + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) req := &milvuspb.ShowPartitionsRequest{ diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 5248f9c9a69b8..9cb86a26cbe4f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -376,99 +376,47 @@ func TestTimeout(t *testing.T) { } } -func TestDatabaseWrapper(t *testing.T) { +func TestCreateIndex(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + postTestCases := []requestBodyTestCase{} mp := mocks.NewMockProxy(t) - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &StatusSuccess, - DbNames: []string{DefaultCollectionName, "exist"}, - }, nil).Twice() - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once() - h := NewHandlersV2(mp) - ginHandler := gin.Default() - app := ginHandler.Group("", genAuthMiddleWare(false)) - path := "/wrapper/database" - app.POST(path, wrapperPost(func() any { return &DefaultReq{} }, h.wrapperCheckDatabase(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) { - return nil, nil - }))) - postTestCases = append(postTestCases, requestBodyTestCase{ - path: path, - requestBody: []byte(`{}`), - }) - postTestCases = append(postTestCases, requestBodyTestCase{ - path: path, - requestBody: []byte(`{"dbName": "exist"}`), - }) + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(IndexCategory, CreateAction) + // the previous format postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"dbName": "non-exist"}`), - errMsg: "database not found, database: non-exist", - errCode: 800, // ErrDatabaseNotFound + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`), }) + // the current format postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"dbName": "test"}`), - errMsg: "", - errCode: 65535, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`), }) for _, testcase := range postTestCases { t.Run("post"+testcase.path, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) w := httptest.NewRecorder() - ginHandler.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - fmt.Println(w.Body.String()) - if testcase.errCode != 0 { - returnBody := &ReturnErrMsg{} - err := json.Unmarshal(w.Body.Bytes(), returnBody) - assert.Nil(t, err) - assert.Equal(t, testcase.errCode, returnBody.Code) - assert.Equal(t, testcase.errMsg, returnBody.Message) - } - }) - } - - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &StatusSuccess, - DbNames: []string{DefaultCollectionName, "default"}, - }, nil).Once() - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &StatusSuccess, - DbNames: []string{DefaultCollectionName, "test"}, - }, nil).Once() - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once() - rawTestCases := []rawTestCase{ - { - errMsg: "database not found, database: test", - errCode: 800, // ErrDatabaseNotFound - }, - {}, - { - errMsg: "", - errCode: 65535, - }, - } - for _, testcase := range rawTestCases { - t.Run("post with db"+testcase.path, func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader([]byte(`{}`))) - req.Header.Set(HTTPHeaderDBName, "test") - w := httptest.NewRecorder() - ginHandler.ServeHTTP(w, req) + testEngine.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) if testcase.errCode != 0 { - returnBody := &ReturnErrMsg{} - err := json.Unmarshal(w.Body.Bytes(), returnBody) - assert.Nil(t, err) - assert.Equal(t, testcase.errCode, returnBody.Code) assert.Equal(t, testcase.errMsg, returnBody.Message) } }) } } -func TestCreateIndex(t *testing.T) { +func TestDatabase(t *testing.T) { paramtable.Init() // disable rate limit paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") @@ -476,18 +424,107 @@ func TestCreateIndex(t *testing.T) { postTestCases := []requestBodyTestCase{} mp := mocks.NewMockProxy(t) - mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() testEngine := initHTTPServerV2(mp, false) - path := versionalV2(IndexCategory, CreateAction) - // the previous format + path := versionalV2(DataBaseCategory, CreateAction) + // success postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`), + requestBody: []byte(`{"dbName":"test"}`), }) - // the current format + // mock fail postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`), + requestBody: []byte(`{"dbName":"invalid_name"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, DropAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{DbNames: []string{"a", "b", "c"}}, nil).Once() + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(DataBaseCategory, ListAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{DbName: "test", DbID: 100}, nil).Once() + mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(DataBaseCategory, DescribeAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, AlterAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid }) for _, testcase := range postTestCases { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 0ef3c2045e0f6..5ae7babd346a0 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -9,12 +9,23 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +type EmptyReq struct{} + +func (req *EmptyReq) GetDbName() string { return "" } + type DatabaseReq struct { DbName string `json:"dbName"` } func (req *DatabaseReq) GetDbName() string { return req.DbName } +type DatabaseReqWithProperties struct { + DbName string `json:"dbName" binding:"required"` + Properties map[string]interface{} `json:"properties"` +} + +func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName } + type CollectionNameReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index abd4b714d7dc1..ac200a5afdaed 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -345,3 +345,9 @@ func (c *Client) DeleteBatch(ctx context.Context, req *querypb.DeleteBatchReques return client.DeleteBatch(ctx, req) }) } + +func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*milvuspb.CheckHealthResponse, error) { + return client.CheckHealth(ctx, req) + }) +} diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 31a3074b12309..794fe1ce4ad28 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -389,3 +389,7 @@ func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commo func (s *Server) DeleteBatch(ctx context.Context, req *querypb.DeleteBatchRequest) (*querypb.DeleteBatchResponse, error) { return s.querynode.DeleteBatch(ctx, req) } + +func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + return s.querynode.CheckHealth(ctx, req) +} diff --git a/internal/metastore/kv/datacoord/kv_catalog.go b/internal/metastore/kv/datacoord/kv_catalog.go index 61cd74ad2b1e0..15a3bfde4096f 100644 --- a/internal/metastore/kv/datacoord/kv_catalog.go +++ b/internal/metastore/kv/datacoord/kv_catalog.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" + "github.com/cockroachdb/errors" "go.uber.org/zap" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" @@ -42,6 +43,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -891,6 +893,9 @@ func (kc *Catalog) GetCurrentPartitionStatsVersion(ctx context.Context, collID, key := buildCurrentPartitionStatsVersionPath(collID, partID, vChannel) valueStr, err := kc.MetaKv.Load(key) if err != nil { + if errors.Is(err, merr.ErrIoKeyNotFound) { + return 0, nil + } return 0, err } diff --git a/internal/metastore/mocks/mock_rootcoord_catalog.go b/internal/metastore/mocks/mock_rootcoord_catalog.go index 646eb849ae756..8c35d288c1143 100644 --- a/internal/metastore/mocks/mock_rootcoord_catalog.go +++ b/internal/metastore/mocks/mock_rootcoord_catalog.go @@ -1879,7 +1879,7 @@ func (_c *RootCoordCatalog_GetPrivilegeGroup_Call) Return(_a0 *milvuspb.Privileg return _c } -func (_c *RootCoordCatalog_GetPrivilegeGroup_Call) RunAndReturn(run func(context.Context, string) (*milvuspb.PrivilegeGroupInfo,error)) *RootCoordCatalog_GetPrivilegeGroup_Call { +func (_c *RootCoordCatalog_GetPrivilegeGroup_Call) RunAndReturn(run func(context.Context, string) (*milvuspb.PrivilegeGroupInfo, error)) *RootCoordCatalog_GetPrivilegeGroup_Call { _c.Call.Return(run) return _c } diff --git a/internal/mocks/mock_datanode.go b/internal/mocks/mock_datanode.go index 190da75be0885..5e2836fae8af9 100644 --- a/internal/mocks/mock_datanode.go +++ b/internal/mocks/mock_datanode.go @@ -87,6 +87,61 @@ func (_c *MockDataNode_CheckChannelOperationProgress_Call) RunAndReturn(run func return _c } +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNode_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockDataNode_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *MockDataNode_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *MockDataNode_CheckHealth_Call { + return &MockDataNode_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} +} + +func (_c *MockDataNode_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *MockDataNode_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) + }) + return _c +} + +func (_c *MockDataNode_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockDataNode_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNode_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)) *MockDataNode_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + // CompactionV2 provides a mock function with given fields: _a0, _a1 func (_m *MockDataNode) CompactionV2(_a0 context.Context, _a1 *datapb.CompactionPlan) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/mocks/mock_datanode_client.go b/internal/mocks/mock_datanode_client.go index 91661051c390b..633cf103d2d53 100644 --- a/internal/mocks/mock_datanode_client.go +++ b/internal/mocks/mock_datanode_client.go @@ -101,6 +101,76 @@ func (_c *MockDataNodeClient_CheckChannelOperationProgress_Call) RunAndReturn(ru return _c } +// CheckHealth provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockDataNodeClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CheckHealthRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) CheckHealth(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_CheckHealth_Call { + return &MockDataNodeClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_CheckHealth_Call) Run(run func(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption)) *MockDataNodeClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockDataNodeClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)) *MockDataNodeClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockDataNodeClient) Close() error { ret := _m.Called() diff --git a/internal/mocks/mock_querynode.go b/internal/mocks/mock_querynode.go index abcf83fade8d0..0332ef6aec183 100644 --- a/internal/mocks/mock_querynode.go +++ b/internal/mocks/mock_querynode.go @@ -30,6 +30,61 @@ func (_m *MockQueryNode) EXPECT() *MockQueryNode_Expecter { return &MockQueryNode_Expecter{mock: &_m.Mock} } +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNode_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockQueryNode_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *MockQueryNode_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *MockQueryNode_CheckHealth_Call { + return &MockQueryNode_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} +} + +func (_c *MockQueryNode_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *MockQueryNode_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) + }) + return _c +} + +func (_c *MockQueryNode_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockQueryNode_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNode_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)) *MockQueryNode_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + // Delete provides a mock function with given fields: _a0, _a1 func (_m *MockQueryNode) Delete(_a0 context.Context, _a1 *querypb.DeleteRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/mocks/mock_querynode_client.go b/internal/mocks/mock_querynode_client.go index e7777eb6ab7c6..e2b04295d9176 100644 --- a/internal/mocks/mock_querynode_client.go +++ b/internal/mocks/mock_querynode_client.go @@ -31,6 +31,76 @@ func (_m *MockQueryNodeClient) EXPECT() *MockQueryNodeClient_Expecter { return &MockQueryNodeClient_Expecter{mock: &_m.Mock} } +// CheckHealth provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockQueryNodeClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CheckHealthRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) CheckHealth(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_CheckHealth_Call { + return &MockQueryNodeClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_CheckHealth_Call) Run(run func(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockQueryNodeClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)) *MockQueryNodeClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockQueryNodeClient) Close() error { ret := _m.Called() diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index 3bd0a0abe9687..7294914f03e64 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -131,6 +131,8 @@ service DataNode { rpc QuerySlot(QuerySlotRequest) returns(QuerySlotResponse) {} rpc DropCompactionPlan(DropCompactionPlanRequest) returns(common.Status) {} + + rpc CheckHealth(milvus.CheckHealthRequest)returns (milvus.CheckHealthResponse) {} } message FlushRequest { diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 2f52c2ce47f33..4c3f212599d7e 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -175,7 +175,9 @@ service QueryNode { // DeleteBatch is the API to apply same delete data into multiple segments. // it's basically same as `Delete` but cost less memory pressure. rpc DeleteBatch(DeleteBatchRequest) returns (DeleteBatchResponse) { - } + } + + rpc CheckHealth(milvus.CheckHealthRequest)returns (milvus.CheckHealthResponse) {} } // --------------------QueryCoord grpc request and response proto------------------ diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index fc0f3a91d4028..be2a32fa441f0 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -5254,9 +5254,9 @@ func (node *Proxy) validPrivilegeParams(req *milvuspb.OperatePrivilegeRequest) e return nil } -func (node *Proxy) validOperatePrivilegeV2Params(req *milvuspb.OperatePrivilegeV2Request) error { +func (node *Proxy) validateOperatePrivilegeV2Params(req *milvuspb.OperatePrivilegeV2Request) error { if req.Role == nil { - return fmt.Errorf("the role in the request is nil") + return merr.WrapErrParameterInvalidMsg("the role in the request is nil") } if err := ValidateRoleName(req.Role.Name); err != nil { return err @@ -5264,11 +5264,17 @@ func (node *Proxy) validOperatePrivilegeV2Params(req *milvuspb.OperatePrivilegeV if err := ValidatePrivilege(req.Grantor.Privilege.Name); err != nil { return err } + // validate built-in privilege group params + if err := ValidateBuiltInPrivilegeGroup(req.Grantor.Privilege.Name, req.DbName, req.CollectionName); err != nil { + return err + } if req.Type != milvuspb.OperatePrivilegeType_Grant && req.Type != milvuspb.OperatePrivilegeType_Revoke { - return fmt.Errorf("the type in the request not grant or revoke") + return merr.WrapErrParameterInvalidMsg("the type in the request not grant or revoke") } - if err := ValidateObjectName(req.DbName); err != nil { - return err + if req.DbName != "" && !util.IsAnyWord(req.DbName) { + if err := ValidateDatabaseName(req.DbName); err != nil { + return err + } } if err := ValidateObjectName(req.CollectionName); err != nil { return err @@ -5287,7 +5293,7 @@ func (node *Proxy) OperatePrivilegeV2(ctx context.Context, req *milvuspb.Operate if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return merr.Status(err), nil } - if err := node.validOperatePrivilegeV2Params(req); err != nil { + if err := node.validateOperatePrivilegeV2Params(req); err != nil { return merr.Status(err), nil } curUser, err := GetCurUserFromContext(ctx) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 3817a643d3c1a..a3aadeb433b2f 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -157,25 +157,25 @@ func validateCollectionNameOrAlias(entity, entityType string) error { func ValidatePrivilegeGroupName(groupName string) error { if groupName == "" { - return merr.WrapErrDatabaseNameInvalid(groupName, "privilege group name couldn't be empty") + return merr.WrapErrPrivilegeGroupNameInvalid("privilege group name should not be empty") } if len(groupName) > Params.ProxyCfg.MaxNameLength.GetAsInt() { - return merr.WrapErrDatabaseNameInvalid(groupName, - fmt.Sprintf("the length of a privilege group name must be less than %d characters", Params.ProxyCfg.MaxNameLength.GetAsInt())) + return merr.WrapErrPrivilegeGroupNameInvalid( + "the length of a privilege group name %s must be less than %s characters", groupName, Params.ProxyCfg.MaxNameLength.GetValue()) } firstChar := groupName[0] if firstChar != '_' && !isAlpha(firstChar) { - return merr.WrapErrDatabaseNameInvalid(groupName, - "the first character of a privilege group name must be an underscore or letter") + return merr.WrapErrPrivilegeGroupNameInvalid( + "the first character of a privilege group name %s must be an underscore or letter", groupName) } for i := 1; i < len(groupName); i++ { c := groupName[i] if c != '_' && !isAlpha(c) && !isNumber(c) { - return merr.WrapErrDatabaseNameInvalid(groupName, - "privilege group name can only contain numbers, letters and underscores") + return merr.WrapErrParameterInvalidMsg( + "privilege group name %s can only contain numbers, letters and underscores", groupName) } } return nil @@ -925,7 +925,7 @@ func ValidateObjectName(entity string) error { if util.IsAnyWord(entity) { return nil } - return validateName(entity, "role name") + return validateName(entity, "object name") } func ValidateObjectType(entity string) error { @@ -939,6 +939,31 @@ func ValidatePrivilege(entity string) error { return validateName(entity, "Privilege") } +func ValidateBuiltInPrivilegeGroup(entity string, dbName string, collectionName string) error { + if !util.IsBuiltinPrivilegeGroup(entity) { + return nil + } + switch { + case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Cluster.String()): + if !util.IsAnyWord(dbName) || !util.IsAnyWord(collectionName) { + return merr.WrapErrParameterInvalidMsg("dbName and collectionName should be * for the cluster level privilege: %s", entity) + } + return nil + case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Database.String()): + if collectionName != "" && collectionName != util.AnyWord { + return merr.WrapErrParameterInvalidMsg("collectionName should be * for the database level privilege: %s", entity) + } + return nil + case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Collection.String()): + if util.IsAnyWord(dbName) && !util.IsAnyWord(collectionName) && collectionName != "" { + return merr.WrapErrParameterInvalidMsg("please specify database name for the collection level privilege: %s", entity) + } + return nil + default: + return nil + } +} + func GetCurUserFromContext(ctx context.Context) (string, error) { return contextutil.GetCurUserFromContext(ctx) } @@ -962,13 +987,16 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string { func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context { dbKey := strings.ToLower(util.HeaderDBName) - if username == "" { - return contextutil.AppendToIncomingContext(ctx, dbKey, dbName) + if dbName != "" { + ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName) } - originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) - authKey := strings.ToLower(util.HeaderAuthorize) - authValue := crypto.Base64Encode(originValue) - return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName) + if username != "" { + originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue) + } + return ctx } func AppendUserInfoForRPC(ctx context.Context) context.Context { diff --git a/internal/querycoordv2/balance/channel_level_score_balancer.go b/internal/querycoordv2/balance/channel_level_score_balancer.go index bee771d58f6ff..ba0a3398bf700 100644 --- a/internal/querycoordv2/balance/channel_level_score_balancer.go +++ b/internal/querycoordv2/balance/channel_level_score_balancer.go @@ -42,7 +42,7 @@ func NewChannelLevelScoreBalancer(scheduler task.Scheduler, nodeManager *session.NodeManager, dist *meta.DistributionManager, meta *meta.Meta, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, ) *ChannelLevelScoreBalancer { return &ChannelLevelScoreBalancer{ ScoreBasedBalancer: NewScoreBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr), diff --git a/internal/querycoordv2/balance/multi_target_balance.go b/internal/querycoordv2/balance/multi_target_balance.go index 9f1de2b026cfc..7466c0e782f10 100644 --- a/internal/querycoordv2/balance/multi_target_balance.go +++ b/internal/querycoordv2/balance/multi_target_balance.go @@ -453,7 +453,7 @@ func (g *randomPlanGenerator) generatePlans() []SegmentAssignPlan { type MultiTargetBalancer struct { *ScoreBasedBalancer dist *meta.DistributionManager - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface } func (b *MultiTargetBalancer) BalanceReplica(replica *meta.Replica) (segmentPlans []SegmentAssignPlan, channelPlans []ChannelAssignPlan) { @@ -561,7 +561,7 @@ func (b *MultiTargetBalancer) genPlanByDistributions(nodeSegments, globalNodeSeg return plans } -func NewMultiTargetBalancer(scheduler task.Scheduler, nodeManager *session.NodeManager, dist *meta.DistributionManager, meta *meta.Meta, targetMgr *meta.TargetManager) *MultiTargetBalancer { +func NewMultiTargetBalancer(scheduler task.Scheduler, nodeManager *session.NodeManager, dist *meta.DistributionManager, meta *meta.Meta, targetMgr meta.TargetManagerInterface) *MultiTargetBalancer { return &MultiTargetBalancer{ ScoreBasedBalancer: NewScoreBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr), dist: dist, diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 53d09d89f3fa4..a664c5885ac63 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -37,7 +37,7 @@ type RowCountBasedBalancer struct { *RoundRobinBalancer dist *meta.DistributionManager meta *meta.Meta - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface } // AssignSegment, when row count based balancer assign segments, it will assign segment to node with least global row count. @@ -366,7 +366,7 @@ func NewRowCountBasedBalancer( nodeManager *session.NodeManager, dist *meta.DistributionManager, meta *meta.Meta, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, ) *RowCountBasedBalancer { return &RowCountBasedBalancer{ RoundRobinBalancer: NewRoundRobinBalancer(scheduler, nodeManager), diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 0e3aad1f78efd..872d273c43d6e 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -42,7 +42,7 @@ func NewScoreBasedBalancer(scheduler task.Scheduler, nodeManager *session.NodeManager, dist *meta.DistributionManager, meta *meta.Meta, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, ) *ScoreBasedBalancer { return &ScoreBasedBalancer{ RowCountBasedBalancer: NewRowCountBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr), diff --git a/internal/querycoordv2/dist/dist_controller.go b/internal/querycoordv2/dist/dist_controller.go index 687e16fe5cfed..5661eaae33413 100644 --- a/internal/querycoordv2/dist/dist_controller.go +++ b/internal/querycoordv2/dist/dist_controller.go @@ -99,7 +99,7 @@ func NewDistController( client session.Cluster, nodeManager *session.NodeManager, dist *meta.DistributionManager, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, scheduler task.Scheduler, syncTargetVersionFn TriggerUpdateTargetVersion, ) *ControllerImpl { diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index c441714a99804..3ae1fdddee728 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -51,7 +51,7 @@ type LoadCollectionJob struct { meta *meta.Meta broker meta.Broker cluster session.Cluster - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface targetObserver *observers.TargetObserver collectionObserver *observers.CollectionObserver nodeMgr *session.NodeManager @@ -64,7 +64,7 @@ func NewLoadCollectionJob( meta *meta.Meta, broker meta.Broker, cluster session.Cluster, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, targetObserver *observers.TargetObserver, collectionObserver *observers.CollectionObserver, nodeMgr *session.NodeManager, @@ -265,7 +265,7 @@ type LoadPartitionJob struct { meta *meta.Meta broker meta.Broker cluster session.Cluster - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface targetObserver *observers.TargetObserver collectionObserver *observers.CollectionObserver nodeMgr *session.NodeManager @@ -278,7 +278,7 @@ func NewLoadPartitionJob( meta *meta.Meta, broker meta.Broker, cluster session.Cluster, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, targetObserver *observers.TargetObserver, collectionObserver *observers.CollectionObserver, nodeMgr *session.NodeManager, diff --git a/internal/querycoordv2/job/job_release.go b/internal/querycoordv2/job/job_release.go index ea6289ba8a8a9..ca903159a5698 100644 --- a/internal/querycoordv2/job/job_release.go +++ b/internal/querycoordv2/job/job_release.go @@ -42,7 +42,7 @@ type ReleaseCollectionJob struct { meta *meta.Meta broker meta.Broker cluster session.Cluster - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface targetObserver *observers.TargetObserver checkerController *checkers.CheckerController proxyManager proxyutil.ProxyClientManagerInterface @@ -54,7 +54,7 @@ func NewReleaseCollectionJob(ctx context.Context, meta *meta.Meta, broker meta.Broker, cluster session.Cluster, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, targetObserver *observers.TargetObserver, checkerController *checkers.CheckerController, proxyManager proxyutil.ProxyClientManagerInterface, @@ -82,8 +82,6 @@ func (job *ReleaseCollectionJob) Execute() error { return nil } - job.meta.CollectionManager.SetReleasing(req.GetCollectionID()) - loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()) toRelease := lo.Map(loadedPartitions, func(partition *meta.Partition, _ int) int64 { return partition.GetPartitionID() @@ -130,7 +128,7 @@ type ReleasePartitionJob struct { meta *meta.Meta broker meta.Broker cluster session.Cluster - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface targetObserver *observers.TargetObserver checkerController *checkers.CheckerController proxyManager proxyutil.ProxyClientManagerInterface @@ -142,7 +140,7 @@ func NewReleasePartitionJob(ctx context.Context, meta *meta.Meta, broker meta.Broker, cluster session.Cluster, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, targetObserver *observers.TargetObserver, checkerController *checkers.CheckerController, proxyManager proxyutil.ProxyClientManagerInterface, diff --git a/internal/querycoordv2/job/undo.go b/internal/querycoordv2/job/undo.go index 21d29538639ac..e1314f0aec6e0 100644 --- a/internal/querycoordv2/job/undo.go +++ b/internal/querycoordv2/job/undo.go @@ -38,12 +38,12 @@ type UndoList struct { ctx context.Context meta *meta.Meta cluster session.Cluster - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface targetObserver *observers.TargetObserver } func NewUndoList(ctx context.Context, meta *meta.Meta, - cluster session.Cluster, targetMgr *meta.TargetManager, targetObserver *observers.TargetObserver, + cluster session.Cluster, targetMgr meta.TargetManagerInterface, targetObserver *observers.TargetObserver, ) *UndoList { return &UndoList{ ctx: ctx, diff --git a/internal/querycoordv2/meta/collection_manager.go b/internal/querycoordv2/meta/collection_manager.go index 5df4286417d3b..f88c664169a38 100644 --- a/internal/querycoordv2/meta/collection_manager.go +++ b/internal/querycoordv2/meta/collection_manager.go @@ -50,7 +50,6 @@ type Collection struct { mut sync.RWMutex refreshNotifier chan struct{} LoadSpan trace.Span - isReleasing bool } func (collection *Collection) SetRefreshNotifier(notifier chan struct{}) { @@ -60,18 +59,6 @@ func (collection *Collection) SetRefreshNotifier(notifier chan struct{}) { collection.refreshNotifier = notifier } -func (collection *Collection) SetReleasing() { - collection.mut.Lock() - defer collection.mut.Unlock() - collection.isReleasing = true -} - -func (collection *Collection) IsReleasing() bool { - collection.mut.RLock() - defer collection.mut.RUnlock() - return collection.isReleasing -} - func (collection *Collection) IsRefreshed() bool { collection.mut.RLock() notifier := collection.refreshNotifier @@ -439,15 +426,6 @@ func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool { return ok } -func (m *CollectionManager) SetReleasing(collectionID typeutil.UniqueID) { - m.rwmutex.Lock() - defer m.rwmutex.Unlock() - coll, ok := m.collections[collectionID] - if ok { - coll.SetReleasing() - } -} - // GetAll returns the collection ID of all loaded collections func (m *CollectionManager) GetAll() []int64 { m.rwmutex.RLock() diff --git a/internal/querycoordv2/mocks/mock_querynode.go b/internal/querycoordv2/mocks/mock_querynode.go index 961d0b64f4970..ae935b1304b60 100644 --- a/internal/querycoordv2/mocks/mock_querynode.go +++ b/internal/querycoordv2/mocks/mock_querynode.go @@ -29,6 +29,61 @@ func (_m *MockQueryNodeServer) EXPECT() *MockQueryNodeServer_Expecter { return &MockQueryNodeServer_Expecter{mock: &_m.Mock} } +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNodeServer) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeServer_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockQueryNodeServer_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *MockQueryNodeServer_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_CheckHealth_Call { + return &MockQueryNodeServer_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} +} + +func (_c *MockQueryNodeServer_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *MockQueryNodeServer_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) + }) + return _c +} + +func (_c *MockQueryNodeServer_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockQueryNodeServer_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeServer_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)) *MockQueryNodeServer_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + // Delete provides a mock function with given fields: _a0, _a1 func (_m *MockQueryNodeServer) Delete(_a0 context.Context, _a1 *querypb.DeleteRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index e641d1c245b2c..b71c56175b8c5 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -47,7 +47,7 @@ type CollectionObserver struct { dist *meta.DistributionManager meta *meta.Meta - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface targetObserver *TargetObserver checkerController *checkers.CheckerController partitionLoadedCount map[int64]int @@ -69,7 +69,7 @@ type LoadTask struct { func NewCollectionObserver( dist *meta.DistributionManager, meta *meta.Meta, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, targetObserver *TargetObserver, checherController *checkers.CheckerController, proxyManager proxyutil.ProxyClientManagerInterface, diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index 3bb1c31924a40..45e6345488b45 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -75,7 +75,7 @@ type TargetObserver struct { cancel context.CancelFunc wg sync.WaitGroup meta *meta.Meta - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface distMgr *meta.DistributionManager broker meta.Broker cluster session.Cluster @@ -101,7 +101,7 @@ type TargetObserver struct { func NewTargetObserver( meta *meta.Meta, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, distMgr *meta.DistributionManager, broker meta.Broker, cluster session.Cluster, diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 0b6aa3d36df49..c3a6f29502fc6 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -53,6 +53,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/tsoutil" @@ -94,7 +95,7 @@ type Server struct { store metastore.QueryCoordCatalog meta *meta.Meta dist *meta.DistributionManager - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface broker meta.Broker // Session @@ -134,6 +135,8 @@ type Server struct { proxyCreator proxyutil.ProxyCreator proxyWatcher proxyutil.ProxyWatcherInterface proxyClientManager proxyutil.ProxyClientManagerInterface + + healthChecker *healthcheck.Checker } func NewQueryCoord(ctx context.Context) (*Server, error) { @@ -346,6 +349,8 @@ func (s *Server) initQueryCoord() error { // Init load status cache meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() + interval := Params.CommonCfg.HealthCheckInterval.GetAsDuration(time.Second) + s.healthChecker = healthcheck.NewChecker(interval, s.healthCheckFn) log.Info("init querycoord done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", s.address)) return err } @@ -489,6 +494,7 @@ func (s *Server) startQueryCoord() error { s.startServerLoop() s.afterStart() + s.healthChecker.Start() s.UpdateStateCode(commonpb.StateCode_Healthy) sessionutil.SaveServerInfo(typeutil.QueryCoordRole, s.session.GetServerID()) return nil @@ -525,7 +531,9 @@ func (s *Server) Stop() error { // FOLLOW the dependence graph: // job scheduler -> checker controller -> task scheduler -> dist controller -> cluster -> session // observers -> dist controller - + if s.healthChecker != nil { + s.healthChecker.Close() + } if s.jobScheduler != nil { log.Info("stop job scheduler...") s.jobScheduler.Stop() diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 5f5ecc8c4b5e6..e97aaa9c729f6 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/cockroachdb/errors" "github.com/samber/lo" @@ -35,7 +36,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/job" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/utils" - "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" @@ -935,16 +936,20 @@ func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthReque return &milvuspb.CheckHealthResponse{Status: merr.Status(err), IsHealthy: false, Reasons: []string{err.Error()}}, nil } - errReasons, err := s.checkNodeHealth(ctx) - if err != nil || len(errReasons) != 0 { - return componentutil.CheckHealthRespWithErrMsg(errReasons...), nil - } + latestCheckResult := s.healthChecker.GetLatestCheckResult() + return healthcheck.GetCheckHealthResponseFromResult(latestCheckResult), nil +} - if err := utils.CheckCollectionsQueryable(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr); err != nil { - log.Warn("some collection is not queryable during health check", zap.Error(err)) - } +func (s *Server) healthCheckFn() *healthcheck.Result { + timeout := Params.CommonCfg.HealthCheckRPCTimeout.GetAsDuration(time.Second) + ctx, cancel := context.WithTimeout(s.ctx, timeout) + defer cancel() - return componentutil.CheckHealthRespWithErr(nil), nil + checkResults := s.broadcastCheckHealth(ctx) + for collectionID, failReason := range utils.CheckCollectionsQueryable(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr) { + checkResults.AppendUnhealthyCollectionMsgs(healthcheck.NewUnhealthyCollectionMsg(collectionID, failReason, healthcheck.CollectionQueryable)) + } + return checkResults } func (s *Server) checkNodeHealth(ctx context.Context) ([]string, error) { @@ -975,6 +980,39 @@ func (s *Server) checkNodeHealth(ctx context.Context) ([]string, error) { return errReasons, err } +func (s *Server) broadcastCheckHealth(ctx context.Context) *healthcheck.Result { + result := healthcheck.NewResult() + wg := sync.WaitGroup{} + wlock := sync.Mutex{} + + for _, node := range s.nodeMgr.GetAll() { + node := node + wg.Add(1) + go func() { + defer wg.Done() + + checkHealthResp, err := s.cluster.CheckHealth(ctx, node.ID()) + if err = merr.CheckRPCCall(checkHealthResp, err); err != nil && !errors.Is(err, merr.ErrServiceUnimplemented) { + err = fmt.Errorf("CheckHealth fails for querynode:%d, %w", node.ID(), err) + wlock.Lock() + result.AppendUnhealthyClusterMsg( + healthcheck.NewUnhealthyClusterMsg(typeutil.QueryNodeRole, node.ID(), err.Error(), healthcheck.NodeHealthCheck)) + wlock.Unlock() + return + } + + if len(checkHealthResp.Reasons) > 0 { + wlock.Lock() + result.AppendResult(healthcheck.GetHealthCheckResultFromResp(checkHealthResp)) + wlock.Unlock() + } + }() + } + + wg.Wait() + return result +} + func (s *Server) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With( zap.String("rgName", req.GetResourceGroup()), diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 7066e22de2238..5d1d73e2ff291 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -48,6 +48,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" @@ -171,6 +172,11 @@ func (suite *ServiceSuite) SetupTest() { } suite.cluster = session.NewMockCluster(suite.T()) suite.cluster.EXPECT().SyncDistribution(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe() + suite.cluster.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(&milvuspb.CheckHealthResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + IsHealthy: true, + Reasons: []string{}, + }, nil).Maybe() suite.jobScheduler = job.NewScheduler() suite.taskScheduler = task.NewMockScheduler(suite.T()) suite.taskScheduler.EXPECT().GetSegmentTaskDelta(mock.Anything, mock.Anything).Return(0).Maybe() @@ -1630,42 +1636,82 @@ func (suite *ServiceSuite) TestCheckHealth() { suite.loadAll() ctx := context.Background() server := suite.server + server.healthChecker = healthcheck.NewChecker(40*time.Millisecond, suite.server.healthCheckFn) + server.healthChecker.Start() + defer server.healthChecker.Close() + + assertCheckHealthResult := func(isHealthy bool) { + resp, err := server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + suite.NoError(err) + suite.Equal(resp.IsHealthy, isHealthy) + if !isHealthy { + suite.NotEmpty(resp.Reasons) + } else { + suite.Empty(resp.Reasons) + } + } + + setNodeSate := func(isHealthy bool, isRPCFail bool) { + var resp *milvuspb.CheckHealthResponse + if isHealthy { + resp = healthcheck.OK() + } else { + resp = healthcheck.GetCheckHealthResponseFromClusterMsg(healthcheck.NewUnhealthyClusterMsg("dn", 1, "check fails", healthcheck.NodeHealthCheck)) + } + resp.Status = &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + if isRPCFail { + resp.Status = &commonpb.Status{ErrorCode: commonpb.ErrorCode_ForceDeny} + } + suite.cluster.EXPECT().CheckHealth(mock.Anything, mock.Anything).Unset() + suite.cluster.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(resp, nil).Maybe() + time.Sleep(50 * time.Millisecond) + } // Test for server is not healthy server.UpdateStateCode(commonpb.StateCode_Initializing) - resp, err := server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) - suite.NoError(err) - suite.Equal(resp.IsHealthy, false) - suite.NotEmpty(resp.Reasons) + assertCheckHealthResult(false) - // Test for components state fail - for _, node := range suite.nodes { - suite.cluster.EXPECT().GetComponentStates(mock.Anything, node).Return( - &milvuspb.ComponentStates{ - State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal}, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - }, - nil).Once() - } + // Test for check health has some error reasons + setNodeSate(false, false) server.UpdateStateCode(commonpb.StateCode_Healthy) - resp, err = server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) - suite.NoError(err) - suite.Equal(resp.IsHealthy, false) - suite.NotEmpty(resp.Reasons) + assertCheckHealthResult(false) + + // Test for check health rpc fail + setNodeSate(true, true) + server.UpdateStateCode(commonpb.StateCode_Healthy) + assertCheckHealthResult(false) - // Test for server is healthy + // Test for check load percentage fail + setNodeSate(true, false) + assertCheckHealthResult(true) + + // Test for check channel ok + for _, collection := range suite.collections { + suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateChannelDist(collection) + } + assertCheckHealthResult(true) + + // Test for check channel fail + tm := meta.NewMockTargetManager(suite.T()) + tm.EXPECT().GetDmChannelsByCollection(mock.Anything, mock.Anything).Return(nil).Maybe() + otm := server.targetMgr + server.targetMgr = tm + assertCheckHealthResult(true) + + // Test for get shard leader fail + server.targetMgr = otm for _, node := range suite.nodes { - suite.cluster.EXPECT().GetComponentStates(mock.Anything, node).Return( - &milvuspb.ComponentStates{ - State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - }, - nil).Once() - } - resp, err = server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + suite.nodeMgr.Stopping(node) + } + + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.UpdateCollectionLoadStatusInterval.Key, "1") + defer paramtable.Get().Reset(paramtable.Get().QueryCoordCfg.UpdateCollectionLoadStatusInterval.Key) + time.Sleep(1500 * time.Millisecond) + resp, err := server.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) suite.NoError(err) suite.Equal(resp.IsHealthy, true) - suite.Empty(resp.Reasons) + suite.NotEmpty(resp.Reasons) } func (suite *ServiceSuite) TestGetShardLeaders() { diff --git a/internal/querycoordv2/session/cluster.go b/internal/querycoordv2/session/cluster.go index 7b6bc316ebe25..569dbb0029469 100644 --- a/internal/querycoordv2/session/cluster.go +++ b/internal/querycoordv2/session/cluster.go @@ -52,6 +52,7 @@ type Cluster interface { GetMetrics(ctx context.Context, nodeID int64, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) SyncDistribution(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) GetComponentStates(ctx context.Context, nodeID int64) (*milvuspb.ComponentStates, error) + CheckHealth(ctx context.Context, nodeID int64) (*milvuspb.CheckHealthResponse, error) Start() Stop() } @@ -272,6 +273,20 @@ func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli types return nil } +func (c *QueryCluster) CheckHealth(ctx context.Context, nodeID int64) (*milvuspb.CheckHealthResponse, error) { + var ( + resp *milvuspb.CheckHealthResponse + err error + ) + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { + resp, err = cli.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + }) + if err1 != nil { + return nil, err1 + } + return resp, err +} + type clients struct { sync.RWMutex clients map[int64]types.QueryNodeClient // nodeID -> client diff --git a/internal/querycoordv2/session/mock_cluster.go b/internal/querycoordv2/session/mock_cluster.go index dbc14c720ce98..136f6c4e23da0 100644 --- a/internal/querycoordv2/session/mock_cluster.go +++ b/internal/querycoordv2/session/mock_cluster.go @@ -27,6 +27,61 @@ func (_m *MockCluster) EXPECT() *MockCluster_Expecter { return &MockCluster_Expecter{mock: &_m.Mock} } +// CheckHealth provides a mock function with given fields: ctx, nodeID +func (_m *MockCluster) CheckHealth(ctx context.Context, nodeID int64) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(ctx, nodeID) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(ctx, nodeID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, nodeID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCluster_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockCluster_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - nodeID int64 +func (_e *MockCluster_Expecter) CheckHealth(ctx interface{}, nodeID interface{}) *MockCluster_CheckHealth_Call { + return &MockCluster_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, nodeID)} +} + +func (_c *MockCluster_CheckHealth_Call) Run(run func(ctx context.Context, nodeID int64)) *MockCluster_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockCluster_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockCluster_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCluster_CheckHealth_Call) RunAndReturn(run func(context.Context, int64) (*milvuspb.CheckHealthResponse, error)) *MockCluster_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + // GetComponentStates provides a mock function with given fields: ctx, nodeID func (_m *MockCluster) GetComponentStates(ctx context.Context, nodeID int64) (*milvuspb.ComponentStates, error) { ret := _m.Called(ctx, nodeID) diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 474dc9a701683..db3acf44779f8 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -58,7 +58,7 @@ type Executor struct { meta *meta.Meta dist *meta.DistributionManager broker meta.Broker - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface cluster session.Cluster nodeMgr *session.NodeManager @@ -70,7 +70,7 @@ type Executor struct { func NewExecutor(meta *meta.Meta, dist *meta.DistributionManager, broker meta.Broker, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, cluster session.Cluster, nodeMgr *session.NodeManager, ) *Executor { diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 4cc328959a733..416d3ea1d4675 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -157,7 +157,7 @@ type taskScheduler struct { distMgr *meta.DistributionManager meta *meta.Meta - targetMgr *meta.TargetManager + targetMgr meta.TargetManagerInterface broker meta.Broker cluster session.Cluster nodeMgr *session.NodeManager @@ -172,7 +172,7 @@ type taskScheduler struct { func NewScheduler(ctx context.Context, meta *meta.Meta, distMgr *meta.DistributionManager, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, broker meta.Broker, cluster session.Cluster, nodeMgr *session.NodeManager, diff --git a/internal/querycoordv2/utils/util.go b/internal/querycoordv2/utils/util.go index e37961c8bc0b4..6bcb80677b5d4 100644 --- a/internal/querycoordv2/utils/util.go +++ b/internal/querycoordv2/utils/util.go @@ -73,13 +73,13 @@ func CheckDelegatorDataReady(nodeMgr *session.NodeManager, targetMgr meta.Target for segmentID, info := range segmentDist { _, exist := leader.Segments[segmentID] if !exist { - log.RatedInfo(10, "leader is not available due to lack of segment", zap.Int64("segmentID", segmentID)) + log.RatedWarn(10, "leader is not available due to lack of segment", zap.Int64("segmentID", segmentID)) return merr.WrapErrSegmentLack(segmentID) } l0WithWrongLocation := info.GetLevel() == datapb.SegmentLevel_L0 && leader.Segments[segmentID].GetNodeID() != leader.ID if l0WithWrongLocation { - log.RatedInfo(10, "leader is not available due to lack of L0 segment", zap.Int64("segmentID", segmentID)) + log.RatedWarn(10, "leader is not available due to lack of L0 segment", zap.Int64("segmentID", segmentID)) return merr.WrapErrSegmentLack(segmentID) } } @@ -108,13 +108,11 @@ func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) erro return nil } -func GetShardLeadersWithChannels(m *meta.Meta, targetMgr *meta.TargetManager, dist *meta.DistributionManager, +func GetShardLeadersWithChannels(m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64, channels map[string]*meta.DmChannel, ) ([]*querypb.ShardLeadersList, error) { ret := make([]*querypb.ShardLeadersList, 0) for _, channel := range channels { - log := log.With(zap.String("channel", channel.GetChannelName())) - var channelErr error leaders := dist.LeaderViewManager.GetByFilter(meta.WithChannelName2LeaderView(channel.GetChannelName())) if len(leaders) == 0 { @@ -132,7 +130,7 @@ func GetShardLeadersWithChannels(m *meta.Meta, targetMgr *meta.TargetManager, di if len(readableLeaders) == 0 { msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName()) - log.Warn(msg, zap.Error(channelErr)) + log.RatedWarn(60, msg, zap.Error(channelErr)) err := merr.WrapErrChannelNotAvailable(channel.GetChannelName(), channelErr.Error()) return nil, err } @@ -169,7 +167,7 @@ func GetShardLeadersWithChannels(m *meta.Meta, targetMgr *meta.TargetManager, di return ret, nil } -func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr *meta.TargetManager, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) { +func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) { if err := checkLoadStatus(ctx, m, collectionID); err != nil { return nil, err } @@ -185,19 +183,24 @@ func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr *meta.TargetMa } // CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection -func CheckCollectionsQueryable(ctx context.Context, m *meta.Meta, targetMgr *meta.TargetManager, dist *meta.DistributionManager, nodeMgr *session.NodeManager) error { - maxInterval := paramtable.Get().QueryCoordCfg.UpdateCollectionLoadStatusInterval.GetAsDuration(time.Minute) +func CheckCollectionsQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager) map[int64]string { + maxInterval := paramtable.Get().QueryCoordCfg.UpdateCollectionLoadStatusInterval.GetAsDuration(time.Second) + checkResult := make(map[int64]string) for _, coll := range m.GetAllCollections() { err := checkCollectionQueryable(ctx, m, targetMgr, dist, nodeMgr, coll) - if err != nil && !coll.IsReleasing() && time.Since(coll.UpdatedAt) >= maxInterval { - return err + // the collection is not queryable, if meet following conditions: + // 1. Some segments are not loaded + // 2. Collection is not starting to release + // 3. The load percentage has not been updated in the last 5 minutes. + if err != nil && m.Exist(coll.CollectionID) && time.Since(coll.UpdatedAt) >= maxInterval { + checkResult[coll.CollectionID] = err.Error() } } - return nil + return checkResult } // checkCollectionQueryable check all channels are watched and all segments are loaded for this collection -func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr *meta.TargetManager, dist *meta.DistributionManager, nodeMgr *session.NodeManager, coll *meta.Collection) error { +func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, coll *meta.Collection) error { collectionID := coll.GetCollectionID() if err := checkLoadStatus(ctx, m, collectionID); err != nil { return err diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 1fa77d3d4ec74..2fd75fd9508ba 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -42,6 +42,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -53,6 +54,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -1438,6 +1440,25 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) ( return merr.Success(), nil } +func (node *QueryNode) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + return &milvuspb.CheckHealthResponse{ + Status: merr.Status(err), + Reasons: []string{err.Error()}, + }, nil + } + defer node.lifetime.Done() + + maxDelay := paramtable.Get().QuotaConfig.MaxTimeTickDelay.GetAsDuration(time.Second) + minTsafeChannel, minTsafe := node.tSafeManager.Min() + if err := ratelimitutil.CheckTimeTickDelay(minTsafeChannel, minTsafe, maxDelay); err != nil { + msg := healthcheck.NewUnhealthyClusterMsg(typeutil.QueryNodeRole, node.GetNodeID(), err.Error(), healthcheck.TimeTickLagExceed) + return healthcheck.GetCheckHealthResponseFromClusterMsg(msg), nil + } + + return healthcheck.OK(), nil +} + // DeleteBatch is the API to apply same delete data into multiple segments. // it's basically same as `Delete` but cost less memory pressure. func (node *QueryNode) DeleteBatch(ctx context.Context, req *querypb.DeleteBatchRequest) (*querypb.DeleteBatchResponse, error) { diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 7c94af1ce861a..ec5bcf54a2e48 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -44,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" + "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/streamrpc" @@ -97,7 +98,7 @@ func (suite *ServiceSuite) SetupSuite() { paramtable.Init() paramtable.Get().Save(paramtable.Get().CommonCfg.GCEnabled.Key, "false") - suite.rootPath = suite.T().Name() + suite.rootPath = path.Join("/tmp/milvus/test", suite.T().Name()) suite.collectionID = 111 suite.collectionName = "test-collection" suite.partitionIDs = []int64{222} @@ -2197,6 +2198,41 @@ func (suite *ServiceSuite) TestLoadPartition() { suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) } +func (suite *ServiceSuite) TestCheckHealth() { + suite.Run("node not healthy", func() { + suite.node.UpdateStateCode(commonpb.StateCode_Abnormal) + + ctx := context.Background() + resp, err := suite.node.CheckHealth(ctx, nil) + suite.NoError(err) + suite.False(merr.Ok(resp.GetStatus())) + suite.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + }) + + suite.Run("exceeded timetick lag on pipeline", func() { + suite.node.tSafeManager = tsafe.NewTSafeReplica() + suite.node.tSafeManager.Add(context.TODO(), "timetick-lag-ch", 1) + ctx := context.Background() + suite.node.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err := suite.node.CheckHealth(ctx, nil) + suite.NoError(err) + suite.True(merr.Ok(resp.GetStatus())) + suite.False(resp.GetIsHealthy()) + suite.NotEmpty(resp.Reasons) + }) + + suite.Run("ok", func() { + ctx := context.Background() + suite.node.UpdateStateCode(commonpb.StateCode_Healthy) + suite.node.tSafeManager = tsafe.NewTSafeReplica() + resp, err := suite.node.CheckHealth(ctx, nil) + suite.NoError(err) + suite.True(merr.Ok(resp.GetStatus())) + suite.True(resp.GetIsHealthy()) + suite.Empty(resp.Reasons) + }) +} + func TestQueryNodeService(t *testing.T) { suite.Run(t, new(ServiceSuite)) } diff --git a/internal/rootcoord/drop_collection_task.go b/internal/rootcoord/drop_collection_task.go index 4d443a1ad3ad7..a5a65b1cd1d85 100644 --- a/internal/rootcoord/drop_collection_task.go +++ b/internal/rootcoord/drop_collection_task.go @@ -57,9 +57,9 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { // dropping collection with `ts1` but a collection exists in catalog with newer ts which is bigger than `ts1`. // fortunately, if ddls are promised to execute in sequence, then everything is OK. The `ts1` will always be latest. collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp) - if errors.Is(err, merr.ErrCollectionNotFound) { + if errors.Is(err, merr.ErrCollectionNotFound) || errors.Is(err, merr.ErrDatabaseNotFound) { // make dropping collection idempotent. - log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName())) + log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName()), zap.String("database", t.Req.GetDbName())) return nil } diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index ca58353f85c85..f866353ba70ae 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -609,6 +609,11 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str dbName = util.DefaultDBName } + db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp) + if err != nil { + return nil, err + } + collectionID, ok := mt.aliases.get(dbName, collectionName) if ok { return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, false) @@ -623,11 +628,6 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName) } - db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp) - if err != nil { - return nil, err - } - // travel meta information from catalog. No need to check time travel logic again, since catalog already did. ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue()) coll, err := mt.catalog.GetCollectionByName(ctx1, db.ID, collectionName, ts) @@ -871,8 +871,6 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum - metrics.RootCoordNumOfPartitions.WithLabelValues().Inc() - log.Ctx(ctx).Info("add partition to meta table", zap.Int64("collection", partition.CollectionID), zap.String("partition", partition.PartitionName), zap.Int64("partitionid", partition.PartitionID), zap.Uint64("ts", partition.PartitionCreatedTimestamp)) @@ -1543,6 +1541,10 @@ func (mt *MetaTable) OperatePrivilegeGroup(groupName string, privileges []*milvu mt.permissionLock.Lock() defer mt.permissionLock.Unlock() + if util.IsBuiltinPrivilegeGroup(groupName) { + return merr.WrapErrParameterInvalidMsg("the privilege group name [%s] is defined by built in privilege groups in system", groupName) + } + // validate input params definedByUsers, err := mt.IsCustomPrivilegeGroup(groupName) if err != nil { diff --git a/internal/rootcoord/meta_table_test.go b/internal/rootcoord/meta_table_test.go index 4192b655c03be..3bf93aeb31cdb 100644 --- a/internal/rootcoord/meta_table_test.go +++ b/internal/rootcoord/meta_table_test.go @@ -537,6 +537,24 @@ func TestMetaTable_getCollectionByIDInternal(t *testing.T) { } func TestMetaTable_GetCollectionByName(t *testing.T) { + t.Run("db not found", func(t *testing.T) { + meta := &MetaTable{ + aliases: newNameDb(), + collID2Meta: map[typeutil.UniqueID]*model.Collection{ + 100: { + State: pb.CollectionState_CollectionCreated, + CreateTime: 99, + Partitions: []*model.Partition{}, + }, + }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, + } + ctx := context.Background() + _, err := meta.GetCollectionByName(ctx, "not_exist", "name", 101) + assert.Error(t, err) + }) t.Run("get by alias", func(t *testing.T) { meta := &MetaTable{ aliases: newNameDb(), @@ -550,6 +568,9 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { }, }, }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, } meta.aliases.insert(util.DefaultDBName, "alias", 100) ctx := context.Background() @@ -574,6 +595,9 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { }, }, }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, } meta.names.insert(util.DefaultDBName, "name", 100) ctx := context.Background() @@ -661,7 +685,13 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { t.Run("get latest version", func(t *testing.T) { ctx := context.Background() - meta := &MetaTable{names: newNameDb(), aliases: newNameDb()} + meta := &MetaTable{ + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, + names: newNameDb(), + aliases: newNameDb(), + } _, err := meta.GetCollectionByName(ctx, "", "not_exist", typeutil.MaxTimestamp) assert.Error(t, err) assert.ErrorIs(t, err, merr.ErrCollectionNotFound) @@ -1880,6 +1910,9 @@ func TestMetaTable_EmtpyDatabaseName(t *testing.T) { collID2Meta: map[typeutil.UniqueID]*model.Collection{ 1: {CollectionID: 1}, }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, } mt.aliases.insert(util.DefaultDBName, "aliases", 1) @@ -2105,6 +2138,8 @@ func TestMetaTable_PrivilegeGroup(t *testing.T) { assert.NoError(t, err) err = mt.OperatePrivilegeGroup("", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup) assert.Error(t, err) + err = mt.OperatePrivilegeGroup("ClusterReadOnly", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup) + assert.Error(t, err) err = mt.OperatePrivilegeGroup("pg3", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup) assert.Error(t, err) _, err = mt.GetPrivilegeGroupRoles("") diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 5cc0b97b72f70..c03a887bf4d6c 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -404,6 +404,7 @@ func newMockProxy() *mockProxy { func newTestCore(opts ...Opt) *Core { c := &Core{ + ctx: context.TODO(), session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}}, } executor := newMockStepExecutor() diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 798eb1b2ec38a..612b3838fa1b6 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -22,7 +22,6 @@ import ( "math/rand" "os" "strconv" - "strings" "sync" "time" @@ -32,7 +31,6 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/atomic" "go.uber.org/zap" - "golang.org/x/sync/errgroup" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -50,6 +48,7 @@ import ( tso2 "github.com/milvus-io/milvus/internal/tso" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" tsoutil2 "github.com/milvus-io/milvus/internal/util/tsoutil" @@ -128,6 +127,7 @@ type Core struct { enableActiveStandBy bool activateFunc func() error + healthChecker *healthcheck.Checker } // --------------------- function -------------------------- @@ -482,6 +482,8 @@ func (c *Core) initInternal() error { return err } + interval := Params.CommonCfg.HealthCheckInterval.GetAsDuration(time.Second) + c.healthChecker = healthcheck.NewChecker(interval, c.healthCheckFn) log.Info("init rootcoord done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", c.address)) return nil } @@ -758,6 +760,7 @@ func (c *Core) startInternal() error { }() c.startServerLoop() + c.healthChecker.Start() c.UpdateStateCode(commonpb.StateCode_Healthy) sessionutil.SaveServerInfo(typeutil.RootCoordRole, c.session.ServerID) logutil.Logger(c.ctx).Info("rootcoord startup successfully") @@ -816,6 +819,10 @@ func (c *Core) revokeSession() { // Stop stops rootCoord. func (c *Core) Stop() error { c.UpdateStateCode(commonpb.StateCode_Abnormal) + if c.healthChecker != nil { + c.healthChecker.Close() + } + c.stopExecutor() c.stopScheduler() if c.proxyWatcher != nil { @@ -2572,43 +2579,21 @@ func (c *Core) isValidPrivilege(ctx context.Context, privilegeName string, objec return fmt.Errorf("not found the privilege name[%s] in object[%s]", privilegeName, object) } -func (c *Core) isValidPrivilegeV2(ctx context.Context, privilegeName, dbName, collectionName string) error { +func (c *Core) isValidPrivilegeV2(ctx context.Context, privilegeName string) error { if util.IsAnyWord(privilegeName) { return nil } - var privilegeLevel string - for group, privileges := range util.BuiltinPrivilegeGroups { - if privilegeName == group || lo.Contains(privileges, privilegeName) { - privilegeLevel = group - break - } - } - if privilegeLevel == "" { - customPrivGroup, err := c.meta.IsCustomPrivilegeGroup(privilegeName) - if err != nil { - return err - } - if customPrivGroup { - return nil - } - return fmt.Errorf("not found the privilege name[%s] in the custom privilege groups", privilegeName) + customPrivGroup, err := c.meta.IsCustomPrivilegeGroup(privilegeName) + if err != nil { + return err } - switch { - case strings.HasPrefix(privilegeLevel, milvuspb.PrivilegeLevel_Cluster.String()): - if !util.IsAnyWord(dbName) || !util.IsAnyWord(collectionName) { - return fmt.Errorf("dbName and collectionName should be * for the cluster level privilege: %s", privilegeName) - } - return nil - case strings.HasPrefix(privilegeLevel, milvuspb.PrivilegeLevel_Database.String()): - if collectionName != "" && collectionName != util.AnyWord { - return fmt.Errorf("collectionName should be empty or * for the database level privilege: %s", privilegeName) - } - return nil - case strings.HasPrefix(privilegeLevel, milvuspb.PrivilegeLevel_Collection.String()): + if customPrivGroup { return nil - default: + } + if util.IsPrivilegeNameDefined(privilegeName) { return nil } + return fmt.Errorf("not found the privilege name[%s]", privilegeName) } // OperatePrivilege operate the privilege, including grant and revoke @@ -2629,26 +2614,27 @@ func (c *Core) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivile return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeFailure), nil } + privName := in.Entity.Grantor.Privilege.Name switch in.Version { case "v2": - if err := c.isValidPrivilegeV2(ctx, in.Entity.Grantor.Privilege.Name, - in.Entity.DbName, in.Entity.ObjectName); err != nil { + if err := c.isValidPrivilegeV2(ctx, privName); err != nil { ctxLog.Error("", zap.Error(err)) return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeFailure), nil } + // set up object type for metastore, to be compatible with v1 version + in.Entity.Object.Name = util.GetObjectType(privName) default: - if err := c.isValidPrivilege(ctx, in.Entity.Grantor.Privilege.Name, in.Entity.Object.Name); err != nil { + if err := c.isValidPrivilege(ctx, privName, in.Entity.Object.Name); err != nil { ctxLog.Error("", zap.Error(err)) return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeFailure), nil } // set up object name if it is global object type and not built in privilege group - if in.Entity.Object.Name == commonpb.ObjectType_Global.String() && !lo.Contains(lo.Keys(util.BuiltinPrivilegeGroups), in.Entity.Grantor.Privilege.Name) { + if in.Entity.Object.Name == commonpb.ObjectType_Global.String() && !util.IsBuiltinPrivilegeGroup(in.Entity.Grantor.Privilege.Name) { in.Entity.ObjectName = util.AnyWord } } - // set up privilege name for metastore - privName := in.Entity.Grantor.Privilege.Name + privName = in.Entity.Grantor.Privilege.Name redoTask := newBaseRedoTask(c.stepExecutor) redoTask.AddSyncStep(NewSimpleStep("operate privilege meta data", func(ctx context.Context) ([]nestedStep, error) { @@ -3030,53 +3016,40 @@ func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) }, nil } - group, ctx := errgroup.WithContext(ctx) - errs := typeutil.NewConcurrentSet[error]() + latestCheckResult := c.healthChecker.GetLatestCheckResult() + return healthcheck.GetCheckHealthResponseFromResult(latestCheckResult), nil +} + +func (c *Core) healthCheckFn() *healthcheck.Result { + timeout := Params.CommonCfg.HealthCheckRPCTimeout.GetAsDuration(time.Second) + ctx, cancel := context.WithTimeout(c.ctx, timeout) + defer cancel() proxyClients := c.proxyClientManager.GetProxyClients() + wg := sync.WaitGroup{} + lock := sync.Mutex{} + result := healthcheck.NewResult() + proxyClients.Range(func(key int64, value types.ProxyClient) bool { nodeID := key proxyClient := value - group.Go(func() error { - sta, err := proxyClient.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) - if err != nil { - errs.Insert(err) - return err - } + wg.Add(1) + go func() { + defer wg.Done() + resp, err := proxyClient.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + err = merr.AnalyzeComponentStateResp(typeutil.ProxyRole, nodeID, resp, err) - err = merr.AnalyzeState("Proxy", nodeID, sta) + lock.Lock() + defer lock.Unlock() if err != nil { - errs.Insert(err) + result.AppendUnhealthyClusterMsg(healthcheck.NewUnhealthyClusterMsg(typeutil.ProxyRole, nodeID, err.Error(), healthcheck.NodeHealthCheck)) } - - return err - }) + }() return true }) - maxDelay := Params.QuotaConfig.MaxTimeTickDelay.GetAsDuration(time.Second) - if maxDelay > 0 { - group.Go(func() error { - err := CheckTimeTickLagExceeded(ctx, c.queryCoord, c.dataCoord, maxDelay) - if err != nil { - errs.Insert(err) - } - return err - }) - } - - err := group.Wait() - if err != nil { - return &milvuspb.CheckHealthResponse{ - Status: merr.Success(), - IsHealthy: false, - Reasons: lo.Map(errs.Collect(), func(e error, i int) string { - return err.Error() - }), - }, nil - } - - return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: true, Reasons: []string{}}, nil + wg.Wait() + return result } func (c *Core) CreatePrivilegeGroup(ctx context.Context, in *milvuspb.CreatePrivilegeGroupRequest) (*commonpb.Status, error) { @@ -3087,12 +3060,12 @@ func (c *Core) CreatePrivilegeGroup(ctx context.Context, in *milvuspb.CreatePriv ctxLog.Debug(method) if err := merr.CheckHealthy(c.GetStateCode()); err != nil { - return merr.Status(err), nil + return merr.StatusWithErrorCode(err, commonpb.ErrorCode_CreatePrivilegeGroupFailure), nil } if err := c.meta.CreatePrivilegeGroup(in.GroupName); err != nil { ctxLog.Warn("fail to create privilege group", zap.Error(err)) - return merr.Status(err), nil + return merr.StatusWithErrorCode(err, commonpb.ErrorCode_CreatePrivilegeGroupFailure), nil } ctxLog.Debug(method + " success") @@ -3110,12 +3083,12 @@ func (c *Core) DropPrivilegeGroup(ctx context.Context, in *milvuspb.DropPrivileg ctxLog.Debug(method) if err := merr.CheckHealthy(c.GetStateCode()); err != nil { - return merr.Status(err), nil + return merr.StatusWithErrorCode(err, commonpb.ErrorCode_DropPrivilegeGroupFailure), nil } if err := c.meta.DropPrivilegeGroup(in.GroupName); err != nil { ctxLog.Warn("fail to drop privilege group", zap.Error(err)) - return merr.Status(err), nil + return merr.StatusWithErrorCode(err, commonpb.ErrorCode_DropPrivilegeGroupFailure), nil } ctxLog.Debug(method + " success") @@ -3302,8 +3275,7 @@ func (c *Core) OperatePrivilegeGroup(ctx context.Context, in *milvuspb.OperatePr if err != nil { errMsg := "fail to execute task when operate privilege group" ctxLog.Warn(errMsg, zap.Error(err)) - status := merr.StatusWithErrorCode(errors.New(errMsg), commonpb.ErrorCode_OperatePrivilegeGroupFailure) - return status, nil + return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeGroupFailure), nil } ctxLog.Debug(method + " success") @@ -3319,13 +3291,17 @@ func (c *Core) expandPrivilegeGroups(grants []*milvuspb.GrantEntity, groups map[ if err != nil { return nil, err } - if objectType := util.GetObjectType(privilegeName); objectType != "" { - grant.Object.Name = objectType + objectType := &milvuspb.ObjectEntity{ + Name: util.GetObjectType(privilegeName), + } + objectName := grant.ObjectName + if objectType.Name == commonpb.ObjectType_Global.String() { + objectName = util.AnyWord } return &milvuspb.GrantEntity{ Role: grant.Role, - Object: grant.Object, - ObjectName: grant.ObjectName, + Object: objectType, + ObjectName: objectName, Grantor: &milvuspb.GrantorEntity{ User: grant.Grantor.User, Privilege: &milvuspb.PrivilegeEntity{ @@ -3338,20 +3314,16 @@ func (c *Core) expandPrivilegeGroups(grants []*milvuspb.GrantEntity, groups map[ for _, grant := range grants { privName := grant.Grantor.Privilege.Name - if privGroup, exists := groups[privName]; !exists { - newGrant, err := createGrantEntity(grant, privName) + privGroup, exists := groups[privName] + if !exists { + privGroup = []*milvuspb.PrivilegeEntity{{Name: privName}} + } + for _, priv := range privGroup { + newGrant, err := createGrantEntity(grant, priv.Name) if err != nil { return nil, err } newGrants = append(newGrants, newGrant) - } else { - for _, priv := range privGroup { - newGrant, err := createGrantEntity(grant, priv.Name) - if err != nil { - return nil, err - } - newGrants = append(newGrants, newGrant) - } } } // uniq by role + object + object name + grantor user + privilege name + db name diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index 1ea416a4ce4b6..59cd1d95409d4 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -32,7 +32,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" - "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" @@ -40,6 +39,7 @@ import ( mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" "github.com/milvus-io/milvus/internal/util/dependency" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" + "github.com/milvus-io/milvus/internal/util/healthcheck" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/util" @@ -49,7 +49,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tikv" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -1453,65 +1452,6 @@ func TestRootCoord_AlterCollection(t *testing.T) { } func TestRootCoord_CheckHealth(t *testing.T) { - getQueryCoordMetricsFunc := func(tt typeutil.Timestamp) (*milvuspb.GetMetricsResponse, error) { - clusterTopology := metricsinfo.QueryClusterTopology{ - ConnectedNodes: []metricsinfo.QueryNodeInfos{ - { - QuotaMetrics: &metricsinfo.QueryNodeQuotaMetrics{ - Fgm: metricsinfo.FlowGraphMetric{ - MinFlowGraphChannel: "ch1", - MinFlowGraphTt: tt, - NumFlowGraph: 1, - }, - }, - }, - }, - } - - resp, _ := metricsinfo.MarshalTopology(metricsinfo.QueryCoordTopology{Cluster: clusterTopology}) - return &milvuspb.GetMetricsResponse{ - Status: merr.Success(), - Response: resp, - ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryCoordRole, 0), - }, nil - } - - getDataCoordMetricsFunc := func(tt typeutil.Timestamp) (*milvuspb.GetMetricsResponse, error) { - clusterTopology := metricsinfo.DataClusterTopology{ - ConnectedDataNodes: []metricsinfo.DataNodeInfos{ - { - QuotaMetrics: &metricsinfo.DataNodeQuotaMetrics{ - Fgm: metricsinfo.FlowGraphMetric{ - MinFlowGraphChannel: "ch1", - MinFlowGraphTt: tt, - NumFlowGraph: 1, - }, - }, - }, - }, - } - - resp, _ := metricsinfo.MarshalTopology(metricsinfo.DataCoordTopology{Cluster: clusterTopology}) - return &milvuspb.GetMetricsResponse{ - Status: merr.Success(), - Response: resp, - ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, 0), - }, nil - } - - querynodeTT := tsoutil.ComposeTSByTime(time.Now().Add(-1*time.Minute), 0) - datanodeTT := tsoutil.ComposeTSByTime(time.Now().Add(-2*time.Minute), 0) - - dcClient := mocks.NewMockDataCoordClient(t) - dcClient.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(getDataCoordMetricsFunc(datanodeTT)) - qcClient := mocks.NewMockQueryCoordClient(t) - qcClient.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(getQueryCoordMetricsFunc(querynodeTT)) - - errDataCoordClient := mocks.NewMockDataCoordClient(t) - errDataCoordClient.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, errors.New("error")) - errQueryCoordClient := mocks.NewMockQueryCoordClient(t) - errQueryCoordClient.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, errors.New("error")) - t.Run("not healthy", func(t *testing.T) { ctx := context.Background() c := newTestCore(withAbnormalCode()) @@ -1521,25 +1461,13 @@ func TestRootCoord_CheckHealth(t *testing.T) { assert.NotEmpty(t, resp.Reasons) }) - t.Run("ok with disabled tt lag configuration", func(t *testing.T) { - v := Params.QuotaConfig.MaxTimeTickDelay.GetValue() - Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, "-1") - defer Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, v) - - c := newTestCore(withHealthyCode(), withValidProxyManager()) - ctx := context.Background() - resp, err := c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) - assert.NoError(t, err) - assert.Equal(t, true, resp.IsHealthy) - assert.Empty(t, resp.Reasons) - }) - t.Run("proxy health check fail with invalid proxy", func(t *testing.T) { - v := Params.QuotaConfig.MaxTimeTickDelay.GetValue() - Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, "6000") - defer Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, v) + c := newTestCore(withHealthyCode(), withInvalidProxyManager()) + c.healthChecker = healthcheck.NewChecker(40*time.Millisecond, c.healthCheckFn) + c.healthChecker.Start() + defer c.healthChecker.Close() - c := newTestCore(withHealthyCode(), withInvalidProxyManager(), withDataCoord(dcClient), withQueryCoord(qcClient)) + time.Sleep(50 * time.Millisecond) ctx := context.Background() resp, err := c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) @@ -1548,55 +1476,14 @@ func TestRootCoord_CheckHealth(t *testing.T) { assert.NotEmpty(t, resp.Reasons) }) - t.Run("proxy health check fail with get metrics error", func(t *testing.T) { - v := Params.QuotaConfig.MaxTimeTickDelay.GetValue() - Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, "6000") - defer Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, v) - - { - c := newTestCore(withHealthyCode(), - withValidProxyManager(), withDataCoord(dcClient), withQueryCoord(errQueryCoordClient)) - - ctx := context.Background() - resp, err := c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) - assert.NoError(t, err) - assert.Equal(t, false, resp.IsHealthy) - assert.NotEmpty(t, resp.Reasons) - } - - { - c := newTestCore(withHealthyCode(), - withValidProxyManager(), withDataCoord(errDataCoordClient), withQueryCoord(qcClient)) - - ctx := context.Background() - resp, err := c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) - assert.NoError(t, err) - assert.Equal(t, false, resp.IsHealthy) - assert.NotEmpty(t, resp.Reasons) - } - }) - - t.Run("ok with tt lag exceeded", func(t *testing.T) { - v := Params.QuotaConfig.MaxTimeTickDelay.GetValue() - Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, "90") - defer Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, v) - - c := newTestCore(withHealthyCode(), - withValidProxyManager(), withDataCoord(dcClient), withQueryCoord(qcClient)) - ctx := context.Background() - resp, err := c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) - assert.NoError(t, err) - assert.Equal(t, false, resp.IsHealthy) - assert.NotEmpty(t, resp.Reasons) - }) + t.Run("ok", func(t *testing.T) { + c := newTestCore(withHealthyCode(), withValidProxyManager()) + c.healthChecker = healthcheck.NewChecker(40*time.Millisecond, c.healthCheckFn) + c.healthChecker.Start() + defer c.healthChecker.Close() - t.Run("ok with tt lag checking", func(t *testing.T) { - v := Params.QuotaConfig.MaxTimeTickDelay.GetValue() - Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, "600") - defer Params.Save(Params.QuotaConfig.MaxTimeTickDelay.Key, v) + time.Sleep(50 * time.Millisecond) - c := newTestCore(withHealthyCode(), - withValidProxyManager(), withDataCoord(dcClient), withQueryCoord(qcClient)) ctx := context.Background() resp, err := c.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index 88a27e758dac1..c2588bb3b5efe 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -21,10 +21,8 @@ import ( "encoding/json" "fmt" "strconv" - "time" "go.uber.org/zap" - "golang.org/x/sync/errgroup" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/types" @@ -34,7 +32,6 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -284,83 +281,3 @@ func getProxyMetrics(ctx context.Context, proxies proxyutil.ProxyClientManagerIn return ret, nil } - -func CheckTimeTickLagExceeded(ctx context.Context, queryCoord types.QueryCoordClient, dataCoord types.DataCoordClient, maxDelay time.Duration) error { - ctx, cancel := context.WithTimeout(ctx, GetMetricsTimeout) - defer cancel() - - now := time.Now() - group := &errgroup.Group{} - queryNodeTTDelay := typeutil.NewConcurrentMap[string, time.Duration]() - dataNodeTTDelay := typeutil.NewConcurrentMap[string, time.Duration]() - - group.Go(func() error { - queryCoordTopology, err := getQueryCoordMetrics(ctx, queryCoord) - if err != nil { - return err - } - - for _, queryNodeMetric := range queryCoordTopology.Cluster.ConnectedNodes { - qm := queryNodeMetric.QuotaMetrics - if qm != nil { - if qm.Fgm.NumFlowGraph > 0 && qm.Fgm.MinFlowGraphChannel != "" { - minTt, _ := tsoutil.ParseTS(qm.Fgm.MinFlowGraphTt) - delay := now.Sub(minTt) - - if delay.Milliseconds() >= maxDelay.Milliseconds() { - queryNodeTTDelay.Insert(qm.Fgm.MinFlowGraphChannel, delay) - } - } - } - } - return nil - }) - - // get Data cluster metrics - group.Go(func() error { - dataCoordTopology, err := getDataCoordMetrics(ctx, dataCoord) - if err != nil { - return err - } - - for _, dataNodeMetric := range dataCoordTopology.Cluster.ConnectedDataNodes { - dm := dataNodeMetric.QuotaMetrics - if dm != nil { - if dm.Fgm.NumFlowGraph > 0 && dm.Fgm.MinFlowGraphChannel != "" { - minTt, _ := tsoutil.ParseTS(dm.Fgm.MinFlowGraphTt) - delay := now.Sub(minTt) - - if delay.Milliseconds() >= maxDelay.Milliseconds() { - dataNodeTTDelay.Insert(dm.Fgm.MinFlowGraphChannel, delay) - } - } - } - } - return nil - }) - - err := group.Wait() - if err != nil { - return err - } - - var maxLagChannel string - var maxLag time.Duration - findMaxLagChannel := func(params ...*typeutil.ConcurrentMap[string, time.Duration]) { - for _, param := range params { - param.Range(func(k string, v time.Duration) bool { - if v > maxLag { - maxLag = v - maxLagChannel = k - } - return true - }) - } - } - findMaxLagChannel(queryNodeTTDelay, dataNodeTTDelay) - - if maxLag > 0 && len(maxLagChannel) != 0 { - return fmt.Errorf("max timetick lag execced threhold, max timetick lag:%s on channel:%s", maxLag, maxLagChannel) - } - return nil -} diff --git a/internal/util/componentutil/componentutil.go b/internal/util/componentutil/componentutil.go index d89c9db72bd67..93537d24451d8 100644 --- a/internal/util/componentutil/componentutil.go +++ b/internal/util/componentutil/componentutil.go @@ -84,17 +84,3 @@ func WaitForComponentHealthy[T interface { }](ctx context.Context, client T, serviceName string, attempts uint, sleep time.Duration) error { return WaitForComponentStates(ctx, client, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep) } - -func CheckHealthRespWithErr(err error) *milvuspb.CheckHealthResponse { - if err != nil { - return CheckHealthRespWithErrMsg(err.Error()) - } - return CheckHealthRespWithErrMsg() -} - -func CheckHealthRespWithErrMsg(errMsg ...string) *milvuspb.CheckHealthResponse { - if len(errMsg) != 0 { - return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: false, Reasons: errMsg} - } - return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: true, Reasons: []string{}} -} diff --git a/internal/util/healthcheck/checker.go b/internal/util/healthcheck/checker.go new file mode 100644 index 0000000000000..c1f06a8e105d5 --- /dev/null +++ b/internal/util/healthcheck/checker.go @@ -0,0 +1,276 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package healthcheck + +import ( + "fmt" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/json" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +// UnHealthyLevel represents the health level of a system. +type UnHealthyLevel int + +const ( + // Healthy means the system is operating normally. + Healthy UnHealthyLevel = iota + // Warning indicates minor issues that might escalate. + Warning + // Critical indicates major issues that need immediate attention. + Critical + // Fatal indicates system failure. + Fatal +) + +// String returns the string representation of the UnHealthyLevel. +func (u UnHealthyLevel) String() string { + switch u { + case Healthy: + return "Healthy" + case Warning: + return "Warning" + case Critical: + return "Critical" + case Fatal: + return "Fatal" + default: + return "Unknown" + } +} + +type Item int + +const ( + ChannelsWatched Item = iota + CheckpointLagExceed + CollectionQueryable + TimeTickLagExceed + NodeHealthCheck +) + +func getUnhealthyLevel(item Item) UnHealthyLevel { + switch item { + case ChannelsWatched: + return Fatal + case CheckpointLagExceed: + return Fatal + case TimeTickLagExceed: + return Fatal + case NodeHealthCheck: + return Fatal + case CollectionQueryable: + return Critical + default: + panic(fmt.Sprintf("unknown health check item: %d", int(item))) + } +} + +type Result struct { + UnhealthyClusterMsgs []*UnhealthyClusterMsg `json:"unhealthy_cluster_msgs"` + UnhealthyCollectionMsgs []*UnhealthyCollectionMsg `json:"unhealthy_collection_msgs"` +} + +func NewResult() *Result { + return &Result{} +} + +func (r *Result) AppendUnhealthyClusterMsg(unm *UnhealthyClusterMsg) { + r.UnhealthyClusterMsgs = append(r.UnhealthyClusterMsgs, unm) +} + +func (r *Result) AppendUnhealthyCollectionMsgs(udm *UnhealthyCollectionMsg) { + r.UnhealthyCollectionMsgs = append(r.UnhealthyCollectionMsgs, udm) +} + +func (r *Result) AppendResult(other *Result) { + if other == nil { + return + } + r.UnhealthyClusterMsgs = append(r.UnhealthyClusterMsgs, other.UnhealthyClusterMsgs...) + r.UnhealthyCollectionMsgs = append(r.UnhealthyCollectionMsgs, other.UnhealthyCollectionMsgs...) +} + +func (r *Result) IsEmpty() bool { + return len(r.UnhealthyClusterMsgs) == 0 && len(r.UnhealthyCollectionMsgs) == 0 +} + +func (r *Result) IsHealthy() bool { + if len(r.UnhealthyClusterMsgs) == 0 && len(r.UnhealthyCollectionMsgs) == 0 { + return true + } + + for _, unm := range r.UnhealthyClusterMsgs { + if unm.Reason.UnhealthyLevel == Fatal { + return false + } + } + + for _, ucm := range r.UnhealthyCollectionMsgs { + if ucm.Reason.UnhealthyLevel == Fatal { + return false + } + } + + return true +} + +type UnhealthyReason struct { + UnhealthyMsg string `json:"unhealthy_msg"` + UnhealthyLevel UnHealthyLevel `json:"unhealthy_level"` +} + +type UnhealthyClusterMsg struct { + Role string `json:"role"` + NodeID int64 `json:"node_id"` + Reason *UnhealthyReason `json:"reason"` +} + +func NewUnhealthyClusterMsg(role string, nodeID int64, unhealthyMsg string, item Item) *UnhealthyClusterMsg { + return &UnhealthyClusterMsg{ + Role: role, + NodeID: nodeID, + Reason: &UnhealthyReason{ + UnhealthyMsg: unhealthyMsg, + UnhealthyLevel: getUnhealthyLevel(item), + }, + } +} + +type UnhealthyCollectionMsg struct { + DatabaseID int64 `json:"database_id"` + CollectionID int64 `json:"collection_id"` + Reason *UnhealthyReason `json:"reason"` +} + +func NewUnhealthyCollectionMsg(collectionID int64, unhealthyMsg string, item Item) *UnhealthyCollectionMsg { + return &UnhealthyCollectionMsg{ + CollectionID: collectionID, + Reason: &UnhealthyReason{ + UnhealthyMsg: unhealthyMsg, + UnhealthyLevel: getUnhealthyLevel(item), + }, + } +} + +type Checker struct { + sync.RWMutex + interval time.Duration + done chan struct{} + checkFn func() *Result + latestResult *Result + once sync.Once +} + +func NewChecker(interval time.Duration, checkFn func() *Result) *Checker { + checker := &Checker{ + interval: interval, + checkFn: checkFn, + latestResult: NewResult(), + done: make(chan struct{}, 1), + once: sync.Once{}, + } + return checker +} + +func (hc *Checker) Start() { + go func() { + ticker := time.NewTicker(hc.interval) + defer ticker.Stop() + log.Info("start health checker") + for { + select { + case <-ticker.C: + hc.Lock() + hc.latestResult = hc.checkFn() + hc.Unlock() + case <-hc.done: + log.Info("stop health checker") + return + } + } + }() +} + +func (hc *Checker) GetLatestCheckResult() *Result { + hc.RLock() + defer hc.RUnlock() + return hc.latestResult +} + +func (hc *Checker) Close() { + hc.once.Do(func() { + close(hc.done) + }) +} + +func GetHealthCheckResultFromResp(resp *milvuspb.CheckHealthResponse) *Result { + var r Result + if len(resp.Reasons) == 0 { + return &r + } + if len(resp.Reasons) > 1 { + log.Error("invalid check result", zap.Any("reasons", resp.Reasons)) + return &r + } + + err := json.Unmarshal([]byte(resp.Reasons[0]), &r) + if err != nil { + log.Error("unmarshal check result error", zap.String("error", err.Error())) + } + return &r +} + +func GetCheckHealthResponseFromClusterMsg(msg ...*UnhealthyClusterMsg) *milvuspb.CheckHealthResponse { + r := &Result{UnhealthyClusterMsgs: msg} + reasons, err := json.Marshal(r) + if err != nil { + log.Error("marshal check result error", zap.String("error", err.Error())) + } + return &milvuspb.CheckHealthResponse{ + Status: merr.Success(), + IsHealthy: r.IsHealthy(), + Reasons: []string{string(reasons)}, + } +} + +func GetCheckHealthResponseFromResult(checkResult *Result) *milvuspb.CheckHealthResponse { + if checkResult.IsEmpty() { + return OK() + } + + reason, err := json.Marshal(checkResult) + if err != nil { + log.Error("marshal check result error", zap.String("error", err.Error())) + } + + return &milvuspb.CheckHealthResponse{ + Status: merr.Success(), + IsHealthy: checkResult.IsHealthy(), + Reasons: []string{string(reason)}, + } +} + +func OK() *milvuspb.CheckHealthResponse { + return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: true, Reasons: []string{}} +} diff --git a/internal/util/healthcheck/checker_test.go b/internal/util/healthcheck/checker_test.go new file mode 100644 index 0000000000000..7fdcb8cd6e8d1 --- /dev/null +++ b/internal/util/healthcheck/checker_test.go @@ -0,0 +1,60 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package healthcheck + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func TestChecker(t *testing.T) { + expected1 := NewResult() + expected1.AppendUnhealthyClusterMsg(NewUnhealthyClusterMsg("role1", 1, "msg1", ChannelsWatched)) + expected1.AppendUnhealthyClusterMsg(NewUnhealthyClusterMsg("role1", 1, "msg1", ChannelsWatched)) + + expected1.AppendUnhealthyCollectionMsgs(&UnhealthyCollectionMsg{ + CollectionID: 1, + Reason: &UnhealthyReason{ + UnhealthyMsg: "msg2", + UnhealthyLevel: Critical, + }, + }) + + checkFn := func() *Result { + return expected1 + } + checker := NewChecker(100*time.Millisecond, checkFn) + go checker.Start() + + time.Sleep(150 * time.Millisecond) + actual1 := checker.GetLatestCheckResult() + assert.Equal(t, expected1, actual1) + assert.False(t, actual1.IsHealthy()) + + chr := GetCheckHealthResponseFromResult(actual1) + assert.Equal(t, merr.Success(), chr.Status) + assert.Equal(t, actual1.IsHealthy(), chr.IsHealthy) + assert.Equal(t, 1, len(chr.Reasons)) + + actualResult := GetHealthCheckResultFromResp(chr) + assert.Equal(t, actual1, actualResult) + checker.Close() +} diff --git a/internal/util/mock/grpc_datanode_client.go b/internal/util/mock/grpc_datanode_client.go index 13ae355738d80..621a2317e61e8 100644 --- a/internal/util/mock/grpc_datanode_client.go +++ b/internal/util/mock/grpc_datanode_client.go @@ -112,3 +112,7 @@ func (m *GrpcDataNodeClient) QuerySlot(ctx context.Context, req *datapb.QuerySlo func (m *GrpcDataNodeClient) DropCompactionPlan(ctx context.Context, req *datapb.DropCompactionPlanRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcDataNodeClient) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + return &milvuspb.CheckHealthResponse{}, m.Err +} diff --git a/internal/util/mock/grpc_querynode_client.go b/internal/util/mock/grpc_querynode_client.go index dadfb3157897d..5db4bee2a4984 100644 --- a/internal/util/mock/grpc_querynode_client.go +++ b/internal/util/mock/grpc_querynode_client.go @@ -134,6 +134,10 @@ func (m *GrpcQueryNodeClient) DeleteBatch(ctx context.Context, in *querypb.Delet return &querypb.DeleteBatchResponse{}, m.Err } +func (m *GrpcQueryNodeClient) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + return &milvuspb.CheckHealthResponse{}, m.Err +} + func (m *GrpcQueryNodeClient) Close() error { return m.Err } diff --git a/internal/util/wrappers/qn_wrapper.go b/internal/util/wrappers/qn_wrapper.go index def2e64f01bef..c186fdf1596dd 100644 --- a/internal/util/wrappers/qn_wrapper.go +++ b/internal/util/wrappers/qn_wrapper.go @@ -152,6 +152,10 @@ func (qn *qnServerWrapper) DeleteBatch(ctx context.Context, in *querypb.DeleteBa return qn.QueryNode.DeleteBatch(ctx, in) } +func (qn *qnServerWrapper) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + return qn.QueryNode.CheckHealth(ctx, req) +} + func WrapQueryNodeServerAsClient(qn types.QueryNode) types.QueryNodeClient { return &qnServerWrapper{ QueryNode: qn, diff --git a/pkg/util/constant.go b/pkg/util/constant.go index a1323f7ab1e77..c762009d51fd9 100644 --- a/pkg/util/constant.go +++ b/pkg/util/constant.go @@ -332,7 +332,6 @@ var ( MetaStore2API(commonpb.ObjectPrivilege_PrivilegeFlush.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeCompaction.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeLoadBalance.String()), - MetaStore2API(commonpb.ObjectPrivilege_PrivilegeRenameCollection.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeCreateIndex.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeDropIndex.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeCreatePartition.String()), @@ -384,6 +383,7 @@ var ( MetaStore2API(commonpb.ObjectPrivilege_PrivilegeCreateResourceGroup.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeDropResourceGroup.String()), MetaStore2API(commonpb.ObjectPrivilege_PrivilegeUpdateUser.String()), + MetaStore2API(commonpb.ObjectPrivilege_PrivilegeRenameCollection.String()), ) ) @@ -475,5 +475,5 @@ func GetObjectType(privName string) string { return objectType } } - return "" + return commonpb.ObjectType_Global.String() } diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index 1a9862c9c371e..82f79560cb525 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -24,7 +24,6 @@ import ( "fmt" "net" "reflect" - "regexp" "strconv" "strings" "time" @@ -254,13 +253,18 @@ func ConvertChannelName(chanName string, tokenFrom string, tokenTo string) (stri } func GetCollectionIDFromVChannel(vChannelName string) int64 { - re := regexp.MustCompile(`.*_(\d+)v\d+`) - matches := re.FindStringSubmatch(vChannelName) - if len(matches) > 1 { - number, err := strconv.ParseInt(matches[1], 0, 64) - if err == nil { - return number - } + end := strings.LastIndexByte(vChannelName, 'v') + if end <= 0 { + return -1 + } + start := strings.LastIndexByte(vChannelName, '_') + if start <= 0 { + return -1 + } + + collectionIDStr := vChannelName[start+1 : end] + if collectionID, err := strconv.ParseInt(collectionIDStr, 0, 64); err == nil { + return collectionID } return -1 } diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 872ddf19b5eb7..d47fe7399aaaf 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -147,7 +147,8 @@ var ( // this operation is denied because the user not authorized, user need to login in first ErrPrivilegeNotAuthenticated = newMilvusError("not authenticated", 1400, false) // this operation is denied because the user has no permission to do this, user need higher privilege - ErrPrivilegeNotPermitted = newMilvusError("privilege not permitted", 1401, false) + ErrPrivilegeNotPermitted = newMilvusError("privilege not permitted", 1401, false) + ErrPrivilegeGroupInvalidName = newMilvusError("invalid privilege group name", 1402, false) // Alias related ErrAliasNotFound = newMilvusError("alias not found", 1600, false) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index c5041aeaf6e0c..4bf2f4a7293e6 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -299,6 +299,13 @@ func IsHealthyOrStopping(stateCode commonpb.StateCode) error { return CheckHealthy(stateCode) } +func AnalyzeComponentStateResp(role string, nodeID int64, resp *milvuspb.ComponentStates, err error) error { + if err != nil { + return errors.Wrap(err, "service is unhealthy") + } + return AnalyzeState(role, nodeID, resp) +} + func AnalyzeState(role string, nodeID int64, state *milvuspb.ComponentStates) error { if err := Error(state.GetStatus()); err != nil { return errors.Wrapf(err, "%s=%d not healthy", role, nodeID) @@ -455,6 +462,14 @@ func WrapErrDatabaseNameInvalid(database any, msg ...string) error { return err } +func WrapErrPrivilegeGroupNameInvalid(privilegeGroup any, msg ...string) error { + err := wrapFields(ErrPrivilegeGroupInvalidName, value("privilegeGroup", privilegeGroup)) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "->")) + } + return err +} + // Collection related func WrapErrCollectionNotFound(collection any, msg ...string) error { err := wrapFields(ErrCollectionNotFound, value("collection", collection)) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index af858cc0f4bba..86467ac47488c 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -266,6 +266,9 @@ type commonConfig struct { ReadOnlyPrivileges ParamItem `refreshable:"false"` ReadWritePrivileges ParamItem `refreshable:"false"` AdminPrivileges ParamItem `refreshable:"false"` + + HealthCheckInterval ParamItem `refreshable:"true"` + HealthCheckRPCTimeout ParamItem `refreshable:"true"` } func (p *commonConfig) init(base *BaseTable) { @@ -915,6 +918,22 @@ This helps Milvus-CDC synchronize incremental data`, Doc: `use to override the default value of admin privileges, example: "PrivilegeCreateOwnership,PrivilegeDropOwnership"`, } p.AdminPrivileges.Init(base.mgr) + + p.HealthCheckInterval = ParamItem{ + Key: "common.healthcheck.interval.seconds", + Version: "2.4.8", + DefaultValue: "30", + Doc: `health check interval in seconds, default 30s`, + } + p.HealthCheckInterval.Init(base.mgr) + + p.HealthCheckRPCTimeout = ParamItem{ + Key: "common.healthcheck.timeout.seconds", + Version: "2.4.8", + DefaultValue: "10", + Doc: `RPC timeout for health check request`, + } + p.HealthCheckRPCTimeout.Init(base.mgr) } type gpuConfig struct { @@ -2169,9 +2188,9 @@ If this parameter is set false, Milvus simply searches the growing segments with p.UpdateCollectionLoadStatusInterval = ParamItem{ Key: "queryCoord.updateCollectionLoadStatusInterval", Version: "2.4.7", - DefaultValue: "5", + DefaultValue: "300", PanicIfEmpty: true, - Doc: "5m, max interval for updating collection loaded status", + Doc: "300s, max interval of updating collection loaded status for check health", Export: true, } diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index a05f2cabaa42a..3e5b77504e6f1 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -126,6 +126,11 @@ func TestComponentParam(t *testing.T) { params.Save("common.gchelper.minimumGoGC", "80") assert.Equal(t, 80, Params.MinimumGOGCConfig.GetAsInt()) + params.Save("common.healthcheck.interval.seconds", "60") + assert.Equal(t, time.Second*60, Params.HealthCheckInterval.GetAsDuration(time.Second)) + params.Save("common.healthcheck.timeout.seconds", "5") + assert.Equal(t, 5, Params.HealthCheckRPCTimeout.GetAsInt()) + assert.Equal(t, 0, len(Params.ReadOnlyPrivileges.GetAsStrings())) assert.Equal(t, 0, len(Params.ReadWritePrivileges.GetAsStrings())) assert.Equal(t, 0, len(Params.AdminPrivileges.GetAsStrings())) @@ -304,8 +309,8 @@ func TestComponentParam(t *testing.T) { checkHealthRPCTimeout := Params.CheckHealthRPCTimeout.GetAsInt() assert.Equal(t, 2000, checkHealthRPCTimeout) - updateInterval := Params.UpdateCollectionLoadStatusInterval.GetAsDuration(time.Minute) - assert.Equal(t, updateInterval, time.Minute*5) + updateInterval := Params.UpdateCollectionLoadStatusInterval.GetAsDuration(time.Second) + assert.Equal(t, updateInterval, time.Second*300) assert.Equal(t, 0.1, Params.GlobalRowCountFactor.GetAsFloat()) params.Save("queryCoord.globalRowCountFactor", "0.4") diff --git a/pkg/util/ratelimitutil/utils.go b/pkg/util/ratelimitutil/utils.go index ae65eb14c7429..1c70049ca88fd 100644 --- a/pkg/util/ratelimitutil/utils.go +++ b/pkg/util/ratelimitutil/utils.go @@ -16,7 +16,13 @@ package ratelimitutil -import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +import ( + "fmt" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) var QuotaErrorString = map[commonpb.ErrorCode]string{ commonpb.ErrorCode_ForceDeny: "access has been disabled by the administrator", @@ -28,3 +34,14 @@ var QuotaErrorString = map[commonpb.ErrorCode]string{ func GetQuotaErrorString(errCode commonpb.ErrorCode) string { return QuotaErrorString[errCode] } + +func CheckTimeTickDelay(channel string, minTT uint64, maxDelay time.Duration) error { + if channel != "" && maxDelay > 0 { + minTt, _ := tsoutil.ParseTS(minTT) + delay := time.Since(minTt) + if delay.Milliseconds() >= maxDelay.Milliseconds() { + return fmt.Errorf("max timetick lag execced threhold, lag:%s on channel:%s", delay, channel) + } + } + return nil +} diff --git a/tests/integration/rbac/privilege_group_test.go b/tests/integration/rbac/privilege_group_test.go index 1f1b13d41b2f4..7e44c5301e119 100644 --- a/tests/integration/rbac/privilege_group_test.go +++ b/tests/integration/rbac/privilege_group_test.go @@ -69,13 +69,9 @@ func (s *PrivilegeGroupTestSuite) TestBuiltinPrivilegeGroup() { s.True(merr.Ok(resp)) for _, builtinGroup := range lo.Keys(util.BuiltinPrivilegeGroups) { - fmt.Println("!!! builtinGroup: ", builtinGroup) resp, _ = s.operatePrivilege(ctx, roleName, builtinGroup, commonpb.ObjectType_Global.String(), milvuspb.OperatePrivilegeType_Grant) s.False(merr.Ok(resp)) } - - s.validateGrants(ctx, roleName, commonpb.ObjectType_Global.String(), 1) - s.validateGrants(ctx, roleName, commonpb.ObjectType_Collection.String(), 2) } func (s *PrivilegeGroupTestSuite) TestInvalidPrivilegeGroup() { @@ -135,6 +131,22 @@ func (s *PrivilegeGroupTestSuite) TestInvalidPrivilegeGroup() { s.False(merr.Ok(operateResp)) } +func (s *PrivilegeGroupTestSuite) TestInvalidGrantV2() { + ctx := GetContext(context.Background(), "root:123456") + + // invalid operate privilege type + resp, _ := s.operatePrivilegeV2(ctx, "role", "Insert", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType(-1)) + s.False(merr.Ok(resp)) + + // invlaid database name + resp, _ = s.operatePrivilegeV2(ctx, "role", "Insert", "%$", util.AnyWord, milvuspb.OperatePrivilegeType_Grant) + s.False(merr.Ok(resp)) + + // invalid collection name + resp, _ = s.operatePrivilegeV2(ctx, "role", "Insert", util.AnyWord, "%$", milvuspb.OperatePrivilegeType_Grant) + s.False(merr.Ok(resp)) +} + func (s *PrivilegeGroupTestSuite) TestGrantV2BuiltinPrivilegeGroup() { ctx := GetContext(context.Background(), "root:123456") @@ -145,26 +157,12 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2BuiltinPrivilegeGroup() { s.NoError(err) s.True(merr.Ok(createRoleResp)) - resp, _ := s.operatePrivilegeV2(ctx, roleName, "ClusterReadOnly", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "ClusterReadWrite", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "ClusterAdmin", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "DatabaseReadOnly", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "DatabaseReadWrite", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "DatabaseAdmin", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionReadOnly", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionReadWrite", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) - resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionAdmin", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) + for _, builtinGroup := range lo.Keys(util.BuiltinPrivilegeGroups) { + resp, _ := s.operatePrivilegeV2(ctx, roleName, builtinGroup, util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) + s.True(merr.Ok(resp)) + } - resp, _ = s.operatePrivilegeV2(ctx, roleName, "ClusterAdmin", "db1", util.AnyWord, milvuspb.OperatePrivilegeType_Grant) + resp, _ := s.operatePrivilegeV2(ctx, roleName, "ClusterAdmin", "db1", util.AnyWord, milvuspb.OperatePrivilegeType_Grant) s.False(merr.Ok(resp)) resp, _ = s.operatePrivilegeV2(ctx, roleName, "ClusterAdmin", "db1", "col1", milvuspb.OperatePrivilegeType_Grant) s.False(merr.Ok(resp)) @@ -181,7 +179,7 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2BuiltinPrivilegeGroup() { resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionAdmin", "db1", "col1", milvuspb.OperatePrivilegeType_Grant) s.True(merr.Ok(resp)) resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionAdmin", util.AnyWord, "col1", milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) + s.False(merr.Ok(resp)) } func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() { @@ -233,12 +231,14 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() { s.True(merr.Ok(createRoleResp)) resp, _ := s.operatePrivilegeV2(ctx, role, "Insert", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) s.True(merr.Ok(resp)) - s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), 1) + selectResp, _ := s.validateGrants(ctx, role, commonpb.ObjectType_Collection.String(), util.AnyWord, util.AnyWord) + s.Len(selectResp.GetEntities(), 1) // grant group1 to role -> role: insert, group1(query, search) resp, _ = s.operatePrivilegeV2(ctx, role, "group1", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) s.True(merr.Ok(resp)) - s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), 2) + selectResp, _ = s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), util.AnyWord, util.AnyWord) + s.Len(selectResp.GetEntities(), 1) // create group2: query, delete createResp2, err := s.Cluster.Proxy.CreatePrivilegeGroup(ctx, &milvuspb.CreatePrivilegeGroupRequest{ @@ -256,7 +256,8 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() { // grant group2 to role -> role: insert, group1(query, search), group2(query, delete) resp, _ = s.operatePrivilegeV2(ctx, role, "group2", util.AnyWord, util.AnyWord, milvuspb.OperatePrivilegeType_Grant) s.True(merr.Ok(resp)) - s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), 3) + selectResp, _ = s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), util.AnyWord, util.AnyWord) + s.Len(selectResp.GetEntities(), 2) // add query, load to group1 -> group1: query, search, load -> role: insert, group1(query, search, load), group2(query, delete) operatePrivilegeGroup("group1", milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup, []*milvuspb.PrivilegeEntity{ @@ -264,7 +265,8 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() { {Name: "Load"}, }) validatePrivilegeGroup("group1", 3) - s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), 3) + selectResp, _ = s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), util.AnyWord, util.AnyWord) + s.Len(selectResp.GetEntities(), 2) // add different object type privileges to group1 is not allowed resp, _ = s.Cluster.Proxy.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{ @@ -279,7 +281,8 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() { {Name: "Query"}, }) validatePrivilegeGroup("group1", 2) - s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), 3) + selectResp, _ = s.validateGrants(ctx, role, commonpb.ObjectType_Global.String(), util.AnyWord, util.AnyWord) + s.Len(selectResp.GetEntities(), 2) // Drop the group during any role usage will cause error dropResp, _ := s.Cluster.Proxy.DropPrivilegeGroup(ctx, &milvuspb.DropPrivilegeGroupRequest{ @@ -334,42 +337,91 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() { s.True(merr.Ok(dropRoleResp)) } -func (s *PrivilegeGroupTestSuite) operatePrivilege(ctx context.Context, role, privilege, objectType string, operateType milvuspb.OperatePrivilegeType) (*commonpb.Status, error) { +func (s *PrivilegeGroupTestSuite) TestVersionCrossed() { + ctx := GetContext(context.Background(), "root:123456") + + role := "role1" + createRoleResp, err := s.Cluster.Proxy.CreateRole(ctx, &milvuspb.CreateRoleRequest{ + Entity: &milvuspb.RoleEntity{Name: role}, + }) + s.NoError(err) + s.True(merr.Ok(createRoleResp)) resp, err := s.Cluster.Proxy.OperatePrivilege(ctx, &milvuspb.OperatePrivilegeRequest{ - Type: operateType, + Type: milvuspb.OperatePrivilegeType_Grant, Entity: &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{Name: role}, - Object: &milvuspb.ObjectEntity{Name: objectType}, - ObjectName: util.AnyWord, - DbName: util.AnyWord, + Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Collection.String()}, + ObjectName: "collection1", + DbName: "", Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: util.UserRoot}, - Privilege: &milvuspb.PrivilegeEntity{Name: privilege}, + Privilege: &milvuspb.PrivilegeEntity{Name: "Insert"}, }, }, }) - return resp, err + s.NoError(err) + s.True(merr.Ok(resp)) + selectResp, err := s.Cluster.Proxy.SelectGrant(ctx, &milvuspb.SelectGrantRequest{ + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: role}, + Object: nil, + ObjectName: "", + DbName: "", + }, + }) + s.NoError(err) + s.True(merr.Ok(selectResp.GetStatus())) + s.Len(selectResp.GetEntities(), 1) + + revoke, err := s.operatePrivilegeV2(ctx, role, "Insert", "default", "collection1", milvuspb.OperatePrivilegeType_Revoke) + s.NoError(err) + s.True(merr.Ok(revoke)) + + selectResp, err = s.Cluster.Proxy.SelectGrant(ctx, &milvuspb.SelectGrantRequest{ + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: role}, + Object: nil, + ObjectName: "", + DbName: "", + }, + }) + s.NoError(err) + s.True(merr.Ok(selectResp.GetStatus())) + s.Len(selectResp.GetEntities(), 0) } -func (s *PrivilegeGroupTestSuite) operatePrivilegeV2(ctx context.Context, role, privilege, dbName, collectionName string, operateType milvuspb.OperatePrivilegeType) (*commonpb.Status, error) { +func (s *PrivilegeGroupTestSuite) operatePrivilege(ctx context.Context, role, privilege, objectType string, operateType milvuspb.OperatePrivilegeType) (*commonpb.Status, error) { resp, err := s.Cluster.Proxy.OperatePrivilege(ctx, &milvuspb.OperatePrivilegeRequest{ Type: operateType, Entity: &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{Name: role}, - Object: &milvuspb.ObjectEntity{Name: commonpb.ObjectType_Global.String()}, - ObjectName: collectionName, - DbName: dbName, + Object: &milvuspb.ObjectEntity{Name: objectType}, + ObjectName: util.AnyWord, + DbName: util.AnyWord, Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: util.UserRoot}, Privilege: &milvuspb.PrivilegeEntity{Name: privilege}, }, }, - Version: "v2", }) return resp, err } -func (s *PrivilegeGroupTestSuite) validateGrants(ctx context.Context, roleName, objectType string, expectedCount int) { +func (s *PrivilegeGroupTestSuite) operatePrivilegeV2(ctx context.Context, role, privilege, dbName, collectionName string, operateType milvuspb.OperatePrivilegeType) (*commonpb.Status, error) { + resp, err := s.Cluster.Proxy.OperatePrivilegeV2(ctx, &milvuspb.OperatePrivilegeV2Request{ + Role: &milvuspb.RoleEntity{Name: role}, + Grantor: &milvuspb.GrantorEntity{ + User: &milvuspb.UserEntity{Name: util.UserRoot}, + Privilege: &milvuspb.PrivilegeEntity{Name: privilege}, + }, + Type: operateType, + DbName: dbName, + CollectionName: collectionName, + }) + return resp, err +} + +func (s *PrivilegeGroupTestSuite) validateGrants(ctx context.Context, roleName, objectType, database, resource string) (*milvuspb.SelectGrantResponse, error) { resp, err := s.Cluster.Proxy.SelectGrant(ctx, &milvuspb.SelectGrantRequest{ Entity: &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{Name: roleName}, @@ -380,7 +432,13 @@ func (s *PrivilegeGroupTestSuite) validateGrants(ctx context.Context, roleName, }) s.NoError(err) s.True(merr.Ok(resp.GetStatus())) - s.Len(resp.GetEntities(), expectedCount) + return resp, err +} + +func (s *PrivilegeGroupTestSuite) marshalGrants(selectResp *milvuspb.SelectGrantResponse) map[string]*milvuspb.GrantEntity { + return lo.SliceToMap(selectResp.GetEntities(), func(e *milvuspb.GrantEntity) (string, *milvuspb.GrantEntity) { + return fmt.Sprintf("%s-%s-%s-%s", e.Object.Name, e.Grantor.Privilege.Name, e.DbName, e.ObjectName), e + }) } func TestPrivilegeGroup(t *testing.T) {