diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index f02bedc545c30..41f34836f8337 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -154,8 +154,8 @@ func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string } if metricType, ok := indexParamsMap["metric_type"]; !ok { - indexParamsMap["metric_type"] = "BM25" - } else if metricType != "BM25" { + indexParamsMap["metric_type"] = metric.BM25 + } else if metricType != metric.BM25 { return fmt.Errorf("index metric type of BM25 function output field must be BM25, got %s", metricType) } diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 8fd8447914150..06f98608ecbe3 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -263,8 +263,13 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest }() } - // build idf for bm25 search - if req.GetReq().GetMetricType() == metric.BM25 || (req.GetReq().GetMetricType() == metric.EMPTY && sd.isBM25Field[req.GetReq().GetFieldId()]) { + searchAgainstBM25Field := sd.isBM25Field[req.GetReq().GetFieldId()] + + if searchAgainstBM25Field { + if req.GetReq().GetMetricType() != metric.BM25 && req.GetReq().GetMetricType() != metric.EMPTY { + return nil, merr.WrapErrParameterInvalid("BM25", req.GetReq().GetMetricType(), "must use BM25 metric type when searching against BM25 Function output field") + } + // build idf for bm25 search avgdl, err := sd.buildBM25IDF(req.GetReq()) if err != nil { return nil, err diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index cb541aa6e98c1..1bd61173c2aee 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -1267,3 +1267,23 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { assert.Equal(t, sd.Serviceable(), false) assert.Equal(t, sd.Stopped(), true) } + +func TestDelegatorSearchBM25InvalidMetricType(t *testing.T) { + paramtable.Init() + searchReq := &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: commonpbutil.NewMsgBase(), + }, + } + + searchReq.Req.FieldId = 101 + searchReq.Req.MetricType = metric.IP + + sd := &shardDelegator{ + isBM25Field: map[int64]bool{101: true}, + } + + _, err := sd.search(context.Background(), searchReq, []SnapshotItem{}, []SegmentEntry{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "must use BM25 metric type when searching against BM25 Function output field") +}