Skip to content

Commit

Permalink
Fix groupby search case
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao committed Apr 18, 2024
1 parent 5b00816 commit 8514063
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions test/testcases/groupby_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,19 @@ func TestSearchGroupByFloatDefault(t *testing.T) {
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for _, rs := range resGroupBy {
for i := 0; i < rs.ResultCount; i++ {
groupByValue, _ := rs.GroupByValue.Get(i)
pkValue, _ := rs.IDs.GetAsInt64(i)
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
var expr string
if groupByField == "varchar" {
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
} else {
expr = fmt.Sprintf("%s == %v", groupByField, groupByValue)
}

// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
groupByField}, queryVec, common.DefaultFloatVecFieldName, metricType, 1, sp)
groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp)
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
//log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
// groupByField, groupByValue, pkValue, filterTop1Pk)
Expand Down Expand Up @@ -209,27 +208,30 @@ func TestGroupBySearchSparseVector(t *testing.T) {
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for _, rs := range resGroupBy {
for i := 0; i < rs.ResultCount; i++ {
groupByValue, _ := rs.GroupByValue.Get(i)
pkValue, _ := rs.IDs.GetAsInt64(i)
expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue)

// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
common.DefaultVarcharFieldName}, queryVec, common.DefaultSparseVecFieldName, entity.IP, 1, sp)
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
for i := 0; i < common.DefaultNq; i++ {
if resGroupBy[i].ResultCount > 0 {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue)
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
common.DefaultVarcharFieldName}, []entity.Vector{queryVec[i]}, common.DefaultSparseVecFieldName, entity.IP, 1, sp)
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
common.DefaultVarcharFieldName, groupByValue, pkValue, filterTop1Pk)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
total += 1
}
}

// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
common.DefaultSparseVecFieldName, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate)
common.DefaultVarcharFieldName, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate)
log.Println(_str)
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
Expand Down Expand Up @@ -340,19 +342,18 @@ func TestSearchGroupByFloatGrowing(t *testing.T) {
client.WithSearchQueryConsistencyLevel(entity.ClStrong))

// verify each topK entity is the top1 in the group
for _, rs := range resGroupBy {
for i := 0; i < rs.ResultCount; i++ {
groupByValue, _ := rs.GroupByValue.Get(i)
pkValue, _ := rs.IDs.GetAsInt64(i)
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
var expr string
if groupByField == "varchar" {
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
} else {
expr = fmt.Sprintf("%s == %v", groupByField, groupByValue)
}

resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
groupByField}, queryVec, common.DefaultFloatVecFieldName, metricType, 1, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong))
groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong))

// search filter with groupByValue is the top1
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
Expand Down

0 comments on commit 8514063

Please sign in to comment.