Skip to content

Commit

Permalink
test: update case: hybrid search support groupby
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao committed Sep 10, 2024
1 parent 878e052 commit 2557156
Showing 1 changed file with 88 additions and 9 deletions.
97 changes: 88 additions & 9 deletions test/testcases/groupby_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ func genGroupByVectorIndex(metricType entity.MetricType) []entity.Index {
idxFlat, _ := entity.NewIndexFlat(metricType)
idxIvfFlat, _ := entity.NewIndexIvfFlat(metricType, nlist)
idxHnsw, _ := entity.NewIndexHNSW(metricType, 8, 96)
idxIvfSq8, _ := entity.NewIndexIvfSQ8(metricType, 128)

allFloatIndex := []entity.Index{
idxFlat,
idxIvfFlat,
idxHnsw,
idxIvfSq8,
}
return allFloatIndex
}
Expand All @@ -47,12 +49,10 @@ func genGroupByBinaryIndex(metricType entity.MetricType) []entity.Index {
}

func genUnsupportedFloatGroupByIndex() []entity.Index {
// idxIvfSq8, _ := entity.NewIndexIvfSQ8(entity.L2, 128)
idxIvfPq, _ := entity.NewIndexIvfPQ(entity.L2, 128, 16, 8)
idxScann, _ := entity.NewIndexSCANN(entity.L2, 16, false)
idxDiskAnn, _ := entity.NewIndexDISKANN(entity.L2)
return []entity.Index{
// idxIvfSq8,
idxIvfPq,
idxScann,
idxDiskAnn,
Expand Down Expand Up @@ -463,21 +463,100 @@ func TestSearchGroupByRangeSearch(t *testing.T) {
common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by")
}

// groupBy + advanced search
func TestSearchGroupByHybridSearch(t *testing.T) {
indexHnsw, _ := entity.NewIndexHNSW(entity.COSINE, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, indexHnsw, false)

// search params
queryVec := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
sp, _ := entity.NewIndexHNSWSearchParam(20)

collection, _ := mc.DescribeCollection(ctx, collName)
common.PrintAllFieldNames(collName, collection.Schema)

// search with groupBy field
groupByField := common.DefaultVarcharFieldName

expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName)
queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
sReqs := []*client.ANNSearchRequest{
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.COSINE, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2)),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.COSINE, expr, queryVec2, sp, common.DefaultTopK),
}
resGroupBy, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithGroupByField(groupByField))
common.CheckErr(t, errSearch, true)

// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName,
groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, entity.COSINE, 1, sp)
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
//log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
// groupByField, groupByValue, pkValue, filterTop1Pk)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}

// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate)
log.Println(_str)
require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str)
}

// groupBy + advanced search
func TestHybridSearchDifferentGroupByField(t *testing.T) {
t.Skip("TODO: 2.5 test hybrid search with groupBy")
// prepare data
indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, indexHnsw, false)
mc, ctx, collName := prepareDataForGroupBySearch(t, 5, 1000, indexHnsw, false)

// hybrid search with groupBy field
sp, _ := entity.NewIndexHNSWSearchParam(20)
expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName)
queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
queryVec1 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector)
queryVec2 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector)
sReqs := []*client.ANNSearchRequest{
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2)),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2), client.WithGroupByField("int64")),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK, client.WithGroupByField("varchar")),
}
_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs)
common.CheckErr(t, errSearch, false, "not support search_group_by operation in the hybrid search")
_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithGroupByField("int8"))
common.CheckErr(t, errSearch, true)
// TODO check the true groupBy field
}

// groupBy field not existed
func TestSearchNotExistedGroupByField(t *testing.T) {
// prepare data
indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 2, 1000, indexHnsw, false)

// hybrid search with groupBy field
sp, _ := entity.NewIndexHNSWSearchParam(20)
expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName)
queryVec1 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector)
queryVec2 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector)
sReqs := []*client.ANNSearchRequest{
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2), client.WithGroupByField("aaa")),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK, client.WithGroupByField("bbb")),
}
_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithGroupByField("ccc"))
common.CheckErr(t, errSearch, false, "groupBy field not found in schema: field not found[field=ccc]")

// search
_, err := mc.Search(ctx, collName, []string{}, "", []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName},
queryVec1, common.DefaultFloatVecFieldName, entity.L2, common.DefaultTopK, sp, client.WithGroupByField("ddd"))
common.CheckErr(t, err, false, "groupBy field not found in schema: field not found[field=ddd]")
}

0 comments on commit 2557156

Please sign in to comment.