From 8514063559dce4b8f3a5011125a9212f17f62449 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Thu, 18 Apr 2024 20:22:36 +0800 Subject: [PATCH] Fix groupby search case Signed-off-by: ThreadDao --- test/testcases/groupby_search_test.go | 53 ++++++++++++++------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/test/testcases/groupby_search_test.go b/test/testcases/groupby_search_test.go index 04ee32fd..0b263213 100644 --- a/test/testcases/groupby_search_test.go +++ b/test/testcases/groupby_search_test.go @@ -134,20 +134,19 @@ func TestSearchGroupByFloatDefault(t *testing.T) { // verify each topK entity is the top1 of the whole group hitsNum := 0 total := 0 - for _, rs := range resGroupBy { - for i := 0; i < rs.ResultCount; i++ { - groupByValue, _ := rs.GroupByValue.Get(i) - pkValue, _ := rs.IDs.GetAsInt64(i) + for i := 0; i < common.DefaultNq; i++ { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) var expr string if groupByField == "varchar" { expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue) } else { expr = fmt.Sprintf("%s == %v", groupByField, groupByValue) } - // search filter with groupByValue is the top1 resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, - groupByField}, queryVec, common.DefaultFloatVecFieldName, metricType, 1, sp) + groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp) filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) //log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", // groupByField, groupByValue, pkValue, filterTop1Pk) @@ -209,27 +208,30 @@ func TestGroupBySearchSparseVector(t *testing.T) { // verify each topK entity is the top1 of the whole group hitsNum := 0 total := 0 - for _, rs := range resGroupBy { - for i := 0; i < rs.ResultCount; i++ { - groupByValue, _ := rs.GroupByValue.Get(i) - pkValue, _ := rs.IDs.GetAsInt64(i) - expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue) - - // search filter with groupByValue is the top1 - resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, - common.DefaultVarcharFieldName}, queryVec, common.DefaultSparseVecFieldName, entity.IP, 1, sp) - filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) - if filterTop1Pk == pkValue { - hitsNum += 1 + for i := 0; i < common.DefaultNq; i++ { + if resGroupBy[i].ResultCount > 0 { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) + expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue) + // search filter with groupByValue is the top1 + resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, + common.DefaultVarcharFieldName}, []entity.Vector{queryVec[i]}, common.DefaultSparseVecFieldName, entity.IP, 1, sp) + filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) + log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", + common.DefaultVarcharFieldName, groupByValue, pkValue, filterTop1Pk) + if filterTop1Pk == pkValue { + hitsNum += 1 + } + total += 1 } - total += 1 } } // verify hits rate hitsRate := float32(hitsNum) / float32(total) _str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n", - common.DefaultSparseVecFieldName, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate) + common.DefaultVarcharFieldName, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate) log.Println(_str) require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str) } @@ -340,19 +342,18 @@ func TestSearchGroupByFloatGrowing(t *testing.T) { client.WithSearchQueryConsistencyLevel(entity.ClStrong)) // verify each topK entity is the top1 in the group - for _, rs := range resGroupBy { - for i := 0; i < rs.ResultCount; i++ { - groupByValue, _ := rs.GroupByValue.Get(i) - pkValue, _ := rs.IDs.GetAsInt64(i) + for i := 0; i < common.DefaultNq; i++ { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) var expr string if groupByField == "varchar" { expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue) } else { expr = fmt.Sprintf("%s == %v", groupByField, groupByValue) } - resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, - groupByField}, queryVec, common.DefaultFloatVecFieldName, metricType, 1, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong)) + groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong)) // search filter with groupByValue is the top1 filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)