Skip to content

Commit

Permalink
fix search case
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao committed Apr 18, 2024
1 parent 85c7cef commit 5b00816
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions test/testcases/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func TestSearchEmptyCollection(t *testing.T) {

// search vector
sp, _ := entity.NewIndexHNSWSearchParam(74)
searchRes, _ := mc.Search(
searchRes, errSearch := mc.Search(
ctx, collName,
[]string{common.DefaultPartition},
"",
Expand All @@ -177,10 +177,8 @@ func TestSearchEmptyCollection(t *testing.T) {
common.DefaultTopK,
sp,
)
require.Len(t, searchRes, common.DefaultNq)
for _, resultSet := range searchRes {
assert.EqualValues(t, 0, resultSet.ResultCount)
}
common.CheckErr(t, errSearch, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, 0)
}
}

Expand Down Expand Up @@ -216,9 +214,7 @@ func TestSearchEmptyCollection2(t *testing.T) {
resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, nv.queryVec, nv.fieldName,
nv.metricType, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
for _, res := range resSearch {
require.Nil(t, res)
}
common.CheckSearchResult(t, resSearch, common.DefaultNq, 0)
}
}

Expand Down Expand Up @@ -278,8 +274,6 @@ func TestSearchEmptyPartitions(t *testing.T) {
nq0IDs := searchResult[0].IDs.(*entity.ColumnInt64).Data()
nq1IDs := searchResult[1].IDs.(*entity.ColumnInt64).Data()
common.CheckSearchResult(t, searchResult, 2, common.DefaultTopK)
log.Println(nq0IDs)
log.Println(nq1IDs)
require.Contains(t, nq0IDs, vecColumnDefault.IdsColumn.(*entity.ColumnInt64).Data()[0])
require.Contains(t, nq1IDs, vecColumnPartition.IdsColumn.(*entity.ColumnInt64).Data()[0])
}
Expand Down Expand Up @@ -1600,7 +1594,6 @@ func TestSearchMultiVectors(t *testing.T) {
}

func TestSearchSparseVector(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus-sdk-go/issues/725")
t.Parallel()
idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"})
idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"})
Expand Down Expand Up @@ -1630,10 +1623,15 @@ func TestSearchSparseVector(t *testing.T) {
resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
common.CheckSearchResult(t, resSearch, common.DefaultNq, common.DefaultTopK)
require.Len(t, resSearch, common.DefaultNq)
outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName,
common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName}
common.CheckOutputFields(t, resSearch[0].Fields, outputFields)
for _, res := range resSearch {
require.LessOrEqual(t, res.ResultCount, common.DefaultTopK)
if res.ResultCount == common.DefaultTopK {
common.CheckOutputFields(t, resSearch[0].Fields, outputFields)
}
}
}
}

Expand Down Expand Up @@ -1688,7 +1686,6 @@ func TestSearchInvalidSparseVector(t *testing.T) {
}
}

// TODO https://github.com/milvus-io/milvus-sdk-go/issues/725
func TestSearchEmptySparseCollection(t *testing.T) {
t.Parallel()
idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"})
Expand Down Expand Up @@ -1717,11 +1714,7 @@ func TestSearchEmptySparseCollection(t *testing.T) {
resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
require.Empty(t, resSearch)
//require.Len(t, resSearch, common.DefaultNq)
//for _, res := range resSearch {
// require.Nil(t, res)
//}
common.CheckSearchResult(t, resSearch, common.DefaultNq, 0)
}
}

Expand Down Expand Up @@ -1755,12 +1748,16 @@ func TestSearchSparseVectorPagination(t *testing.T) {
resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
require.Len(t, resSearch, common.DefaultNq)

pageSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,
entity.IP, 5, sp, client.WithOffset(5))
common.CheckErr(t, errSearch, true)
require.Len(t, pageSearch, common.DefaultNq)
for i := 0; i < len(resSearch); i++ {
require.Equal(t, resSearch[i].IDs.(*entity.ColumnInt64).Data()[5:], pageSearch[i].IDs.(*entity.ColumnInt64).Data())
if resSearch[i].ResultCount == common.DefaultTopK && pageSearch[i].ResultCount == 5 {
require.Equal(t, resSearch[i].IDs.(*entity.ColumnInt64).Data()[5:], pageSearch[i].IDs.(*entity.ColumnInt64).Data())
}
}
}
}
Expand Down

0 comments on commit 5b00816

Please sign in to comment.