From 5b00816e74ce8d9ffd8cddefdebadff902c06be7 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Thu, 18 Apr 2024 16:40:37 +0800 Subject: [PATCH] fix search case Signed-off-by: ThreadDao --- test/testcases/search_test.go | 37 ++++++++++++++++------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/test/testcases/search_test.go b/test/testcases/search_test.go index 93b9b7c5..df642b91 100644 --- a/test/testcases/search_test.go +++ b/test/testcases/search_test.go @@ -165,7 +165,7 @@ func TestSearchEmptyCollection(t *testing.T) { // search vector sp, _ := entity.NewIndexHNSWSearchParam(74) - searchRes, _ := mc.Search( + searchRes, errSearch := mc.Search( ctx, collName, []string{common.DefaultPartition}, "", @@ -177,10 +177,8 @@ func TestSearchEmptyCollection(t *testing.T) { common.DefaultTopK, sp, ) - require.Len(t, searchRes, common.DefaultNq) - for _, resultSet := range searchRes { - assert.EqualValues(t, 0, resultSet.ResultCount) - } + common.CheckErr(t, errSearch, true) + common.CheckSearchResult(t, searchRes, common.DefaultNq, 0) } } @@ -216,9 +214,7 @@ func TestSearchEmptyCollection2(t *testing.T) { resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, nv.queryVec, nv.fieldName, nv.metricType, common.DefaultTopK, sp) common.CheckErr(t, errSearch, true) - for _, res := range resSearch { - require.Nil(t, res) - } + common.CheckSearchResult(t, resSearch, common.DefaultNq, 0) } } @@ -278,8 +274,6 @@ func TestSearchEmptyPartitions(t *testing.T) { nq0IDs := searchResult[0].IDs.(*entity.ColumnInt64).Data() nq1IDs := searchResult[1].IDs.(*entity.ColumnInt64).Data() common.CheckSearchResult(t, searchResult, 2, common.DefaultTopK) - log.Println(nq0IDs) - log.Println(nq1IDs) require.Contains(t, nq0IDs, vecColumnDefault.IdsColumn.(*entity.ColumnInt64).Data()[0]) require.Contains(t, nq1IDs, vecColumnPartition.IdsColumn.(*entity.ColumnInt64).Data()[0]) } @@ -1600,7 +1594,6 @@ func TestSearchMultiVectors(t *testing.T) { } func TestSearchSparseVector(t *testing.T) { - t.Skip("https://github.com/milvus-io/milvus-sdk-go/issues/725") t.Parallel() idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"}) idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"}) @@ -1630,10 +1623,15 @@ func TestSearchSparseVector(t *testing.T) { resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName, entity.IP, common.DefaultTopK, sp) common.CheckErr(t, errSearch, true) - common.CheckSearchResult(t, resSearch, common.DefaultNq, common.DefaultTopK) + require.Len(t, resSearch, common.DefaultNq) outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName, common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName} - common.CheckOutputFields(t, resSearch[0].Fields, outputFields) + for _, res := range resSearch { + require.LessOrEqual(t, res.ResultCount, common.DefaultTopK) + if res.ResultCount == common.DefaultTopK { + common.CheckOutputFields(t, resSearch[0].Fields, outputFields) + } + } } } @@ -1688,7 +1686,6 @@ func TestSearchInvalidSparseVector(t *testing.T) { } } -// TODO https://github.com/milvus-io/milvus-sdk-go/issues/725 func TestSearchEmptySparseCollection(t *testing.T) { t.Parallel() idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"}) @@ -1717,11 +1714,7 @@ func TestSearchEmptySparseCollection(t *testing.T) { resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName, entity.IP, common.DefaultTopK, sp) common.CheckErr(t, errSearch, true) - require.Empty(t, resSearch) - //require.Len(t, resSearch, common.DefaultNq) - //for _, res := range resSearch { - // require.Nil(t, res) - //} + common.CheckSearchResult(t, resSearch, common.DefaultNq, 0) } } @@ -1755,12 +1748,16 @@ func TestSearchSparseVectorPagination(t *testing.T) { resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName, entity.IP, common.DefaultTopK, sp) common.CheckErr(t, errSearch, true) + require.Len(t, resSearch, common.DefaultNq) pageSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName, entity.IP, 5, sp, client.WithOffset(5)) common.CheckErr(t, errSearch, true) + require.Len(t, pageSearch, common.DefaultNq) for i := 0; i < len(resSearch); i++ { - require.Equal(t, resSearch[i].IDs.(*entity.ColumnInt64).Data()[5:], pageSearch[i].IDs.(*entity.ColumnInt64).Data()) + if resSearch[i].ResultCount == common.DefaultTopK && pageSearch[i].ResultCount == 5 { + require.Equal(t, resSearch[i].IDs.(*entity.ColumnInt64).Data()[5:], pageSearch[i].IDs.(*entity.ColumnInt64).Data()) + } } } }