diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 01e6f70c53675..902a13866d12e 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -380,22 +380,18 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel). Observe(float64(waitTr.ElapseSpan().Milliseconds())) - sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) - if err != nil { - log.Warn("delegator failed to search, current distribution is not serviceable") - return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") - } - defer sd.distribution.Unpin(version) targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs()) if err != nil { return nil, err } // set target partition ids to sub task request req.Req.PartitionIDs = targetPartitions - growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(targetPartitions, segment.PartitionID) - }) - + sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) + if err != nil { + log.Warn("delegator failed to search, current distribution is not serviceable") + return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") + } + defer sd.distribution.Unpin(version) if req.GetReq().GetIsAdvanced() { futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs())) for index, subReq := range req.GetReq().GetSubReqs() { @@ -497,13 +493,6 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). Observe(float64(waitTr.ElapseSpan().Milliseconds())) - sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) - if err != nil { - log.Warn("delegator failed to query, current distribution is not serviceable") - return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") - } - defer sd.distribution.Unpin(version) - targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs()) if err != nil { return err @@ -511,9 +500,13 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq // set target partition ids to sub task request req.Req.PartitionIDs = targetPartitions - growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(targetPartitions, segment.PartitionID) - }) + sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) + if err != nil { + log.Warn("delegator failed to query, current distribution is not serviceable") + return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") + } + defer sd.distribution.Unpin(version) + if req.Req.IgnoreGrowing { growing = []SegmentEntry{} } @@ -570,13 +563,6 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). Observe(float64(waitTr.ElapseSpan().Milliseconds())) - sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) - if err != nil { - log.Warn("delegator failed to query, current distribution is not serviceable") - return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") - } - defer sd.distribution.Unpin(version) - targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs()) if err != nil { return nil, err @@ -584,12 +570,15 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) // set target partition ids to sub task request req.Req.PartitionIDs = targetPartitions + sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) + if err != nil { + log.Warn("delegator failed to query, current distribution is not serviceable") + return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") + } + defer sd.distribution.Unpin(version) + if req.Req.IgnoreGrowing { growing = []SegmentEntry{} - } else { - growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(targetPartitions, segment.PartitionID) - }) } if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() { diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 5b16778a49424..6dc57287e8348 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -426,19 +426,17 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge if req.GetFromShardLeader() { var ( - results []segments.SegmentStats - readSegments []segments.Segment - err error + results []segments.SegmentStats + err error ) switch req.GetScope() { case querypb.DataScope_Historical: - results, readSegments, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs()) + results, err = segments.StatisticsHistorical(ctx, node.manager, req.GetSegmentIDs()) case querypb.DataScope_Streaming: - results, readSegments, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs()) + results, err = segments.StatisticStreaming(ctx, node.manager, req.GetSegmentIDs()) } - defer node.manager.Segment.Unpin(readSegments) if err != nil { log.Warn("get segments statistics failed", zap.Error(err)) return nil, err diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index f7c30bfa4ed62..1df866c71c2d9 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -141,9 +141,9 @@ func retrieveOnSegmentsWithStream(ctx context.Context, mgr *Manager, segments [] } // retrieve will retrieve all the validate target segments -func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]RetrieveSegmentResult, []Segment, error) { +func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]RetrieveSegmentResult, error) { if ctx.Err() != nil { - return nil, nil, ctx.Err() + return nil, ctx.Err() } var err error @@ -156,41 +156,38 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu if req.GetScope() == querypb.DataScope_Historical { SegType = SegmentTypeSealed - retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } else { SegType = SegmentTypeGrowing - retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } - + retrieveSegments, err = manager.Segment.GetAndPin(segIDs) if err != nil { - return nil, retrieveSegments, err + return nil, err } + defer manager.Segment.Unpin(retrieveSegments) result, err := retrieveOnSegments(ctx, manager, retrieveSegments, SegType, plan, req) - return result, retrieveSegments, err + return result, err } // retrieveStreaming will retrieve all the validate target segments and return by stream -func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) ([]Segment, error) { +func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { var err error var SegType commonpb.SegmentState var retrieveSegments []Segment segIDs := req.GetSegmentIDs() - collID := req.Req.GetCollectionID() if req.GetScope() == querypb.DataScope_Historical { SegType = SegmentTypeSealed - retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } else { SegType = SegmentTypeGrowing - retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } - + retrieveSegments, err = manager.Segment.GetAndPin(segIDs) if err != nil { - return retrieveSegments, err + return err } + defer manager.Segment.Unpin(retrieveSegments) err = retrieveOnSegmentsWithStream(ctx, manager, retrieveSegments, SegType, plan, srv) - return retrieveSegments, err + return err } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index 3e12eb334b68a..50208e9c3eac9 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -161,10 +161,9 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { Scope: querypb.DataScope_Historical, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.NoError(err) suite.Len(res[0].Result.Offset, 3) - suite.manager.Segment.Unpin(segments) resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{ RetrievePlan: plan, @@ -187,10 +186,9 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { Scope: querypb.DataScope_Streaming, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.NoError(err) suite.Len(res[0].Result.Offset, 3) - suite.manager.Segment.Unpin(segments) resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{ RetrievePlan: plan, @@ -220,9 +218,8 @@ func (suite *RetrieveSuite) TestRetrieveStreamSealed() { server := client.CreateServer() go func() { - segments, err := RetrieveStream(ctx, suite.manager, plan, req, server) + err := RetrieveStream(ctx, suite.manager, plan, req, server) suite.NoError(err) - suite.manager.Segment.Unpin(segments) server.FinishSend(err) }() @@ -257,10 +254,9 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { Scope: querypb.DataScope_Streaming, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.Error(err) suite.Len(res, 0) - suite.manager.Segment.Unpin(segments) } func (suite *RetrieveSuite) TestRetrieveNilSegment() { @@ -277,10 +273,9 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() { Scope: querypb.DataScope_Historical, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.ErrorIs(err, merr.ErrSegmentNotLoaded) suite.Len(res, 0) - suite.manager.Segment.Unpin(segments) } func TestRetrieve(t *testing.T) { diff --git a/internal/querynodev2/segments/search.go b/internal/querynodev2/segments/search.go index ca5aa76a1ebe7..f9f76c06d7518 100644 --- a/internal/querynodev2/segments/search.go +++ b/internal/querynodev2/segments/search.go @@ -21,6 +21,7 @@ import ( "fmt" "sync" + "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -200,51 +201,73 @@ func searchSegmentsStreamly(ctx context.Context, } // search will search on the historical segments the target segments in historical. -// if segIDs is not specified, it will search on all the historical segments speficied by partIDs. -// if segIDs is specified, it will only search on the segments specified by the segIDs. -// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) { +func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, segIDs []int64) ([]*SearchResult, int64, error) { if ctx.Err() != nil { - return nil, nil, ctx.Err() + return nil, 0, ctx.Err() } - segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs) + segments, err := manager.Segment.GetAndPin(segIDs) if err != nil { - return nil, nil, err + return nil, 0, err } + defer manager.Segment.Unpin(segments) + searchResults, err := searchSegments(ctx, manager, segments, SegmentTypeSealed, searchReq) - return searchResults, segments, err + if err != nil { + return nil, 0, err + } + + relatedDataSize := lo.Reduce(segments, func(acc int64, seg Segment, _ int) int64 { + return acc + GetSegmentRelatedDataSize(seg) + }, 0) + return searchResults, relatedDataSize, nil } // searchStreaming will search all the target segments in streaming // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) { +func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, segIDs []int64) ([]*SearchResult, int64, error) { if ctx.Err() != nil { - return nil, nil, ctx.Err() + return nil, 0, ctx.Err() } - segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs) + segments, err := manager.Segment.GetAndPin(segIDs) if err != nil { - return nil, nil, err + return nil, 0, err } + defer manager.Segment.Unpin(segments) + searchResults, err := searchSegments(ctx, manager, segments, SegmentTypeGrowing, searchReq) - return searchResults, segments, err + if err != nil { + return nil, 0, err + } + + relatedDataSize := lo.Reduce(segments, func(acc int64, seg Segment, _ int) int64 { + return acc + GetSegmentRelatedDataSize(seg) + }, 0) + + return searchResults, relatedDataSize, nil } -func SearchHistoricalStreamly(ctx context.Context, manager *Manager, searchReq *SearchRequest, - collID int64, partIDs []int64, segIDs []int64, streamReduce func(result *SearchResult) error, -) ([]Segment, error) { +func SearchHistoricalStreamly(ctx context.Context, manager *Manager, searchReq *SearchRequest, segIDs []int64, streamReduce func(result *SearchResult) error, +) (int64, error) { if ctx.Err() != nil { - return nil, ctx.Err() + return 0, ctx.Err() } - segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs) + segments, err := manager.Segment.GetAndPin(segIDs) if err != nil { - return segments, err + return 0, err } + defer manager.Segment.Unpin(segments) + err = searchSegmentsStreamly(ctx, manager, segments, searchReq, streamReduce) if err != nil { - return segments, err + return 0, err } - return segments, nil + + relatedDataSize := lo.Reduce(segments, func(acc int64, seg Segment, _ int) int64 { + return acc + GetSegmentRelatedDataSize(seg) + }, 0) + + return relatedDataSize, nil } diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index 11d003769a87e..e97af9cb352b5 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -145,23 +145,17 @@ func (suite *SearchSuite) TestSearchSealed() { searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.sealed.ID()}, mock_segcore.IndexFaissIDMap, nq) suite.NoError(err) - _, segments, err := SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()}) + _, _, err = SearchHistorical(ctx, suite.manager, searchReq, []int64{suite.sealed.ID()}) suite.NoError(err) - suite.manager.Segment.Unpin(segments) } func (suite *SearchSuite) TestSearchGrowing() { searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.growing.ID()}, mock_segcore.IndexFaissIDMap, 1) suite.NoError(err) - res, segments, err := SearchStreaming(context.TODO(), suite.manager, searchReq, - suite.collectionID, - []int64{suite.partitionID}, - []int64{suite.growing.ID()}, - ) + res, _, err := SearchStreaming(context.TODO(), suite.manager, searchReq, []int64{suite.growing.ID()}) suite.NoError(err) suite.Len(res, 1) - suite.manager.Segment.Unpin(segments) } func TestSearch(t *testing.T) { diff --git a/internal/querynodev2/segments/statistics.go b/internal/querynodev2/segments/statistics.go index 0b8f0b8a7b145..a65a06515fb2f 100644 --- a/internal/querynodev2/segments/statistics.go +++ b/internal/querynodev2/segments/statistics.go @@ -56,25 +56,24 @@ func statisticOnSegments(ctx context.Context, segments []Segment, segType Segmen } // statistic will do statistics on the historical segments the target segments in historical. -// if segIDs is not specified, it will search on all the historical segments specified by partIDs. -// if segIDs is specified, it will only search on the segments specified by the segIDs. -// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func StatisticsHistorical(ctx context.Context, manager *Manager, collID int64, partIDs []int64, segIDs []int64) ([]SegmentStats, []Segment, error) { - segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs) +func StatisticsHistorical(ctx context.Context, manager *Manager, segIDs []int64) ([]SegmentStats, error) { + segments, err := manager.Segment.GetAndPin(segIDs) if err != nil { - return nil, nil, err + return nil, err } + defer manager.Segment.Unpin(segments) result, err := statisticOnSegments(ctx, segments, SegmentTypeSealed) - return result, segments, err + return result, err } // StatisticStreaming will do statistics all the target segments in streaming -// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. -func StatisticStreaming(ctx context.Context, manager *Manager, collID int64, partIDs []int64, segIDs []int64) ([]SegmentStats, []Segment, error) { - segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs) +func StatisticStreaming(ctx context.Context, manager *Manager, segIDs []int64) ([]SegmentStats, error) { + segments, err := manager.Segment.GetAndPin(segIDs) if err != nil { - return nil, nil, err + return nil, err } + defer manager.Segment.Unpin(segments) + result, err := statisticOnSegments(ctx, segments, SegmentTypeGrowing) - return result, segments, err + return result, err } diff --git a/internal/querynodev2/segments/validate.go b/internal/querynodev2/segments/validate.go deleted file mode 100644 index c421bc1129744..0000000000000 --- a/internal/querynodev2/segments/validate.go +++ /dev/null @@ -1,97 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segments - -import ( - "context" - "fmt" - - "github.com/cockroachdb/errors" - "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/merr" -) - -func validate(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64, segmentFilter SegmentFilter) ([]Segment, error) { - var searchPartIDs []int64 - - collection := manager.Collection.Get(collectionID) - if collection == nil { - return nil, merr.WrapErrCollectionNotFound(collectionID) - } - - // validate partition - // no partition id specified, get all partition ids in collection - if len(partitionIDs) == 0 { - searchPartIDs = collection.GetPartitions() - } else { - // use request partition ids directly, ignoring meta partition ids - // partitions shall be controlled by delegator distribution - searchPartIDs = partitionIDs - } - - log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) - - // all partitions have been released - if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadPartition { - return nil, errors.Newf("partitions have been released , collectionID = %d target partitionIDs = %v", collectionID, searchPartIDs) - } - - if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadCollection { - return []Segment{}, nil - } - - // validate segment - segments := make([]Segment, 0, len(segmentIDs)) - var err error - defer func() { - if err != nil { - manager.Segment.Unpin(segments) - } - }() - if len(segmentIDs) == 0 { - for _, partID := range searchPartIDs { - segments, err = manager.Segment.GetAndPinBy(WithPartition(partID), segmentFilter) - if err != nil { - return nil, err - } - } - } else { - segments, err = manager.Segment.GetAndPin(segmentIDs, segmentFilter) - if err != nil { - return nil, err - } - for _, segment := range segments { - if !funcutil.SliceContain(searchPartIDs, segment.Partition()) { - err = fmt.Errorf("segment %d belongs to partition %d, which is not in %v", segment.ID(), segment.Partition(), searchPartIDs) - return nil, err - } - } - } - return segments, nil -} - -func validateOnHistorical(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64) ([]Segment, error) { - return validate(ctx, manager, collectionID, partitionIDs, segmentIDs, WithType(SegmentTypeSealed)) -} - -func validateOnStream(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64) ([]Segment, error) { - return validate(ctx, manager, collectionID, partitionIDs, segmentIDs, WithType(SegmentTypeGrowing)) -} diff --git a/internal/querynodev2/tasks/query_stream_task.go b/internal/querynodev2/tasks/query_stream_task.go index cdb3c3c2e99ba..864f31b36d1b4 100644 --- a/internal/querynodev2/tasks/query_stream_task.go +++ b/internal/querynodev2/tasks/query_stream_task.go @@ -74,8 +74,7 @@ func (t *QueryStreamTask) Execute() error { srv := streamrpc.NewResultCacheServer(t.srv, t.minMsgSize, t.maxMsgSize) defer srv.Flush() - segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, srv) - defer t.segmentManager.Segment.Unpin(segments) + err = segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, srv) if err != nil { return err } diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index 7099e83defc22..7392ecaa01c31 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -111,8 +111,7 @@ func (t *QueryTask) Execute() error { return err } defer retrievePlan.Delete() - results, pinnedSegments, err := segments.Retrieve(t.ctx, t.segmentManager, retrievePlan, t.req) - defer t.segmentManager.Segment.Unpin(pinnedSegments) + results, err := segments.Retrieve(t.ctx, t.segmentManager, retrievePlan, t.req) if err != nil { return err } diff --git a/internal/querynodev2/tasks/search_task.go b/internal/querynodev2/tasks/search_task.go index 9c16c13253dfd..69d52bb7e1c3e 100644 --- a/internal/querynodev2/tasks/search_task.go +++ b/internal/querynodev2/tasks/search_task.go @@ -10,7 +10,6 @@ import ( "fmt" "strconv" - "github.com/samber/lo" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" @@ -153,29 +152,26 @@ func (t *SearchTask) Execute() error { defer searchReq.Delete() var ( - results []*segments.SearchResult - searchedSegments []segments.Segment + results []*segments.SearchResult + relatedDataSize int64 ) + if req.GetScope() == querypb.DataScope_Historical { - results, searchedSegments, err = segments.SearchHistorical( + results, relatedDataSize, err = segments.SearchHistorical( t.ctx, t.segmentManager, searchReq, - req.GetReq().GetCollectionID(), - req.GetReq().GetPartitionIDs(), req.GetSegmentIDs(), ) } else if req.GetScope() == querypb.DataScope_Streaming { - results, searchedSegments, err = segments.SearchStreaming( + results, relatedDataSize, err = segments.SearchStreaming( t.ctx, t.segmentManager, searchReq, - req.GetReq().GetCollectionID(), - req.GetReq().GetPartitionIDs(), req.GetSegmentIDs(), ) } - defer t.segmentManager.Segment.Unpin(searchedSegments) + if err != nil { return err } @@ -210,11 +206,6 @@ func (t *SearchTask) Execute() error { } return nil } - - relatedDataSize := lo.Reduce(searchedSegments, func(acc int64, seg segments.Segment, _ int) int64 { - return acc + segments.GetSegmentRelatedDataSize(seg) - }, 0) - tr.RecordSpan() blobs, err := segcore.ReduceSearchResultsAndFillData( t.ctx, @@ -448,16 +439,13 @@ func (t *StreamingSearchTask) Execute() error { reduceErr := t.streamReduce(t.ctx, searchReq.Plan(), result, t.originNqs, t.originTopks) return reduceErr } - pinnedSegments, err := segments.SearchHistoricalStreamly( + relatedDataSize, err = segments.SearchHistoricalStreamly( t.ctx, t.segmentManager, searchReq, - req.GetReq().GetCollectionID(), - nil, req.GetSegmentIDs(), streamReduceFunc) defer segcore.DeleteStreamReduceHelper(t.streamReducer) - defer t.segmentManager.Segment.Unpin(pinnedSegments) if err != nil { log.Error("Failed to search sealed segments streamly", zap.Error(err)) return err @@ -468,20 +456,15 @@ func (t *StreamingSearchTask) Execute() error { log.Error("Failed to get stream-reduced search result") return err } - relatedDataSize = lo.Reduce(pinnedSegments, func(acc int64, seg segments.Segment, _ int) int64 { - return acc + segments.GetSegmentRelatedDataSize(seg) - }, 0) } else if req.GetScope() == querypb.DataScope_Streaming { - results, pinnedSegments, err := segments.SearchStreaming( + var results []*segments.SearchResult + results, relatedDataSize, err = segments.SearchStreaming( t.ctx, t.segmentManager, searchReq, - req.GetReq().GetCollectionID(), - req.GetReq().GetPartitionIDs(), req.GetSegmentIDs(), ) defer segments.DeleteSearchResults(results) - defer t.segmentManager.Segment.Unpin(pinnedSegments) if err != nil { return err } @@ -508,9 +491,6 @@ func (t *StreamingSearchTask) Execute() error { metrics.ReduceSegments, metrics.BatchReduce). Observe(float64(tr.RecordSpan().Milliseconds())) - relatedDataSize = lo.Reduce(pinnedSegments, func(acc int64, seg segments.Segment, _ int) int64 { - return acc + segments.GetSegmentRelatedDataSize(seg) - }, 0) } // 2. reorganize blobs to original search request