Skip to content

Commit

Permalink
enhance: Remove unnecessary segment validation on worker node
Browse files Browse the repository at this point in the history
This PR remove unnecessary segment validation on worker node, also
refine the segment's func call on `pin and unpin`.

Signed-off-by: Wei Liu <[email protected]>
  • Loading branch information
weiliu1031 committed Dec 2, 2024
1 parent 4c623ce commit 5ffacaf
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 232 deletions.
51 changes: 20 additions & 31 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,22 +380,18 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))

sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to search, current distribution is not serviceable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return nil, err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})

sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to search, current distribution is not serviceable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
if req.GetReq().GetIsAdvanced() {
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs()))
for index, subReq := range req.GetReq().GetSubReqs() {
Expand Down Expand Up @@ -497,23 +493,20 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))

sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable")
return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)

targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions

growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable")
return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)

if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
}
Expand Down Expand Up @@ -570,26 +563,22 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))

sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)

targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return nil, err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions

sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)

if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
} else {
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})
}

if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() {
Expand Down
10 changes: 4 additions & 6 deletions internal/querynodev2/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,17 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge

if req.GetFromShardLeader() {
var (
results []segments.SegmentStats
readSegments []segments.Segment
err error
results []segments.SegmentStats
err error
)

switch req.GetScope() {
case querypb.DataScope_Historical:
results, readSegments, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
results, err = segments.StatisticsHistorical(ctx, node.manager, req.GetSegmentIDs())

Check warning on line 435 in internal/querynodev2/handlers.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/handlers.go#L435

Added line #L435 was not covered by tests
case querypb.DataScope_Streaming:
results, readSegments, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
results, err = segments.StatisticStreaming(ctx, node.manager, req.GetSegmentIDs())
}

defer node.manager.Segment.Unpin(readSegments)
if err != nil {
log.Warn("get segments statistics failed", zap.Error(err))
return nil, err
Expand Down
25 changes: 11 additions & 14 deletions internal/querynodev2/segments/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ func retrieveOnSegmentsWithStream(ctx context.Context, mgr *Manager, segments []
}

// retrieve will retrieve all the validate target segments
func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]RetrieveSegmentResult, []Segment, error) {
func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]RetrieveSegmentResult, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
return nil, ctx.Err()

Check warning on line 146 in internal/querynodev2/segments/retrieve.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/retrieve.go#L146

Added line #L146 was not covered by tests
}

var err error
Expand All @@ -156,41 +156,38 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu

if req.GetScope() == querypb.DataScope_Historical {
SegType = SegmentTypeSealed
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
} else {
SegType = SegmentTypeGrowing
retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
}

retrieveSegments, err = manager.Segment.GetAndPin(segIDs)
if err != nil {
return nil, retrieveSegments, err
return nil, err
}
defer manager.Segment.Unpin(retrieveSegments)

result, err := retrieveOnSegments(ctx, manager, retrieveSegments, SegType, plan, req)
return result, retrieveSegments, err
return result, err
}

// retrieveStreaming will retrieve all the validate target segments and return by stream
func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) ([]Segment, error) {
func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {
var err error
var SegType commonpb.SegmentState
var retrieveSegments []Segment

segIDs := req.GetSegmentIDs()
collID := req.Req.GetCollectionID()

if req.GetScope() == querypb.DataScope_Historical {
SegType = SegmentTypeSealed
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
} else {
SegType = SegmentTypeGrowing
retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
}

retrieveSegments, err = manager.Segment.GetAndPin(segIDs)
if err != nil {
return retrieveSegments, err
return err

Check warning on line 187 in internal/querynodev2/segments/retrieve.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/retrieve.go#L187

Added line #L187 was not covered by tests
}
defer manager.Segment.Unpin(retrieveSegments)

err = retrieveOnSegmentsWithStream(ctx, manager, retrieveSegments, SegType, plan, srv)
return retrieveSegments, err
return err
}
15 changes: 5 additions & 10 deletions internal/querynodev2/segments/retrieve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,9 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
Scope: querypb.DataScope_Historical,
}

res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req)
res, err := Retrieve(context.TODO(), suite.manager, plan, req)
suite.NoError(err)
suite.Len(res[0].Result.Offset, 3)
suite.manager.Segment.Unpin(segments)

resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{
RetrievePlan: plan,
Expand All @@ -187,10 +186,9 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
Scope: querypb.DataScope_Streaming,
}

res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req)
res, err := Retrieve(context.TODO(), suite.manager, plan, req)
suite.NoError(err)
suite.Len(res[0].Result.Offset, 3)
suite.manager.Segment.Unpin(segments)

resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{
RetrievePlan: plan,
Expand Down Expand Up @@ -220,9 +218,8 @@ func (suite *RetrieveSuite) TestRetrieveStreamSealed() {
server := client.CreateServer()

go func() {
segments, err := RetrieveStream(ctx, suite.manager, plan, req, server)
err := RetrieveStream(ctx, suite.manager, plan, req, server)
suite.NoError(err)
suite.manager.Segment.Unpin(segments)
server.FinishSend(err)
}()

Expand Down Expand Up @@ -257,10 +254,9 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
Scope: querypb.DataScope_Streaming,
}

res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req)
res, err := Retrieve(context.TODO(), suite.manager, plan, req)
suite.Error(err)
suite.Len(res, 0)
suite.manager.Segment.Unpin(segments)
}

func (suite *RetrieveSuite) TestRetrieveNilSegment() {
Expand All @@ -277,10 +273,9 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
Scope: querypb.DataScope_Historical,
}

res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req)
res, err := Retrieve(context.TODO(), suite.manager, plan, req)
suite.ErrorIs(err, merr.ErrSegmentNotLoaded)
suite.Len(res, 0)
suite.manager.Segment.Unpin(segments)
}

func TestRetrieve(t *testing.T) {
Expand Down
65 changes: 44 additions & 21 deletions internal/querynodev2/segments/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"sync"

"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -200,51 +201,73 @@ func searchSegmentsStreamly(ctx context.Context,
}

// search will search on the historical segments the target segments in historical.
// if segIDs is not specified, it will search on all the historical segments speficied by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
func SearchHistorical(ctx context.Context, manager *Manager, searchReq *SearchRequest, segIDs []int64) ([]*SearchResult, int64, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
return nil, 0, ctx.Err()

Check warning on line 206 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L206

Added line #L206 was not covered by tests
}

segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
segments, err := manager.Segment.GetAndPin(segIDs)
if err != nil {
return nil, nil, err
return nil, 0, err
}
defer manager.Segment.Unpin(segments)

searchResults, err := searchSegments(ctx, manager, segments, SegmentTypeSealed, searchReq)
return searchResults, segments, err
if err != nil {
return nil, 0, err
}

relatedDataSize := lo.Reduce(segments, func(acc int64, seg Segment, _ int) int64 {
return acc + GetSegmentRelatedDataSize(seg)
}, 0)
return searchResults, relatedDataSize, nil
}

// searchStreaming will search all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, collID int64, partIDs []int64, segIDs []int64) ([]*SearchResult, []Segment, error) {
func SearchStreaming(ctx context.Context, manager *Manager, searchReq *SearchRequest, segIDs []int64) ([]*SearchResult, int64, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
return nil, 0, ctx.Err()

Check warning on line 230 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L230

Added line #L230 was not covered by tests
}

segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs)
segments, err := manager.Segment.GetAndPin(segIDs)
if err != nil {
return nil, nil, err
return nil, 0, err

Check warning on line 235 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L235

Added line #L235 was not covered by tests
}
defer manager.Segment.Unpin(segments)

searchResults, err := searchSegments(ctx, manager, segments, SegmentTypeGrowing, searchReq)
return searchResults, segments, err
if err != nil {
return nil, 0, err
}

Check warning on line 242 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L241-L242

Added lines #L241 - L242 were not covered by tests

relatedDataSize := lo.Reduce(segments, func(acc int64, seg Segment, _ int) int64 {
return acc + GetSegmentRelatedDataSize(seg)
}, 0)

return searchResults, relatedDataSize, nil
}

func SearchHistoricalStreamly(ctx context.Context, manager *Manager, searchReq *SearchRequest,
collID int64, partIDs []int64, segIDs []int64, streamReduce func(result *SearchResult) error,
) ([]Segment, error) {
func SearchHistoricalStreamly(ctx context.Context, manager *Manager, searchReq *SearchRequest, segIDs []int64, streamReduce func(result *SearchResult) error,
) (int64, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
return 0, ctx.Err()

Check warning on line 254 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L254

Added line #L254 was not covered by tests
}

segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs)
segments, err := manager.Segment.GetAndPin(segIDs)
if err != nil {
return segments, err
return 0, err

Check warning on line 259 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L259

Added line #L259 was not covered by tests
}
defer manager.Segment.Unpin(segments)

err = searchSegmentsStreamly(ctx, manager, segments, searchReq, streamReduce)
if err != nil {
return segments, err
return 0, err

Check warning on line 265 in internal/querynodev2/segments/search.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/search.go#L265

Added line #L265 was not covered by tests
}
return segments, nil

relatedDataSize := lo.Reduce(segments, func(acc int64, seg Segment, _ int) int64 {
return acc + GetSegmentRelatedDataSize(seg)
}, 0)

return relatedDataSize, nil
}
10 changes: 2 additions & 8 deletions internal/querynodev2/segments/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,17 @@ func (suite *SearchSuite) TestSearchSealed() {
searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.sealed.ID()}, mock_segcore.IndexFaissIDMap, nq)
suite.NoError(err)

_, segments, err := SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()})
_, _, err = SearchHistorical(ctx, suite.manager, searchReq, []int64{suite.sealed.ID()})
suite.NoError(err)
suite.manager.Segment.Unpin(segments)
}

func (suite *SearchSuite) TestSearchGrowing() {
searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.growing.ID()}, mock_segcore.IndexFaissIDMap, 1)
suite.NoError(err)

res, segments, err := SearchStreaming(context.TODO(), suite.manager, searchReq,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.growing.ID()},
)
res, _, err := SearchStreaming(context.TODO(), suite.manager, searchReq, []int64{suite.growing.ID()})
suite.NoError(err)
suite.Len(res, 1)
suite.manager.Segment.Unpin(segments)
}

func TestSearch(t *testing.T) {
Expand Down
Loading

0 comments on commit 5ffacaf

Please sign in to comment.