diff --git a/test/testcases/groupby_search_test.go b/test/testcases/groupby_search_test.go index 92e58290..40fa9246 100644 --- a/test/testcases/groupby_search_test.go +++ b/test/testcases/groupby_search_test.go @@ -24,11 +24,13 @@ func genGroupByVectorIndex(metricType entity.MetricType) []entity.Index { idxFlat, _ := entity.NewIndexFlat(metricType) idxIvfFlat, _ := entity.NewIndexIvfFlat(metricType, nlist) idxHnsw, _ := entity.NewIndexHNSW(metricType, 8, 96) + idxIvfSq8, _ := entity.NewIndexIvfSQ8(metricType, 128) allFloatIndex := []entity.Index{ idxFlat, idxIvfFlat, idxHnsw, + idxIvfSq8, } return allFloatIndex } @@ -47,12 +49,10 @@ func genGroupByBinaryIndex(metricType entity.MetricType) []entity.Index { } func genUnsupportedFloatGroupByIndex() []entity.Index { - // idxIvfSq8, _ := entity.NewIndexIvfSQ8(entity.L2, 128) idxIvfPq, _ := entity.NewIndexIvfPQ(entity.L2, 128, 16, 8) idxScann, _ := entity.NewIndexSCANN(entity.L2, 16, false) idxDiskAnn, _ := entity.NewIndexDISKANN(entity.L2) return []entity.Index{ - // idxIvfSq8, idxIvfPq, idxScann, idxDiskAnn, @@ -463,21 +463,100 @@ func TestSearchGroupByRangeSearch(t *testing.T) { common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by") } -// groupBy + advanced search func TestSearchGroupByHybridSearch(t *testing.T) { + indexHnsw, _ := entity.NewIndexHNSW(entity.COSINE, 8, 96) + mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, indexHnsw, false) + + // search params + queryVec := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + sp, _ := entity.NewIndexHNSWSearchParam(20) + + collection, _ := mc.DescribeCollection(ctx, collName) + common.PrintAllFieldNames(collName, collection.Schema) + + // search with groupBy field + groupByField := common.DefaultVarcharFieldName + + expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName) + queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + sReqs := []*client.ANNSearchRequest{ + client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.COSINE, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2)), + client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.COSINE, expr, queryVec2, sp, common.DefaultTopK), + } + resGroupBy, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithGroupByField(groupByField)) + common.CheckErr(t, errSearch, true) + + // verify each topK entity is the top1 of the whole group + hitsNum := 0 + total := 0 + for i := 0; i < common.DefaultNq; i++ { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) + expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue) + // search filter with groupByValue is the top1 + resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, + groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, entity.COSINE, 1, sp) + filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) + //log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", + // groupByField, groupByValue, pkValue, filterTop1Pk) + if filterTop1Pk == pkValue { + hitsNum += 1 + } + total += 1 + } + } + + // verify hits rate + hitsRate := float32(hitsNum) / float32(total) + _str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n", + groupByField, common.DefaultNq, common.DefaultTopK, hitsNum, hitsRate) + log.Println(_str) + require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str) +} + +// groupBy + advanced search +func TestHybridSearchDifferentGroupByField(t *testing.T) { + t.Skip("TODO: 2.5 test hybrid search with groupBy") // prepare data indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96) - mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, indexHnsw, false) + mc, ctx, collName := prepareDataForGroupBySearch(t, 5, 1000, indexHnsw, false) // hybrid search with groupBy field sp, _ := entity.NewIndexHNSWSearchParam(20) expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName) - queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) - queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec1 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector) sReqs := []*client.ANNSearchRequest{ - client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2)), + client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2), client.WithGroupByField("int64")), client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK, client.WithGroupByField("varchar")), } - _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs) - common.CheckErr(t, errSearch, false, "not support search_group_by operation in the hybrid search") + _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithGroupByField("int8")) + common.CheckErr(t, errSearch, true) + // TODO check the true groupBy field +} + +// groupBy field not existed +func TestSearchNotExistedGroupByField(t *testing.T) { + // prepare data + indexHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96) + mc, ctx, collName := prepareDataForGroupBySearch(t, 2, 1000, indexHnsw, false) + + // hybrid search with groupBy field + sp, _ := entity.NewIndexHNSWSearchParam(20) + expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName) + queryVec1 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector) + queryVec2 := common.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeFloatVector) + sReqs := []*client.ANNSearchRequest{ + client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(2), client.WithGroupByField("aaa")), + client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK, client.WithGroupByField("bbb")), + } + _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithGroupByField("ccc")) + common.CheckErr(t, errSearch, false, "groupBy field not found in schema: field not found[field=ccc]") + + // search + _, err := mc.Search(ctx, collName, []string{}, "", []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName}, + queryVec1, common.DefaultFloatVecFieldName, entity.L2, common.DefaultTopK, sp, client.WithGroupByField("ddd")) + common.CheckErr(t, err, false, "groupBy field not found in schema: field not found[field=ddd]") }