From 06b5e186a781183e8ba66e94f46e5e4649201199 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Wed, 16 Oct 2024 19:45:23 +0800 Subject: [PATCH] fix: return error if searching against BM25 output field with incorrect metric type (#36910) issue: https://github.com/milvus-io/milvus/issues/36835 currently searching BM25 output field using IP will end up in an error in segcore which is hard to understand. now returning error in query node delegator and provide more useful error message Signed-off-by: Buqian Zheng --- internal/proxy/task_index.go | 4 ++-- internal/querynodev2/delegator/delegator.go | 9 +++++++-- .../querynodev2/delegator/delegator_test.go | 20 +++++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index f02bedc545c30..41f34836f8337 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -154,8 +154,8 @@ func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string } if metricType, ok := indexParamsMap["metric_type"]; !ok { - indexParamsMap["metric_type"] = "BM25" - } else if metricType != "BM25" { + indexParamsMap["metric_type"] = metric.BM25 + } else if metricType != metric.BM25 { return fmt.Errorf("index metric type of BM25 function output field must be BM25, got %s", metricType) } diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 8fd8447914150..06f98608ecbe3 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -263,8 +263,13 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest }() } - // build idf for bm25 search - if req.GetReq().GetMetricType() == metric.BM25 || (req.GetReq().GetMetricType() == metric.EMPTY && sd.isBM25Field[req.GetReq().GetFieldId()]) { + searchAgainstBM25Field := sd.isBM25Field[req.GetReq().GetFieldId()] + + if searchAgainstBM25Field { + if req.GetReq().GetMetricType() != metric.BM25 && req.GetReq().GetMetricType() != metric.EMPTY { + return nil, merr.WrapErrParameterInvalid("BM25", req.GetReq().GetMetricType(), "must use BM25 metric type when searching against BM25 Function output field") + } + // build idf for bm25 search avgdl, err := sd.buildBM25IDF(req.GetReq()) if err != nil { return nil, err diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index cb541aa6e98c1..1bd61173c2aee 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -1267,3 +1267,23 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { assert.Equal(t, sd.Serviceable(), false) assert.Equal(t, sd.Stopped(), true) } + +func TestDelegatorSearchBM25InvalidMetricType(t *testing.T) { + paramtable.Init() + searchReq := &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: commonpbutil.NewMsgBase(), + }, + } + + searchReq.Req.FieldId = 101 + searchReq.Req.MetricType = metric.IP + + sd := &shardDelegator{ + isBM25Field: map[int64]bool{101: true}, + } + + _, err := sd.search(context.Background(), searchReq, []SnapshotItem{}, []SegmentEntry{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "must use BM25 metric type when searching against BM25 Function output field") +}