Skip to content

Commit

Permalink
enhance: support recall estimation
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg committed Dec 5, 2024
1 parent 1f8299f commit e205a4d
Show file tree
Hide file tree
Showing 12 changed files with 323 additions and 29 deletions.
2 changes: 2 additions & 0 deletions internal/proto/internal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ message SearchRequest {
int64 offset = 21;
common.ConsistencyLevel consistency_level = 22;
bool is_topk_reduce = 26;
bool is_recall_evaluation = 27;
}

message SubSearchResults {
Expand Down Expand Up @@ -158,6 +159,7 @@ message SearchResults {
bool is_advanced = 16;
int64 all_search_count = 17;
bool is_topk_reduce = 18;
bool is_recall_evaluation = 19;
}

message CostAggregation {
Expand Down
41 changes: 30 additions & 11 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2887,12 +2887,13 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
optimizedSearch := true
resultSizeInsufficient := false
isTopkReduce := false
isRecallEvaluation := false
err2 := retry.Handle(ctx, func() (bool, error) {
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
if merr.Ok(rsp.GetStatus()) && resultSizeInsufficient && isTopkReduce && optimizedSearch && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() {
rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false)
if merr.Ok(rsp.GetStatus()) && optimizedSearch && resultSizeInsufficient && isTopkReduce && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() {
// without optimize search
optimizedSearch = false
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false)

Check warning on line 2896 in internal/proxy/impl.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/impl.go#L2896

Added line #L2896 was not covered by tests
metrics.ProxyRetrySearchCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
Expand All @@ -2910,6 +2911,23 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rsp.GetStatus())
}
// search for ground truth and compute recall
if isRecallEvaluation && merr.Ok(rsp.GetStatus()) {
var rspGT *milvuspb.SearchResults
rspGT, _, _, _, err = node.search(ctx, request, false, true)
metrics.ProxyRecallSearchCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.GetCollectionName(),
).Inc()
if merr.Ok(rspGT.GetStatus()) {
return false, computeRecall(rsp.GetResults(), rspGT.GetResults())
}
if errors.Is(merr.Error(rspGT.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rspGT.GetStatus())
}
return false, merr.Error(rspGT.GetStatus())

Check warning on line 2929 in internal/proxy/impl.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/impl.go#L2916-L2929

Added lines #L2916 - L2929 were not covered by tests
}
return false, nil
})
if err2 != nil {
Expand All @@ -2918,7 +2936,7 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
return rsp, err
}

func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool, isRecallEvaluation bool) (*milvuspb.SearchResults, bool, bool, bool, error) {
metrics.GetStats(ctx).
SetNodeID(paramtable.GetNodeID()).
SetInboundLabel(metrics.SearchLabel).
Expand All @@ -2933,7 +2951,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}

method := "Search"
Expand All @@ -2954,7 +2972,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
if err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil

Check warning on line 2975 in internal/proxy/impl.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/impl.go#L2975

Added line #L2975 was not covered by tests
}

request.PlaceholderGroup = placeholderGroupBytes
Expand All @@ -2968,8 +2986,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
IsTopkReduce: optimizedSearch,
ReqID: paramtable.GetNodeID(),
IsTopkReduce: optimizedSearch,
IsRecallEvaluation: isRecallEvaluation,
},
request: request,
tr: timerecord.NewTimeRecorder("search"),
Expand Down Expand Up @@ -3023,7 +3042,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}
tr.CtxRecord(ctx, "search request enqueue")

Expand All @@ -3049,7 +3068,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}

span := tr.CtxRecord(ctx, "wait search result")
Expand Down Expand Up @@ -3106,7 +3125,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v))
}
}
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, nil
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, qt.isRecallEvaluation, nil
}

func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
Expand Down
6 changes: 6 additions & 0 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type searchTask struct {
mustUsePartitionKey bool
resultSizeInsufficient bool
isTopkReduce bool
isRecallEvaluation bool

userOutputFields []string
userDynamicFields []string
Expand Down Expand Up @@ -620,10 +621,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
t.queryChannelsTs = make(map[string]uint64)
t.relatedDataSize = 0
isTopkReduce := false
isRecallEvaluation := false
for _, r := range toReduceResults {
if r.GetIsTopkReduce() {
isTopkReduce = true
}
if r.GetIsRecallEvaluation() {
isRecallEvaluation = true
}

Check warning on line 631 in internal/proxy/task_search.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/task_search.go#L630-L631

Added lines #L630 - L631 were not covered by tests
t.relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize()
for ch, ts := range r.GetChannelsMvcc() {
t.queryChannelsTs[ch] = ts
Expand Down Expand Up @@ -702,6 +707,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
t.resultSizeInsufficient = resultSizeInsufficient
t.isTopkReduce = isTopkReduce
t.isRecallEvaluation = isRecallEvaluation
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()

Expand Down
69 changes: 69 additions & 0 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,75 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int
return pkNames, fieldIDs
}

func recallCal[T string | int64](results []T, gts []T) float32 {
hit := 0
total := 0
for _, r := range results {
total++
for _, gt := range gts {
if r == gt {
hit++
break
}
}
}
return float32(hit) / float32(total)
}

func computeRecall(results *schemapb.SearchResultData, gts *schemapb.SearchResultData) error {
if results.GetNumQueries() != gts.GetNumQueries() {
return fmt.Errorf("num of queries is inconsistent between search results(%d) and ground truth(%d)", results.GetNumQueries(), gts.GetNumQueries())
}

switch results.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
switch gts.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
currentResultIndex := int64(0)
currentGTIndex := int64(0)
recalls := make([]float32, 0, results.GetNumQueries())
for i := 0; i < int(results.GetNumQueries()); i++ {
currentResultTopk := results.GetTopks()[i]
currentGTTopk := gts.GetTopks()[i]
recalls = append(recalls, recallCal(results.GetIds().GetIntId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
gts.GetIds().GetIntId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
currentResultIndex += currentResultTopk
currentGTIndex += currentGTTopk
}
results.Recalls = recalls
return nil
case *schemapb.IDs_StrId:
return fmt.Errorf("pk type is inconsistent between search results(int64) and ground truth(string)")
default:
return fmt.Errorf("unsupported pk type")

Check warning on line 1122 in internal/proxy/util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/util.go#L1121-L1122

Added lines #L1121 - L1122 were not covered by tests
}

case *schemapb.IDs_StrId:
switch gts.GetIds().GetIdField().(type) {
case *schemapb.IDs_StrId:
currentResultIndex := int64(0)
currentGTIndex := int64(0)
recalls := make([]float32, 0, results.GetNumQueries())
for i := 0; i < int(results.GetNumQueries()); i++ {
currentResultTopk := results.GetTopks()[i]
currentGTTopk := gts.GetTopks()[i]
recalls = append(recalls, recallCal(results.GetIds().GetStrId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
gts.GetIds().GetStrId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
currentResultIndex += currentResultTopk
currentGTIndex += currentGTTopk
}
results.Recalls = recalls
return nil
case *schemapb.IDs_IntId:
return fmt.Errorf("pk type is inconsistent between search results(string) and ground truth(int64)")
default:
return fmt.Errorf("unsupported pk type")

Check warning on line 1144 in internal/proxy/util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/util.go#L1143-L1144

Added lines #L1143 - L1144 were not covered by tests
}
default:
return fmt.Errorf("unsupported pk type")

Check warning on line 1147 in internal/proxy/util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/util.go#L1146-L1147

Added lines #L1146 - L1147 were not covered by tests
}
}

// Support wildcard in output fields:
//
// "*" - all fields
Expand Down
162 changes: 162 additions & 0 deletions internal/proxy/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2672,3 +2672,165 @@ func TestValidateLoadFieldsList(t *testing.T) {
})
}
}

func TestComputeRecall(t *testing.T) {
t.Run("normal case1", func(t *testing.T) {
result1 := &schemapb.SearchResultData{
NumQueries: 3,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"11", "9", "8", "5", "3", "1"},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
}

gt := &schemapb.SearchResultData{
NumQueries: 3,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"11", "10", "8", "5", "3", "1"},
},
},
},
Scores: []float32{1.1, 0.98, 0.8, 0.5, 0.3, 0.1},
Topks: []int64{2, 2, 2},
}

err := computeRecall(result1, gt)
assert.NoError(t, err)
assert.Equal(t, result1.Recalls[0], float32(0.5))
assert.Equal(t, result1.Recalls[1], float32(1.0))
assert.Equal(t, result1.Recalls[2], float32(1.0))
})

t.Run("normal case2", func(t *testing.T) {
result1 := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

gt := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 6, 5, 4, 1, 34, 23, 22, 20},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

err := computeRecall(result1, gt)
assert.NoError(t, err)
assert.Equal(t, result1.Recalls[0], float32(0.6))
assert.Equal(t, result1.Recalls[1], float32(0.8))
})

t.Run("not match size", func(t *testing.T) {
result1 := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

gt := &schemapb.SearchResultData{
NumQueries: 1,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 6, 5, 4},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3},
Topks: []int64{5},
}

err := computeRecall(result1, gt)
assert.Error(t, err)
})

t.Run("not match type1", func(t *testing.T) {
result1 := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

gt := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"11", "10", "8", "5", "3", "1", "23", "22", "21", "20"},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

err := computeRecall(result1, gt)
assert.Error(t, err)
})

t.Run("not match type2", func(t *testing.T) {
result1 := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: []string{"11", "10", "8", "5", "3", "1", "23", "22", "21", "20"},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

gt := &schemapb.SearchResultData{
NumQueries: 2,
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
},
},
},
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
Topks: []int64{5, 5},
}

err := computeRecall(result1, gt)
assert.Error(t, err)
})
}
Loading

0 comments on commit e205a4d

Please sign in to comment.