Skip to content

Commit

Permalink
Add test cases for sparse vector
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao committed Apr 16, 2024
1 parent a963bd4 commit 319b4f4
Show file tree
Hide file tree
Showing 12 changed files with 622 additions and 17 deletions.
3 changes: 2 additions & 1 deletion client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error {
if field.DataType == entity.FieldTypeFloatVector ||
field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeFloat16Vector {
field.DataType == entity.FieldTypeFloat16Vector ||
field.DataType == entity.FieldTypeSparseVector {
vectors++
}
}
Expand Down
4 changes: 4 additions & 0 deletions entity/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ func AnyToColumns(rows []interface{}, schemas ...*Schema) ([]Column, error) {
}
col := NewColumnBFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case FieldTypeSparseVector:
data := make([]SparseEmbedding, 0, rowsLen)
col := NewColumnSparseVectors(field.Name, data)
nameColumns[field.Name] = col
}
}

Expand Down
97 changes: 95 additions & 2 deletions test/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
DefaultBinaryVecFieldName = "binaryVec"
DefaultFloat16VecFieldName = "fp16Vec"
DefaultBFloat16VecFieldName = "bf16Vec"
DefaultSparseVecFieldName = "sparseVec"
DefaultDynamicNumberField = "dynamicNumber"
DefaultDynamicStringField = "dynamicString"
DefaultDynamicBoolField = "dynamicBool"
Expand Down Expand Up @@ -220,6 +221,22 @@ func GenBinaryVector(dim int64) []byte {
return vector
}

func GenSparseVector(maxLen int) entity.SparseEmbedding {
length := 1 + rand.Intn(1+maxLen)
positions := make([]uint32, length)
values := make([]float32, length)
for i := 0; i < length; i++ {
//positions[i] = rand.Uint32() - 1
positions[i] = uint32(i)
values[i] = rand.Float32()
}
vector, err := entity.NewSliceSparseEmbedding(positions, values)
if err != nil {
log.Fatalf("Generate vector failed %s", err)
}
return vector
}

// --- common utils ---

// --- gen fields ---
Expand Down Expand Up @@ -402,6 +419,13 @@ func GenColumnData(start int, nb int, fieldType entity.FieldType, fieldName stri
bf16Vectors = append(bf16Vectors, vec)
}
return entity.NewColumnBFloat16Vector(fieldName, int(opt.dim), bf16Vectors)
case entity.FieldTypeSparseVector:
vectors := make([]entity.SparseEmbedding, 0, nb)
for i := start; i < start+nb; i++ {
vec := GenSparseVector(opt.maxLenSparse)
vectors = append(vectors, vec)
}
return entity.NewColumnSparseVectors(fieldName, vectors)
default:
return nil
}
Expand Down Expand Up @@ -981,6 +1005,53 @@ func GenDefaultArrayRows(start int, nb int, dim int64, enableDynamicField bool,
return rows
}

func GenDefaultSparseRows(start int, nb int, dim int64, maxLenSparse int, enableDynamicField bool) []interface{} {
rows := make([]interface{}, 0, nb)
type BaseRow struct {
Int64 int64 `json:"int64" milvus:"name:int64"`
Varchar string `json:"varchar" milvus:"name:varchar"`
FloatVec []float32 `json:"floatVec" milvus:"name:floatVec"`
SparseVec entity.SparseEmbedding `json:"sparseVec" milvus:"name:sparseVec"`
}

type DynamicRow struct {
Int64 int64 `json:"int64" milvus:"name:int64"`
Varchar string `json:"varchar" milvus:"name:varchar"`
FloatVec []float32 `json:"floatVec" milvus:"name:floatVec"`
SparseVec entity.SparseEmbedding `json:"sparseVec" milvus:"name:sparseVec"`
Dynamic Dynamic `json:"dynamic" milvus:"name:dynamic"`
}

for i := start; i < start+nb; i++ {
baseRow := BaseRow{
Int64: int64(i),
Varchar: strconv.Itoa(i),
FloatVec: GenFloatVector(dim),
SparseVec: GenSparseVector(maxLenSparse),
}
// json and dynamic field
dynamicJSON := Dynamic{
Number: int32(i),
String: strconv.Itoa(i),
Bool: i%2 == 0,
List: []int64{int64(i), int64(i + 1)},
}
if enableDynamicField {
dynamicRow := DynamicRow{
Int64: baseRow.Int64,
Varchar: baseRow.Varchar,
FloatVec: baseRow.FloatVec,
SparseVec: baseRow.SparseVec,
Dynamic: dynamicJSON,
}
rows = append(rows, dynamicRow)
} else {
rows = append(rows, &baseRow)
}
}
return rows
}

func GenAllVectorsRows(start int, nb int, dim int64, enableDynamicField bool) []interface{} {
rows := make([]interface{}, 0, nb)
type BaseRow struct {
Expand Down Expand Up @@ -1231,11 +1302,28 @@ var SupportBinIvfFlatMetricType = []entity.MetricType{
entity.HAMMING,
}

var UnsupportedSparseVecMetricsType = []entity.MetricType{
entity.L2,
entity.COSINE,
entity.JACCARD,
entity.HAMMING,
entity.SUBSTRUCTURE,
entity.SUPERSTRUCTURE,
}

// GenAllFloatIndex gen all float vector index
func GenAllFloatIndex() []entity.Index {
func GenAllFloatIndex(metricTypes ...entity.MetricType) []entity.Index {
nlist := 128
var allFloatIndex []entity.Index
for _, metricType := range SupportFloatMetricType {
var allMetricTypes []entity.MetricType
log.Println(metricTypes)
if len(metricTypes) == 0 {
allMetricTypes = SupportFloatMetricType
} else {
allMetricTypes = metricTypes
}
for _, metricType := range allMetricTypes {
log.Println(metricType)
idxFlat, _ := entity.NewIndexFlat(metricType)
idxIvfFlat, _ := entity.NewIndexIvfFlat(metricType, nlist)
idxIvfSq8, _ := entity.NewIndexIvfSQ8(metricType, nlist)
Expand Down Expand Up @@ -1277,6 +1365,11 @@ func GenSearchVectors(nq int, dim int64, dataType entity.FieldType) []entity.Vec
vector := GenBFloat16Vector(dim)
vectors = append(vectors, entity.BFloat16Vector(vector))
}
case entity.FieldTypeSparseVector:
for i := 0; i < nq; i++ {
vec := GenSparseVector(int(dim))
vectors = append(vectors, vec)
}
}
return vectors
}
Expand Down
13 changes: 10 additions & 3 deletions test/common/utils_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ func GenSchema(name string, autoID bool, fields []*entity.Field, opts ...CreateS
// GenColumnDataOption -- create column data --
type GenColumnDataOption func(opt *genDataOpt)
type genDataOpt struct {
dim int64
ElementType entity.FieldType
capacity int64
dim int64
ElementType entity.FieldType
capacity int64
maxLenSparse int
}

func WithVectorDim(dim int64) GenColumnDataOption {
Expand All @@ -137,4 +138,10 @@ func WithArrayCapacity(capacity int64) GenColumnDataOption {
}
}

func WithSparseVectorLen(length int) GenColumnDataOption {
return func(opt *genDataOpt) {
opt.maxLenSparse = length
}
}

// -- create column data --
44 changes: 44 additions & 0 deletions test/testcases/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,50 @@ func TestCreateMultiVectorExceed(t *testing.T) {
common.CheckErr(t, errCreateCollection, false, "maximum vector field's number should be limited to 4")
}

// specify dim for sparse vector -> error
func TestCreateCollectionSparseVectorWithDim(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
mc := createMilvusClient(ctx, t)
allFields := []*entity.Field{
common.GenField(common.DefaultIntFieldName, entity.FieldTypeInt64, common.WithIsPrimaryKey(true), common.WithAutoID(false)),
common.GenField(common.DefaultSparseVecFieldName, entity.FieldTypeSparseVector, common.WithDim(common.DefaultDim)),
}
collName := common.GenRandomString(6)
schema := common.GenSchema(collName, false, allFields)

// create collection
errCreateCollection := mc.CreateCollection(ctx, schema, common.DefaultShards)
common.CheckErr(t, errCreateCollection, false, "dim should not be specified for sparse vector field sparseVec(0)")
}

// create collection with sparse vector
func TestCreateCollectionSparseVector(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
mc := createMilvusClient(ctx, t)
allFields := []*entity.Field{
common.GenField(common.DefaultIntFieldName, entity.FieldTypeInt64, common.WithIsPrimaryKey(true), common.WithAutoID(false)),
common.GenField(common.DefaultVarcharFieldName, entity.FieldTypeVarChar, common.WithMaxLength(common.TestMaxLen)),
common.GenField(common.DefaultSparseVecFieldName, entity.FieldTypeSparseVector),
}
collName := common.GenRandomString(6)
schema := common.GenSchema(collName, false, allFields)

// create collection
errCreateCollection := mc.CreateCollection(ctx, schema, common.DefaultShards)
common.CheckErr(t, errCreateCollection, true)

// describe collection
collection, err := mc.DescribeCollection(ctx, collName)
common.CheckErr(t, err, true)
common.CheckCollection(t, collection, collName, common.DefaultShards, schema, common.DefaultConsistencyLevel)
require.Len(t, collection.Schema.Fields, 3)
for _, field := range collection.Schema.Fields {
if field.DataType == entity.FieldTypeSparseVector {
require.Equal(t, common.DefaultSparseVecFieldName, field.Name)
}
}
}

// -- Get Collection Statistics --

func TestGetStaticsCollectionNotExisted(t *testing.T) {
Expand Down
55 changes: 53 additions & 2 deletions test/testcases/hybrid_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func TestHybridSearchMultiVectorsRangeSearch(t *testing.T) {
queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector)

// search with different reranker and offset
sp.AddRadius(10)
sp.AddRadius(20)
sp.AddRangeFilter(0.01)
for _, reranker := range []client.Reranker{
client.NewRRFReranker(),
Expand All @@ -300,8 +300,59 @@ func TestHybridSearchMultiVectorsRangeSearch(t *testing.T) {
for _, res := range resRange {
for _, score := range res.Scores {
require.GreaterOrEqual(t, score, float32(0.01))
require.LessOrEqual(t, score, float32(5))
require.LessOrEqual(t, score, float32(20))
}
}
}
}

func TestHybridSearchSparseVector(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/pull/32177")
t.Parallel()
idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"})
idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"})
for _, idx := range []entity.Index{idxInverted, idxWand} {
ctx := createContext(t, time.Second*common.DefaultTimeout*2)
// connect
mc := createMilvusClient(ctx, t)

// create -> insert [0, 3000) -> flush -> index -> load
cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}

dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 3,
dim: common.DefaultDim, EnableDynamicField: true}

// index params
idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
ips := []IndexParams{
{BuildIndex: true, Index: idx, FieldName: common.DefaultSparseVecFieldName, async: false},
{BuildIndex: true, Index: idxHnsw, FieldName: common.DefaultFloatVecFieldName, async: false},
}
collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))

// search
queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector)
queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
sp1, _ := entity.NewIndexSparseInvertedSearchParam(0.2)
sp2, _ := entity.NewIndexHNSWSearchParam(20)
expr := fmt.Sprintf("%s > 1", common.DefaultIntFieldName)
sReqs := []*client.ANNSearchRequest{
client.NewANNSearchRequest(common.DefaultSparseVecFieldName, entity.IP, expr, queryVec1, sp1, common.DefaultTopK),
client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec2, sp2, common.DefaultTopK),
}
for _, reranker := range []client.Reranker{
client.NewRRFReranker(),
client.NewWeightedReranker([]float64{0.5, 0.6}),
} {
// hybrid search
searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, reranker, sReqs)
common.CheckErr(t, errSearch, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultTopK)
common.CheckErr(t, errSearch, true)
outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName,
common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName}
common.CheckOutputFields(t, searchRes[0].Fields, outputFields)
}
}
}
Loading

0 comments on commit 319b4f4

Please sign in to comment.