Skip to content

Commit

Permalink
fix: Unify loaded partition check to delegator (#36879)
Browse files Browse the repository at this point in the history
Related to #36370

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Oct 15, 2024
1 parent 2291d0c commit ba25320
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 33 deletions.
65 changes: 44 additions & 21 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,25 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
return nodeReq
}

func (sd *shardDelegator) getTargetPartitions(reqPartitions []int64) (searchPartitions []int64, err error) {
existPartitions := sd.collection.GetPartitions()

// search all loaded partitions if req partition ids not provided
if len(reqPartitions) == 0 {
searchPartitions = existPartitions
return searchPartitions, nil
}

// use brute search to avoid map struct cost
for _, partition := range reqPartitions {
if !funcutil.SliceContain(existPartitions, partition) {
return nil, merr.WrapErrPartitionNotLoaded(reqPartitions)
}
}
searchPartitions = reqPartitions
return searchPartitions, nil
}

// Search preforms search operation on shard.
func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
Expand Down Expand Up @@ -302,11 +321,6 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}

partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return nil, merr.WrapErrPartitionNotLoaded(partitions)
}

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
Expand All @@ -327,9 +341,14 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()
targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return nil, err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})

if req.GetReq().GetIsAdvanced() {
Expand Down Expand Up @@ -418,11 +437,6 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
return fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}

partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return merr.WrapErrPartitionNotLoaded(partitions)
}

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
Expand All @@ -443,9 +457,16 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()

targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions

growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
Expand Down Expand Up @@ -489,11 +510,6 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}

partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return nil, merr.WrapErrPartitionNotLoaded(partitions)
}

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
Expand All @@ -514,12 +530,19 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)

targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return nil, err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions

if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
} else {
existPartitions := sd.collection.GetPartitions()
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})
}

Expand Down
8 changes: 4 additions & 4 deletions internal/querynodev2/segments/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu

if req.GetScope() == querypb.DataScope_Historical {
SegType = SegmentTypeSealed
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
} else {
SegType = SegmentTypeGrowing
retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
}

if err != nil {
Expand All @@ -181,10 +181,10 @@ func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, r

if req.GetScope() == querypb.DataScope_Historical {
SegType = SegmentTypeSealed
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
} else {
SegType = SegmentTypeGrowing
retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
}

if err != nil {
Expand Down
9 changes: 4 additions & 5 deletions internal/querynodev2/segments/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,16 @@ func validate(ctx context.Context, manager *Manager, collectionID int64, partiti
if len(partitionIDs) == 0 {
searchPartIDs = collection.GetPartitions()
} else {
if collection.ExistPartition(partitionIDs...) {
searchPartIDs = partitionIDs
}
// use request partition ids directly, ignoring meta partition ids
// partitions shall be controlled by delegator distribution
searchPartIDs = partitionIDs
}

log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))

// all partitions have been released
if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadPartition {
return nil, errors.New("partitions have been released , collectionID = " +
fmt.Sprintln(collectionID) + "target partitionIDs = " + fmt.Sprintln(searchPartIDs))
return nil, errors.Newf("partitions have been released , collectionID = %d target partitionIDs = %v", collectionID, searchPartIDs)
}

if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadCollection {
Expand Down
6 changes: 3 additions & 3 deletions internal/querynodev2/tasks/search_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (t *SearchTask) Execute() error {
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetReq().GetPartitionIDs(),
req.GetSegmentIDs(),
)
} else if req.GetScope() == querypb.DataScope_Streaming {
Expand All @@ -169,7 +169,7 @@ func (t *SearchTask) Execute() error {
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetReq().GetPartitionIDs(),
req.GetSegmentIDs(),
)
}
Expand Down Expand Up @@ -475,7 +475,7 @@ func (t *StreamingSearchTask) Execute() error {
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetReq().GetPartitionIDs(),
req.GetSegmentIDs(),
)
defer segments.DeleteSearchResults(results)
Expand Down

0 comments on commit ba25320

Please sign in to comment.