diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index cd5a15f7c4310..ad06089c85193 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -228,6 +228,25 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu return nodeReq } +func (sd *shardDelegator) getTargetPartitions(reqPartitions []int64) (searchPartitions []int64, err error) { + existPartitions := sd.collection.GetPartitions() + + // search all loaded partitions if req partition ids not provided + if len(reqPartitions) == 0 { + searchPartitions = existPartitions + return searchPartitions, nil + } + + // use brute search to avoid map struct cost + for _, partition := range reqPartitions { + if !funcutil.SliceContain(existPartitions, partition) { + return nil, merr.WrapErrPartitionNotLoaded(reqPartitions) + } + } + searchPartitions = reqPartitions + return searchPartitions, nil +} + // Search preforms search operation on shard. func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) { log := sd.getLogger(ctx) @@ -302,11 +321,6 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) } - partitions := req.GetReq().GetPartitionIDs() - if !sd.collection.ExistPartition(partitions...) { - return nil, merr.WrapErrPartitionNotLoaded(partitions) - } - // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) @@ -327,9 +341,14 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") } defer sd.distribution.Unpin(version) - existPartitions := sd.collection.GetPartitions() + targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs()) + if err != nil { + return nil, err + } + // set target partition ids to sub task request + req.Req.PartitionIDs = targetPartitions growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(existPartitions, segment.PartitionID) + return funcutil.SliceContain(targetPartitions, segment.PartitionID) }) if req.GetReq().GetIsAdvanced() { @@ -418,11 +437,6 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq return fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) } - partitions := req.GetReq().GetPartitionIDs() - if !sd.collection.ExistPartition(partitions...) { - return merr.WrapErrPartitionNotLoaded(partitions) - } - // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) @@ -443,9 +457,16 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") } defer sd.distribution.Unpin(version) - existPartitions := sd.collection.GetPartitions() + + targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs()) + if err != nil { + return err + } + // set target partition ids to sub task request + req.Req.PartitionIDs = targetPartitions + growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(existPartitions, segment.PartitionID) + return funcutil.SliceContain(targetPartitions, segment.PartitionID) }) if req.Req.IgnoreGrowing { growing = []SegmentEntry{} @@ -489,11 +510,6 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) } - partitions := req.GetReq().GetPartitionIDs() - if !sd.collection.ExistPartition(partitions...) { - return nil, merr.WrapErrPartitionNotLoaded(partitions) - } - // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) @@ -514,12 +530,19 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") } defer sd.distribution.Unpin(version) + + targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs()) + if err != nil { + return nil, err + } + // set target partition ids to sub task request + req.Req.PartitionIDs = targetPartitions + if req.Req.IgnoreGrowing { growing = []SegmentEntry{} } else { - existPartitions := sd.collection.GetPartitions() growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(existPartitions, segment.PartitionID) + return funcutil.SliceContain(targetPartitions, segment.PartitionID) }) } diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index 6629106257dd0..b8fa826b8caea 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -156,10 +156,10 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu if req.GetScope() == querypb.DataScope_Historical { SegType = SegmentTypeSealed - retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs) + retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } else { SegType = SegmentTypeGrowing - retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs) + retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } if err != nil { @@ -181,10 +181,10 @@ func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, r if req.GetScope() == querypb.DataScope_Historical { SegType = SegmentTypeSealed - retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs) + retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } else { SegType = SegmentTypeGrowing - retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs) + retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) } if err != nil { diff --git a/internal/querynodev2/segments/validate.go b/internal/querynodev2/segments/validate.go index 08ee40d268e1a..c421bc1129744 100644 --- a/internal/querynodev2/segments/validate.go +++ b/internal/querynodev2/segments/validate.go @@ -42,17 +42,16 @@ func validate(ctx context.Context, manager *Manager, collectionID int64, partiti if len(partitionIDs) == 0 { searchPartIDs = collection.GetPartitions() } else { - if collection.ExistPartition(partitionIDs...) { - searchPartIDs = partitionIDs - } + // use request partition ids directly, ignoring meta partition ids + // partitions shall be controlled by delegator distribution + searchPartIDs = partitionIDs } log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) // all partitions have been released if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadPartition { - return nil, errors.New("partitions have been released , collectionID = " + - fmt.Sprintln(collectionID) + "target partitionIDs = " + fmt.Sprintln(searchPartIDs)) + return nil, errors.Newf("partitions have been released , collectionID = %d target partitionIDs = %v", collectionID, searchPartIDs) } if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadCollection { diff --git a/internal/querynodev2/tasks/search_task.go b/internal/querynodev2/tasks/search_task.go index a7423ac716d39..0a2118a787290 100644 --- a/internal/querynodev2/tasks/search_task.go +++ b/internal/querynodev2/tasks/search_task.go @@ -160,7 +160,7 @@ func (t *SearchTask) Execute() error { t.segmentManager, searchReq, req.GetReq().GetCollectionID(), - nil, + req.GetReq().GetPartitionIDs(), req.GetSegmentIDs(), ) } else if req.GetScope() == querypb.DataScope_Streaming { @@ -169,7 +169,7 @@ func (t *SearchTask) Execute() error { t.segmentManager, searchReq, req.GetReq().GetCollectionID(), - nil, + req.GetReq().GetPartitionIDs(), req.GetSegmentIDs(), ) } @@ -475,7 +475,7 @@ func (t *StreamingSearchTask) Execute() error { t.segmentManager, searchReq, req.GetReq().GetCollectionID(), - nil, + req.GetReq().GetPartitionIDs(), req.GetSegmentIDs(), ) defer segments.DeleteSearchResults(results)