Skip to content

Commit

Permalink
enhance: add ts support for iterator
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Sep 29, 2024
1 parent a6545b2 commit b8af35d
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 62 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ replace (
github.com/milvus-io/milvus/pkg => ./pkg
github.com/streamnative/pulsarctl => github.com/xiaofan-luan/pulsarctl v0.5.1
github.com/tecbot/gorocksdb => github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b // indirect
github.com/milvus-io/milvus-proto/go-api/v2 => /home/hanchun/Documents/project/milvus-proto/go-api
)

exclude github.com/apache/pulsar-client-go/oauth2 v0.0.0-20211108044248-fe3b7c4e445b
62 changes: 37 additions & 25 deletions internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,29 @@ func (r *rankParams) String() string {
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
}

type SearchInfo struct {
planInfo *planpb.QueryInfo
offset int64
parseError error
isIterator bool
}

// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
var topK int64
isAdvanced := rankParams != nil
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s is required", TopKKey)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s is required", TopKKey)}
}
topK = externalLimit
} else {
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)}
}
topK = externalLimit
} else {
Expand All @@ -106,7 +113,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
} else {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)}
}
}

Expand All @@ -117,20 +124,20 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)}
}

if offset != 0 {
if err := validateLimit(offset); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)}
}
}
}
}

queryTopK := topK + offset
if err := validateLimit(queryTopK); err != nil {
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)}
}

// 2. parse metrics type
Expand All @@ -147,11 +154,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb

roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
}

if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
}

// 4. parse search param str
Expand All @@ -168,30 +175,35 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
} else {
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
if groupByInfo.err != nil {
return nil, 0, groupByInfo.err
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
}
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
}

// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")}
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")
}

return &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
}, offset, nil
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")}
}

return &SearchInfo{
planInfo: &planpb.QueryInfo{
Topk: queryTopK,
MetricType: metricType,
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
},
offset: offset,
isIterator: isIterator == "True",
parseError: nil,
}
}

func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
Expand Down
14 changes: 13 additions & 1 deletion internal/proxy/task_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type queryParams struct {
limit int64
offset int64
reduceType reduce.IReduceType
isIterator bool
}

// translateToOutputFieldIDs translates output fields name to output fields id.
Expand Down Expand Up @@ -178,7 +179,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
// if limit is not provided
if err != nil {
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType, isIterator: isIterator}, nil
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
Expand All @@ -203,6 +204,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
limit: limit,
offset: offset,
reduceType: reduceType,
isIterator: isIterator,
}, nil
}

Expand Down Expand Up @@ -461,6 +463,12 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
}
}
t.GuaranteeTimestamp = guaranteeTs
// need modify mvccTs and guaranteeTs for iterator specially
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
log.Info("hc===Set ts", zap.Uint64("t.MvccTimestamp", t.MvccTimestamp), zap.Uint64("t.GuaranteeTimestamp", t.GuaranteeTimestamp))
}

deadline, ok := t.TraceCtx().Deadline()
if ok {
Expand Down Expand Up @@ -542,6 +550,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
t.result.OutputFields = t.userOutputFields
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(tr.RecordSpan().Milliseconds()))

if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = t.BeginTs()
}
log.Debug("Query PostExecute done")
return nil
}
Expand Down
42 changes: 27 additions & 15 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ type searchTask struct {
reScorers []reScorer
rankParams *rankParams
groupScorer func(group *Group) error

isIterator bool
}

func (t *searchTask) CanSkipAllocTimestamp() bool {
Expand Down Expand Up @@ -249,6 +251,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
t.SearchRequest.ConsistencyLevel = consistencyLevel
if t.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
log.Info("hc===Set ts for searchRequest", zap.Uint64("t.MvccTimestamp", t.MvccTimestamp), zap.Uint64("t.GuaranteeTimestamp", t.GuaranteeTimestamp))
}

if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
Expand Down Expand Up @@ -351,7 +358,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
for index, subReq := range t.request.GetSubReqs() {
plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl())
if err != nil {
return err
}
Expand Down Expand Up @@ -443,11 +450,12 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
// fetch search_growing from search param

plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl())
if err != nil {
return err
}

t.isIterator = isIterator
t.SearchRequest.Offset = offset

if t.partitionKeyMode {
Expand Down Expand Up @@ -490,38 +498,38 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
return nil
}

func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) {
func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) {
annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params)
if err != nil || len(annsFieldName) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return nil, nil, 0, errors.New(AnnsFieldKey + " not found in schema")
return nil, nil, 0, false, errors.New(AnnsFieldKey + " not found in schema")
}

if enableMultipleVectorFields && len(vecFields) > 1 {
return nil, nil, 0, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
return nil, nil, 0, false, errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if parseErr != nil {
return nil, nil, 0, parseErr
searchInfo := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if searchInfo.parseError != nil {
return nil, nil, 0, false, searchInfo.parseError
}
annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column")
}
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo)
if planErr != nil {
log.Warn("failed to create query plan", zap.Error(planErr),
zap.String("dsl", dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo))
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr)
}
log.Debug("create query plan",
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo))
return plan, queryInfo, offset, nil
zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo))
return plan, searchInfo.planInfo, searchInfo.offset, searchInfo.isIterator, nil
}

func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) {
Expand Down Expand Up @@ -714,6 +722,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName()
if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = t.BeginTs()
}

metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))

Expand Down
38 changes: 19 additions & 19 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2235,9 +2235,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
searchInfo := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
if test.description == "offsetParam" {
assert.Equal(t, targetOffset, offset)
}
Expand All @@ -2256,11 +2256,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
limit: externalLimit,
}

info, offset, err := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, int64(10), info.GetTopk())
assert.Equal(t, int64(0), offset)
searchInfo := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
assert.Equal(t, int64(10), searchInfo.planInfo.GetTopk())
assert.Equal(t, int64(0), searchInfo.offset)
})

t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) {
Expand Down Expand Up @@ -2309,15 +2309,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Value: "true",
})

info, _, err := parseSearchInfo(params, schema, testRankParams)
assert.NoError(t, err)
assert.NotNil(t, info)
searchInfo := parseSearchInfo(params, schema, testRankParams)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)

// all group_by related parameters should be aligned to parameters
// set by main request rather than inner sub request
assert.Equal(t, int64(101), info.GetGroupByFieldId())
assert.Equal(t, int64(3), info.GetGroupSize())
assert.False(t, info.GetGroupStrictSize())
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
assert.False(t, searchInfo.planInfo.GetGroupStrictSize())
})

t.Run("parseSearchInfo error", func(t *testing.T) {
Expand Down Expand Up @@ -2399,12 +2399,12 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)
searchInfo := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.Nil(t, searchInfo.planInfo)
assert.Zero(t, searchInfo.offset)

t.Logf("err=%s", err.Error())
t.Logf("err=%s", searchInfo.parseError)
})
}
})
Expand Down
4 changes: 2 additions & 2 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return err
Expand Down Expand Up @@ -473,7 +473,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return nil, err
Expand Down

0 comments on commit b8af35d

Please sign in to comment.