From 8794ec966e6732dfd4de31cda0536687f038be3c Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Mon, 16 Dec 2024 10:44:43 +0800 Subject: [PATCH] test: add go case for groupby search (#38411) issue: #33419 --------- Signed-off-by: ThreadDao --- tests/go_client/common/response_checker.go | 24 +- tests/go_client/testcases/delete_test.go | 4 +- .../testcases/groupby_search_test.go | 466 ++++++++++++++++++ .../testcases/helper/field_helper.go | 18 +- tests/go_client/testcases/insert_test.go | 8 +- tests/go_client/testcases/main_test.go | 2 +- 6 files changed, 497 insertions(+), 25 deletions(-) create mode 100644 tests/go_client/testcases/groupby_search_test.go diff --git a/tests/go_client/common/response_checker.go b/tests/go_client/common/response_checker.go index 4355b7f8d7080..96c03087e470c 100644 --- a/tests/go_client/common/response_checker.go +++ b/tests/go_client/common/response_checker.go @@ -44,7 +44,8 @@ func CheckErr(t *testing.T, actualErr error, expErrNil bool, expErrorMsg ...stri func EqualColumn(t *testing.T, columnA column.Column, columnB column.Column) { require.Equal(t, columnA.Name(), columnB.Name()) require.Equal(t, columnA.Type(), columnB.Type()) - switch columnA.Type() { + _type := columnA.Type() + switch _type { case entity.FieldTypeBool: require.ElementsMatch(t, columnA.(*column.ColumnBool).Data(), columnB.(*column.ColumnBool).Data()) case entity.FieldTypeInt8: @@ -65,11 +66,13 @@ func EqualColumn(t *testing.T, columnA column.Column, columnB column.Column) { log.Debug("data", zap.String("name", columnA.Name()), zap.Any("type", columnA.Type()), zap.Any("data", columnA.FieldData())) log.Debug("data", zap.String("name", columnB.Name()), zap.Any("type", columnB.Type()), zap.Any("data", columnB.FieldData())) require.Equal(t, reflect.TypeOf(columnA), reflect.TypeOf(columnB)) - switch columnA.(type) { + switch _v := columnA.(type) { case *column.ColumnDynamic: require.ElementsMatch(t, columnA.(*column.ColumnDynamic).Data(), columnB.(*column.ColumnDynamic).Data()) case *column.ColumnJSONBytes: require.ElementsMatch(t, columnA.(*column.ColumnJSONBytes).Data(), columnB.(*column.ColumnJSONBytes).Data()) + default: + log.Warn("columnA type", zap.String("name", columnB.Name()), zap.Any("type", _v)) } case entity.FieldTypeFloatVector: require.ElementsMatch(t, columnA.(*column.ColumnFloatVector).Data(), columnB.(*column.ColumnFloatVector).Data()) @@ -98,7 +101,7 @@ func EqualArrayColumn(t *testing.T, columnA column.Column, columnB column.Column require.Equal(t, columnA.Name(), columnB.Name()) require.IsType(t, columnA.Type(), entity.FieldTypeArray) require.IsType(t, columnB.Type(), entity.FieldTypeArray) - switch columnA.(type) { + switch _type := columnA.(type) { case *column.ColumnBoolArray: require.ElementsMatch(t, columnA.(*column.ColumnBoolArray).Data(), columnB.(*column.ColumnBoolArray).Data()) case *column.ColumnInt8Array: @@ -116,6 +119,7 @@ func EqualArrayColumn(t *testing.T, columnA column.Column, columnB column.Column case *column.ColumnVarCharArray: require.ElementsMatch(t, columnA.(*column.ColumnVarCharArray).Data(), columnB.(*column.ColumnVarCharArray).Data()) default: + log.Debug("columnA type is", zap.Any("type", _type)) log.Info("Support array element type is:", zap.Any("FieldType", []entity.FieldType{ entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32, entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar, @@ -124,16 +128,16 @@ func EqualArrayColumn(t *testing.T, columnA column.Column, columnB column.Column } // CheckInsertResult check insert result, ids len (insert count), ids data (pks, but no auto ids) -func CheckInsertResult(t *testing.T, expIds column.Column, insertRes client.InsertResult) { - require.Equal(t, expIds.Len(), insertRes.IDs.Len()) - require.Equal(t, expIds.Len(), int(insertRes.InsertCount)) - actualIds := insertRes.IDs - switch expIds.Type() { +func CheckInsertResult(t *testing.T, expIDs column.Column, insertRes client.InsertResult) { + require.Equal(t, expIDs.Len(), insertRes.IDs.Len()) + require.Equal(t, expIDs.Len(), int(insertRes.InsertCount)) + actualIDs := insertRes.IDs + switch expIDs.Type() { // pk field support int64 and varchar type case entity.FieldTypeInt64: - require.ElementsMatch(t, actualIds.(*column.ColumnInt64).Data(), expIds.(*column.ColumnInt64).Data()) + require.ElementsMatch(t, actualIDs.(*column.ColumnInt64).Data(), expIDs.(*column.ColumnInt64).Data()) case entity.FieldTypeVarChar: - require.ElementsMatch(t, actualIds.(*column.ColumnVarChar).Data(), expIds.(*column.ColumnVarChar).Data()) + require.ElementsMatch(t, actualIDs.(*column.ColumnVarChar).Data(), expIDs.(*column.ColumnVarChar).Data()) default: log.Info("The primary field only support ", zap.Any("type", []entity.FieldType{entity.FieldTypeInt64, entity.FieldTypeVarChar})) } diff --git a/tests/go_client/testcases/delete_test.go b/tests/go_client/testcases/delete_test.go index cbf0977a7791a..0c9bdd215b50b 100644 --- a/tests/go_client/testcases/delete_test.go +++ b/tests/go_client/testcases/delete_test.go @@ -545,8 +545,8 @@ func TestDeleteDuplicatedPks(t *testing.T) { prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) // delete - deleteIds := []int64{0, 0, 0, 0, 0} - delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIds)) + deleteIDs := []int64{0, 0, 0, 0, 0} + delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIDs)) common.CheckErr(t, err, true) require.Equal(t, 5, int(delRes.DeleteCount)) diff --git a/tests/go_client/testcases/groupby_search_test.go b/tests/go_client/testcases/groupby_search_test.go new file mode 100644 index 0000000000000..dc6a4905b9630 --- /dev/null +++ b/tests/go_client/testcases/groupby_search_test.go @@ -0,0 +1,466 @@ +package testcases + +import ( + "context" + "fmt" + "log" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/client/v2/index" + client "github.com/milvus-io/milvus/client/v2/milvusclient" + "github.com/milvus-io/milvus/tests/go_client/base" + "github.com/milvus-io/milvus/tests/go_client/common" + hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" +) + +// Generate groupBy-supported vector indexes +func genGroupByVectorIndex(metricType entity.MetricType) []index.Index { + nlist := 128 + idxFlat := index.NewFlatIndex(metricType) + idxIvfFlat := index.NewIvfFlatIndex(metricType, nlist) + idxHnsw := index.NewHNSWIndex(metricType, 8, 96) + idxIvfSq8 := index.NewIvfSQ8Index(metricType, 128) + + allFloatIndex := []index.Index{ + idxFlat, + idxIvfFlat, + idxHnsw, + idxIvfSq8, + } + return allFloatIndex +} + +// Generate groupBy-supported vector indexes +func genGroupByBinaryIndex(metricType entity.MetricType) []index.Index { + nlist := 128 + idxBinFlat := index.NewBinFlatIndex(metricType) + idxBinIvfFlat := index.NewBinIvfFlatIndex(metricType, nlist) + + allFloatIndex := []index.Index{ + idxBinFlat, + idxBinIvfFlat, + } + return allFloatIndex +} + +func genUnsupportedFloatGroupByIndex() []index.Index { + idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 16, 8) + idxScann := index.NewSCANNIndex(entity.L2, 16, false) + return []index.Index{ + idxIvfPq, + idxScann, + } +} + +func prepareDataForGroupBySearch(t *testing.T, loopInsert int, insertNi int, idx index.Index, withGrowing bool) (*base.MilvusClient, context.Context, string) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*5) + mc := createDefaultMilvusClient(ctx, t) + + // create collection with all datatype + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true)) + for i := 0; i < loopInsert; i++ { + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(insertNi)) + } + + if !withGrowing { + prepare.FlushData(ctx, t, mc, schema.CollectionName) + } + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultFloatVecFieldName: idx})) + + // create scalar index + supportedGroupByFields := []string{ + common.DefaultInt64FieldName, common.DefaultInt8FieldName, common.DefaultInt16FieldName, + common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName, + } + for _, groupByField := range supportedGroupByFields { + idxTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, groupByField, index.NewAutoIndex(entity.L2))) + common.CheckErr(t, err, true) + err = idxTask.Await(ctx) + common.CheckErr(t, err, true) + } + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + return mc, ctx, schema.CollectionName +} + +// create coll with all datatype -> build all supported index +// -> search with WithGroupByField (int* + varchar + bool +// -> verify every top passage is the top of whole group +// output_fields: pk + groupBy +func TestSearchGroupByFloatDefault(t *testing.T) { + t.Skip("https://github.com/milvus-io/milvus/issues/38343") + t.Parallel() + for _, idx := range genGroupByVectorIndex(entity.L2) { + // prepare data + mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idx, false) + + // search params + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + + // search with groupBy field + supportedGroupByFields := []string{ + common.DefaultInt64FieldName, common.DefaultInt8FieldName, + common.DefaultInt16FieldName, common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName, + } + for _, groupByField := range supportedGroupByFields { + resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName). + WithGroupByField(groupByField).WithOutputFields(common.DefaultInt64FieldName, groupByField)) + + // verify each topK entity is the top1 of the whole group + hitsNum := 0 + total := 0 + for i := 0; i < common.DefaultNq; i++ { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) + var expr string + if groupByField == "varchar" { + expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue) + } else { + expr = fmt.Sprintf("%s == %v", groupByField, groupByValue) + } + + // search filter with groupByValue is the top1 + resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec[:1]).WithANNSField(common.DefaultFloatVecFieldName). + WithGroupByField(groupByField).WithFilter(expr).WithOutputFields(common.DefaultInt64FieldName, groupByField)) + + filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) + if filterTop1Pk == pkValue { + hitsNum += 1 + } + total += 1 + } + } + + // verify hits rate + hitsRate := float32(hitsNum) / float32(total) + _str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n", + groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate) + log.Println(_str) + if groupByField != "bool" { + // waiting for fix https://github.com/milvus-io/milvus/issues/32630 + require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str) + } + } + } +} + +func TestSearchGroupByFloatDefaultCosine(t *testing.T) { + t.Skip("https://github.com/milvus-io/milvus/issues/38343") + t.Parallel() + for _, idx := range genGroupByVectorIndex(entity.COSINE) { + // prepare data + mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idx, false) + + // search params + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + + // search with groupBy field without varchar + supportedGroupByFields := []string{ + common.DefaultInt64FieldName, common.DefaultInt8FieldName, + common.DefaultInt16FieldName, common.DefaultInt32FieldName, common.DefaultBoolFieldName, + } + for _, groupByField := range supportedGroupByFields { + resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName). + WithGroupByField(groupByField).WithOutputFields(common.DefaultInt64FieldName, groupByField)) + + // verify each topK entity is the top1 of the whole group + hitsNum := 0 + total := 0 + for i := 0; i < common.DefaultNq; i++ { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) + expr := fmt.Sprintf("%s == %v", groupByField, groupByValue) + + // search filter with groupByValue is the top1 + resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec[:1]).WithANNSField(common.DefaultFloatVecFieldName). + WithGroupByField(groupByField).WithFilter(expr).WithOutputFields(common.DefaultInt64FieldName, groupByField)) + + filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) + if filterTop1Pk == pkValue { + hitsNum += 1 + } + total += 1 + } + } + + // verify hits rate + hitsRate := float32(hitsNum) / float32(total) + _str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n", + groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate) + log.Println(_str) + if groupByField != "bool" { + // waiting for fix https://github.com/milvus-io/milvus/issues/32630 + require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str) + } + } + } +} + +// test groupBy search sparse vector +func TestGroupBySearchSparseVector(t *testing.T) { + t.Parallel() + idxInverted := index.NewSparseInvertedIndex(entity.IP, 0.3) + idxWand := index.NewSparseWANDIndex(entity.IP, 0.2) + for _, idx := range []index.Index{idxInverted, idxWand} { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := createDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption().TWithMaxLen(common.TestMaxLen), + hp.TNewSchemaOption().TWithEnableDynamicField(true)) + for i := 0; i < 100; i++ { + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(200)) + } + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultSparseVecFieldName: idx})) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // groupBy search + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector) + + resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultSparseVecFieldName). + WithGroupByField(common.DefaultVarcharFieldName).WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName)) + + // verify each topK entity is the top1 of the whole group + hitsNum := 0 + total := 0 + for i := 0; i < common.DefaultNq; i++ { + if resGroupBy[i].ResultCount > 0 { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) + expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue) + // search filter with groupByValue is the top1 + resFilter, _ := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, 1, []entity.Vector{queryVec[i]}). + WithANNSField(common.DefaultSparseVecFieldName). + WithGroupByField(common.DefaultVarcharFieldName). + WithFilter(expr). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName)) + + filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) + log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", + common.DefaultVarcharFieldName, groupByValue, pkValue, filterTop1Pk) + if filterTop1Pk == pkValue { + hitsNum += 1 + } + total += 1 + } + } + } + + // verify hits rate + hitsRate := float32(hitsNum) / float32(total) + _str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n", + common.DefaultVarcharFieldName, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate) + log.Println(_str) + require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str) + } +} + +// binary vector -> not supported +func TestSearchGroupByBinaryDefault(t *testing.T) { + t.Parallel() + for _, metricType := range hp.SupportBinIvfFlatMetricType { + for _, idx := range genGroupByBinaryIndex(metricType) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := createDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.VarcharBinary), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(true)) + for i := 0; i < 2; i++ { + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(1000)) + } + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultBinaryVecFieldName: idx})) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // search params + queryVec := hp.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeBinaryVector) + t.Log("Waiting for support for specifying search parameters") + // sp, _ := index.NewBinIvfFlatIndexSearchParam(32) + supportedGroupByFields := []string{common.DefaultVarcharFieldName, common.DefaultBinaryVecFieldName} + + // search with groupBy field + for _, groupByField := range supportedGroupByFields { + _, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(groupByField). + WithOutputFields(common.DefaultVarcharFieldName, groupByField)) + common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column") + } + } + } +} + +// binary vector -> growing segments, maybe brute force +// default Bounded ConsistencyLevel -> succ ?? +// strong ConsistencyLevel -> error +func TestSearchGroupByBinaryGrowing(t *testing.T) { + t.Skip("https://github.com/milvus-io/milvus/issues/38343") + t.Parallel() + for _, metricType := range hp.SupportBinIvfFlatMetricType { + idxBinIvfFlat := index.NewBinIvfFlatIndex(metricType, 128) + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := createDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.VarcharBinary), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(true)) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultBinaryVecFieldName: idxBinIvfFlat})) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + for i := 0; i < 2; i++ { + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(1000)) + } + + // search params + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeBinaryVector) + t.Log("Waiting for support for specifying search parameters") + // sp, _ := index.NewBinIvfFlatIndexSearchParam(64) + supportedGroupByFields := []string{common.DefaultVarcharFieldName} + + // search with groupBy field + for _, groupByField := range supportedGroupByFields { + _, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(groupByField). + WithOutputFields(common.DefaultVarcharFieldName, groupByField).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column") + } + } +} + +// groupBy in growing segments, maybe growing index or brute force +func TestSearchGroupByFloatGrowing(t *testing.T) { + for _, metricType := range hp.SupportFloatMetricType { + idxHnsw := index.NewHNSWIndex(metricType, 8, 96) + mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idxHnsw, true) + + // search params + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + supportedGroupByFields := []string{common.DefaultInt64FieldName, "int8", "int16", "int32", "varchar", "bool"} + + // search with groupBy field + hitsNum := 0 + total := 0 + for _, groupByField := range supportedGroupByFields { + resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName). + WithOutputFields(common.DefaultInt64FieldName, groupByField).WithGroupByField(groupByField).WithConsistencyLevel(entity.ClStrong)) + + // verify each topK entity is the top1 in the group + for i := 0; i < common.DefaultNq; i++ { + for j := 0; j < resGroupBy[i].ResultCount; j++ { + groupByValue, _ := resGroupBy[i].GroupByValue.Get(j) + pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) + var expr string + if groupByField == "varchar" { + expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue) + } else { + expr = fmt.Sprintf("%s == %v", groupByField, groupByValue) + } + resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec).WithANNSField(common.DefaultFloatVecFieldName). + WithOutputFields(common.DefaultInt64FieldName, groupByField).WithGroupByField(groupByField).WithFilter(expr).WithConsistencyLevel(entity.ClStrong)) + + // search filter with groupByValue is the top1 + filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) + log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", + groupByField, groupByValue, pkValue, filterTop1Pk) + if filterTop1Pk == pkValue { + hitsNum += 1 + } + total += 1 + } + } + // verify hits rate + hitsRate := float32(hitsNum) / float32(total) + _str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n", + groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate) + log.Println(_str) + if groupByField != "bool" { + require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str) + } + } + } +} + +// groupBy + pagination +func TestSearchGroupByPagination(t *testing.T) { + // create index and load + idx := index.NewHNSWIndex(entity.COSINE, 8, 96) + mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, idx, false) + + // search params + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + offset := 10 + + // search pagination & groupBy + resGroupByPagination, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithOffset(offset). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName)) + + common.CheckSearchResult(t, resGroupByPagination, common.DefaultNq, common.DefaultLimit) + + // search limit=origin limit + offset + resGroupByDefault, _ := mc.Search(ctx, client.NewSearchOption(collName, offset+common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName)) + + for i := 0; i < common.DefaultNq; i++ { + require.Equal(t, resGroupByDefault[i].IDs.(*column.ColumnInt64).Data()[10:], resGroupByPagination[i].IDs.(*column.ColumnInt64).Data()) + } +} + +// only support: "FLAT", "IVF_FLAT", "HNSW" +func TestSearchGroupByUnsupportedIndex(t *testing.T) { + t.Parallel() + for _, idx := range genUnsupportedFloatGroupByIndex() { + t.Run(string(idx.IndexType()), func(t *testing.T) { + mc, ctx, collName := prepareDataForGroupBySearch(t, 3, 1000, idx, false) + // groupBy search + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + _, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, false, "doesn't support") + }) + } +} + +// FLOAT, DOUBLE, JSON, ARRAY +func TestSearchGroupByUnsupportedDataType(t *testing.T) { + idxHnsw := index.NewHNSWIndex(entity.L2, 8, 96) + mc, ctx, collName := prepareDataForGroupBySearch(t, 1, 1000, idxHnsw, true) + + // groupBy search with unsupported field type + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + for _, unsupportedField := range []string{ + common.DefaultFloatFieldName, common.DefaultDoubleFieldName, + common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultInt8ArrayField, common.DefaultFloatArrayField, + } { + _, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, false, "unsupported data type") + } +} + +// groupBy + iterator -> not supported +func TestSearchGroupByIterator(t *testing.T) { + // TODO: sdk support +} + +// groupBy + range search -> not supported +func TestSearchGroupByRangeSearch(t *testing.T) { + t.Skipf("Waiting for support for specifying search parameters") + idxHnsw := index.NewHNSWIndex(entity.COSINE, 8, 96) + mc, ctx, collName := prepareDataForGroupBySearch(t, 1, 1000, idxHnsw, true) + + // groupBy search with range + queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) + + // sp, _ := index.NewHNSWIndexSearchParam(50) + // sp.AddRadius(0) + // sp.AddRangeFilter(0.8) + + // range search + _, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by") +} + +// groupBy + advanced search +func TestSearchGroupByHybridSearch(t *testing.T) { + t.Skipf("Waiting for HybridSearch implemention") +} diff --git a/tests/go_client/testcases/helper/field_helper.go b/tests/go_client/testcases/helper/field_helper.go index 07ecc01635065..a99f17e456cc1 100644 --- a/tests/go_client/testcases/helper/field_helper.go +++ b/tests/go_client/testcases/helper/field_helper.go @@ -272,9 +272,10 @@ func (cf FieldsAllFields) GenFields(option GenFieldsOption) []*entity.Field { } // scalar fields and array fields for _, fieldType := range GetAllScalarFieldType() { - if fieldType == entity.FieldTypeInt64 { + switch fieldType { + case entity.FieldTypeInt64: continue - } else if fieldType == entity.FieldTypeArray { + case entity.FieldTypeArray: for _, eleType := range GetAllArrayElementType() { arrayField := entity.NewField().WithName(GetFieldNameByElementType(eleType)).WithDataType(entity.FieldTypeArray).WithElementType(eleType).WithMaxCapacity(option.MaxCapacity) if eleType == entity.FieldTypeVarChar { @@ -282,10 +283,10 @@ func (cf FieldsAllFields) GenFields(option GenFieldsOption) []*entity.Field { } fields = append(fields, arrayField) } - } else if fieldType == entity.FieldTypeVarChar { + case entity.FieldTypeVarChar: varcharField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType).WithMaxLength(option.MaxLength) fields = append(fields, varcharField) - } else { + default: scalarField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType) fields = append(fields, scalarField) } @@ -312,9 +313,10 @@ func (cf FieldsInt64VecAllScalar) GenFields(option GenFieldsOption) []*entity.Fi } // scalar fields and array fields for _, fieldType := range GetAllScalarFieldType() { - if fieldType == entity.FieldTypeInt64 { + switch fieldType { + case entity.FieldTypeInt64: continue - } else if fieldType == entity.FieldTypeArray { + case entity.FieldTypeArray: for _, eleType := range GetAllArrayElementType() { arrayField := entity.NewField().WithName(GetFieldNameByElementType(eleType)).WithDataType(entity.FieldTypeArray).WithElementType(eleType).WithMaxCapacity(option.MaxCapacity) if eleType == entity.FieldTypeVarChar { @@ -322,10 +324,10 @@ func (cf FieldsInt64VecAllScalar) GenFields(option GenFieldsOption) []*entity.Fi } fields = append(fields, arrayField) } - } else if fieldType == entity.FieldTypeVarChar { + case entity.FieldTypeVarChar: varcharField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType).WithMaxLength(option.MaxLength) fields = append(fields, varcharField) - } else { + default: scalarField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType) fields = append(fields, scalarField) } diff --git a/tests/go_client/testcases/insert_test.go b/tests/go_client/testcases/insert_test.go index 0de857ebbe372..33e1aaba26562 100644 --- a/tests/go_client/testcases/insert_test.go +++ b/tests/go_client/testcases/insert_test.go @@ -488,8 +488,8 @@ func TestInsertReadSparseEmptyVector(t *testing.T) { // sparse vector: empty position and values sparseVec, err := entity.NewSliceSparseEmbedding([]uint32{}, []float32{}) common.CheckErr(t, err, true) - data2 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec})) - insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data2...)) + data = append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec})) + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data...)) common.CheckErr(t, err, true) require.EqualValues(t, 1, insertRes.InsertCount) @@ -526,8 +526,8 @@ func TestInsertSparseInvalidVector(t *testing.T) { values = []float32{0.4} sparseVec, err := entity.NewSliceSparseEmbedding(positions, values) common.CheckErr(t, err, true) - data1 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec})) - _, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data1...)) + data = append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec})) + _, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data...)) common.CheckErr(t, err, false, "invalid index in sparse float vector: must be less than 2^32-1") } diff --git a/tests/go_client/testcases/main_test.go b/tests/go_client/testcases/main_test.go index b20ec38cfcd6f..07b5be39be874 100644 --- a/tests/go_client/testcases/main_test.go +++ b/tests/go_client/testcases/main_test.go @@ -28,7 +28,7 @@ func teardown() { defer cancel() mc, err := base.NewMilvusClient(ctx, &defaultCfg) if err != nil { - log.Fatal("teardown failed to connect milvus with error", zap.Error(err)) + log.Error("teardown failed to connect milvus with error", zap.Error(err)) } defer mc.Close(ctx)