Skip to content

Commit

Permalink
test: update gosdk case for valid empty sparse vector (#35621)
Browse files Browse the repository at this point in the history
/kind improvement

Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao authored Aug 22, 2024
1 parent c992a61 commit 570a887
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 14 deletions.
43 changes: 34 additions & 9 deletions tests/go_client/testcases/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,40 @@ func TestInsertSparseDataMaxDim(t *testing.T) {
common.CheckInsertResult(t, pkColumn, inRes)
}

// empty spare vector can't be searched, but can be queried
func TestInsertReadSparseEmptyVector(t *testing.T) {
// invalid sparse vector: positions >= uint32
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)

cp := hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption())
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

// insert data column
columnOpt := hp.TNewDataOption()
data := []column.Column{
hp.GenColumnData(1, entity.FieldTypeInt64, *columnOpt),
hp.GenColumnData(1, entity.FieldTypeVarChar, *columnOpt),
}

// sparse vector: empty position and values
sparseVec, err := entity.NewSliceSparseEmbedding([]uint32{}, []float32{})
common.CheckErr(t, err, true)
data2 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec}))
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data2...))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, insertRes.InsertCount)

// query and check vector is empty
resQuery, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithLimit(10).WithOutputFields([]string{common.DefaultSparseVecFieldName}).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.Equal(t, 1, resQuery.ResultCount)
log.Info("sparseVec", zap.Any("data", resQuery.GetColumn(common.DefaultSparseVecFieldName).(*column.ColumnSparseFloatVector).Data()))
common.EqualColumn(t, resQuery.GetColumn(common.DefaultSparseVecFieldName), column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec}))
}

func TestInsertSparseInvalidVector(t *testing.T) {
// invalid sparse vector: len(positions) != len(values)
positions := []uint32{1, 10}
Expand Down Expand Up @@ -455,15 +489,6 @@ func TestInsertSparseInvalidVector(t *testing.T) {
data1 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec}))
_, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data1...))
common.CheckErr(t, err, false, "invalid index in sparse float vector: must be less than 2^32-1")

// invalid sparse vector: empty position and values
positions = []uint32{}
values = []float32{}
sparseVec, err = entity.NewSliceSparseEmbedding(positions, values)
common.CheckErr(t, err, true)
data2 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec}))
_, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data2...))
common.CheckErr(t, err, false, "empty sparse float vector row")
}

func TestInsertSparseVectorSamePosition(t *testing.T) {
Expand Down
73 changes: 68 additions & 5 deletions tests/go_client/testcases/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,11 +937,6 @@ func TestSearchInvalidSparseVector(t *testing.T) {
_, errSearch := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, []entity.Vector{}).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errSearch, false, "nq (number of search vector per search request) should be in range [1, 16384]")

vector1, err := entity.NewSliceSparseEmbedding([]uint32{}, []float32{})
common.CheckErr(t, err, true)
_, errSearch1 := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, []entity.Vector{vector1}).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errSearch1, false, "Sparse row data should not be empty")

positions := make([]uint32, 100)
values := make([]float32, 100)
for i := 0; i < 100; i++ {
Expand All @@ -954,6 +949,74 @@ func TestSearchInvalidSparseVector(t *testing.T) {
}
}

// test search with empty sparse vector
func TestSearchWithEmptySparseVector(t *testing.T) {
t.Parallel()
idxInverted := index.NewSparseInvertedIndex(entity.IP, 0.1)
idxWand := index.NewSparseWANDIndex(entity.IP, 0.1)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := createDefaultMilvusClient(ctx, t)

for _, idx := range []index.Index{idxInverted, idxWand} {
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption(), hp.TNewSchemaOption().
TWithEnableDynamicField(true))
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithSparseMaxLen(128))
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultSparseVecFieldName: idx}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

// An empty sparse vector is considered to be uncorrelated with any other vector.
vector1, err := entity.NewSliceSparseEmbedding([]uint32{}, []float32{})
common.CheckErr(t, err, true)
searchRes, errSearch1 := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, []entity.Vector{vector1}).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errSearch1, true)
common.CheckSearchResult(t, searchRes, 1, 0)
}
}

// test search from empty sparse vectors collection
func TestSearchFromEmptySparseVector(t *testing.T) {
t.Parallel()
idxInverted := index.NewSparseInvertedIndex(entity.IP, 0.1)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := createDefaultMilvusClient(ctx, t)

for _, idx := range []index.Index{idxInverted} {
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption(), hp.TNewSchemaOption().
TWithEnableDynamicField(true))
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultSparseVecFieldName: idx}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

// insert sparse vector: empty position and values
columnOpt := hp.TNewDataOption()
data := []column.Column{
hp.GenColumnData(common.DefaultNb, entity.FieldTypeInt64, *columnOpt),
hp.GenColumnData(common.DefaultNb, entity.FieldTypeVarChar, *columnOpt),
}
sparseVecs := make([]entity.SparseEmbedding, 0, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
vec, _ := entity.NewSliceSparseEmbedding([]uint32{}, []float32{})
sparseVecs = append(sparseVecs, vec)
}

data = append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, sparseVecs))
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data...))
common.CheckErr(t, err, true)
require.EqualValues(t, common.DefaultNb, insertRes.InsertCount)
prepare.FlushData(ctx, t, mc, schema.CollectionName)

// search vector is or not empty sparse vector
vector1, _ := entity.NewSliceSparseEmbedding([]uint32{}, []float32{})
vector2, _ := entity.NewSliceSparseEmbedding([]uint32{0, 2, 5, 10, 100}, []float32{rand.Float32(), rand.Float32(), rand.Float32(), rand.Float32(), rand.Float32()})

for _, vector := range []entity.Vector{vector1, vector2} {
searchRes, errSearch1 := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, []entity.Vector{vector}).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errSearch1, true)
common.CheckSearchResult(t, searchRes, 1, 0)
}
}
}

func TestSearchSparseVectorPagination(t *testing.T) {
t.Parallel()
idxInverted := index.NewGenericIndex(common.DefaultSparseVecFieldName, map[string]string{"drop_ratio_build": "0.2", index.MetricTypeKey: "IP", index.IndexTypeKey: "SPARSE_INVERTED_INDEX"})
Expand Down

0 comments on commit 570a887

Please sign in to comment.