Skip to content

Commit

Permalink
fix: Fix recent behavior change caused case failure (milvus-io#736)
Browse files Browse the repository at this point in the history
- Sparse Vector support range search
- Error message update
- Remove range filter to make case work

---------

Signed-off-by: Congqi Xia <[email protected]>
Signed-off-by: ThreadDao <[email protected]>
Co-authored-by: ThreadDao <[email protected]>
  • Loading branch information
congqixia and ThreadDao authored Apr 25, 2024
1 parent 626c1fe commit 41f8c8d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 28 deletions.
2 changes: 1 addition & 1 deletion test/testcases/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ func TestDeleteInvalidExpr(t *testing.T) {
common.CheckErr(t, err, true)

err = mc.Delete(ctx, collName, "", "")
common.CheckErr(t, err, false, "invalid expression: invalid parameter")
common.CheckErr(t, err, false, "delete plan can't be empty or always true")

for _, _invalidExprs := range common.InvalidExpressions {
err := mc.Delete(ctx, collName, "", _invalidExprs.Expr)
Expand Down
91 changes: 64 additions & 27 deletions test/testcases/search_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:build L0
///go:build L0

package testcases

Expand Down Expand Up @@ -1358,6 +1358,7 @@ func TestRangeSearchScannL2(t *testing.T) {

// test range search with scann index and IP COSINE metric type
func TestRangeSearchScannIPCosine(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/issues/32608")
t.Parallel()
for _, metricType := range []entity.MetricType{entity.IP, entity.COSINE} {
ctx := createContext(t, time.Second*common.DefaultTimeout)
Expand All @@ -1371,7 +1372,7 @@ func TestRangeSearchScannIPCosine(t *testing.T) {

// insert
dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: Int64FloatVecJSON,
start: 0, nb: common.DefaultNb, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false}
start: 0, nb: common.DefaultNb * 4, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false}
_, _ = insertData(ctx, t, mc, dp)
mc.Flush(ctx, collName, false)

Expand All @@ -1392,25 +1393,46 @@ func TestRangeSearchScannIPCosine(t *testing.T) {
// range search filter distance and output all fields
queryVec := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
sp, _ := entity.NewIndexSCANNSearchParam(8, 20)
sp.AddRadius(0)
sp.AddRangeFilter(100)

// search without range
resSearch, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultFloatVecFieldName,
metricType, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
for _, s := range resSearch[0].Scores {
log.Println(s)
}

// range search
var radius float64
var rangeFilter float64
if metricType == entity.COSINE {
radius = 10
rangeFilter = 50
}
if metricType == entity.IP {
radius = 0.2
rangeFilter = 0.8
}
sp.AddRadius(radius)
sp.AddRangeFilter(rangeFilter)
resRange, errRange := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultFloatVecFieldName,
metricType, common.DefaultTopK, sp)

// verify error nil, output all fields, range score
common.CheckErr(t, errSearch, true)
common.CheckSearchResult(t, resSearch, 1, common.DefaultTopK)
common.CheckOutputFields(t, resSearch[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName,
common.CheckErr(t, errRange, true)
common.CheckSearchResult(t, resRange, 1, common.DefaultTopK)
common.CheckOutputFields(t, resRange[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName,
common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultDynamicFieldName})
for _, s := range resSearch[0].Scores {
require.GreaterOrEqual(t, s, float32(0))
require.Less(t, s, float32(100))
log.Println(s)
require.GreaterOrEqual(t, s, float32(radius))
require.Less(t, s, float32(rangeFilter))
}

// invalid range search: radius > range filter
sp.AddRadius(20)
sp.AddRangeFilter(10)
_, errRange := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultFloatVecFieldName,
_, errRange = mc.Search(ctx, collName, []string{}, "", []string{""}, queryVec, common.DefaultFloatVecFieldName,
metricType, common.DefaultTopK, sp)
common.CheckErr(t, errRange, false, "must be greater than radius")
}
Expand Down Expand Up @@ -1472,7 +1494,7 @@ func TestRangeSearchScannBinary(t *testing.T) {
sp.AddRangeFilter(100)
_, errRange := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultBinaryVecFieldName,
metricType, common.DefaultTopK, sp)
common.CheckErr(t, errRange, false, "range_filter(100.000000) must be less than radius(0.000000)")
common.CheckErr(t, errRange, false, "range_filter(100) must be less than radius(0)")
}
}

Expand Down Expand Up @@ -1671,18 +1693,18 @@ func TestSearchInvalidSparseVector(t *testing.T) {
searchRes, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, []entity.Vector{vector1}, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
require.Len(t, searchRes, 0)
common.CheckSearchResult(t, searchRes, 1, 0)

positions := make([]uint32, 100)
values := make([]float32, 100)
for i := 0; i < 100; i++ {
positions[i] = uint32(1)
values[i] = rand.Float32()
}
vector, err := entity.NewSliceSparseEmbedding(positions, values)
searchRes, errSearch = mc.Search(ctx, collName, []string{}, "", []string{"*"}, []entity.Vector{vector}, common.DefaultSparseVecFieldName,
vector, _ := entity.NewSliceSparseEmbedding(positions, values)
_, errSearch1 := mc.Search(ctx, collName, []string{}, "", []string{"*"}, []entity.Vector{vector}, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, false, "unsorted or same indices in sparse float vector")
common.CheckErr(t, errSearch1, false, "unsorted or same indices in sparse float vector")
}
}

Expand Down Expand Up @@ -1762,14 +1784,12 @@ func TestSearchSparseVectorPagination(t *testing.T) {
}
}

// test sparse vector unsupported search: range search, TODO iterator search
// test sparse vector unsupported search: TODO iterator search
func TestSearchSparseVectorNotSupported(t *testing.T) {
// invalid sparse search params
for _, dropRatio := range []float64{1.2, -0.3, 1} {
_, err := entity.NewIndexSparseInvertedSearchParam(dropRatio)
common.CheckErr(t, err, false, fmt.Sprintf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio))
}
t.Skip("Go-sdk support iterator search in progress")
}

func TestRangeSearchSparseVector(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout*2)
// connect
mc := createMilvusClient(ctx, t)
Expand All @@ -1778,12 +1798,12 @@ func TestSearchSparseVectorNotSupported(t *testing.T) {
cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}

dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 2,
dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 4,
dim: common.DefaultDim, EnableDynamicField: true}

// index params
idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"})
idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.1", "metric_type": "IP"})
ips := []IndexParams{
{BuildIndex: true, Index: idxWand, FieldName: common.DefaultSparseVecFieldName, async: false},
{BuildIndex: true, Index: idxHnsw, FieldName: common.DefaultFloatVecFieldName, async: false},
Expand All @@ -1793,11 +1813,28 @@ func TestSearchSparseVectorNotSupported(t *testing.T) {
// range search
queryVec := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
sp, _ := entity.NewIndexSparseInvertedSearchParam(0.3)
sp.AddRadius(10)
sp.AddRangeFilter(100)
_, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,

// without range
resRange, errSearch := mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, false, "RangeSearch not supported for current index type")
common.CheckErr(t, errSearch, true)
require.Len(t, resRange, common.DefaultNq)
for _, res := range resRange {
log.Println(res.Scores)
}

sp.AddRadius(0)
sp.AddRangeFilter(0.8)
resRange, errSearch = mc.Search(ctx, collName, []string{}, "", []string{"*"}, queryVec, common.DefaultSparseVecFieldName,
entity.IP, common.DefaultTopK, sp)
common.CheckErr(t, errSearch, true)
require.Len(t, resRange, common.DefaultNq)
for _, res := range resRange {
for _, s := range res.Scores {
require.GreaterOrEqual(t, s, float32(0))
require.Less(t, s, float32(0.8))
}
}
}

// TODO offset and limit
Expand Down

0 comments on commit 41f8c8d

Please sign in to comment.