Skip to content

Commit

Permalink
enhance: Support Sparse Embedding (#684)
Browse files Browse the repository at this point in the history
See also milvus-io/milvus#29419

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Mar 22, 2024
1 parent ffd629b commit 7ab04e8
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 1 deletion.
10 changes: 9 additions & 1 deletion client/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, o
if has && len(ids) > 0 {
flushed := func() bool {
resp, err := c.Service.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
SegmentIDs: ids,
SegmentIDs: ids,
FlushTs: resp.GetCollFlushTs()[collName],
CollectionName: collName,
})
if err != nil {
// TODO max retry
Expand Down Expand Up @@ -506,6 +508,12 @@ func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
placeHolderType = commonpb.PlaceholderType_FloatVector
case entity.BinaryVector:
placeHolderType = commonpb.PlaceholderType_BinaryVector
case entity.BFloat16Vector:
placeHolderType = commonpb.PlaceholderType_BFloat16Vector
case entity.Float16Vector:
placeHolderType = commonpb.PlaceholderType_FloatVector
case entity.SparseEmbedding:
placeHolderType = commonpb.PlaceholderType_SparseFloatVector
}
ph.Type = placeHolderType
for _, vector := range vectors {
Expand Down
19 changes: 19 additions & 0 deletions entity/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,25 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) {
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
case schema.DataType_SparseFloatVector:
sparseVectors := fd.GetVectors().GetSparseFloatVector()
if sparseVectors == nil {
return nil, errFieldDataTypeNotMatch
}
data := sparseVectors.Contents
if end < 0 {
end = len(data)
}
data = data[begin:end]
vectors := make([]SparseEmbedding, 0, len(data))
for _, bs := range data {
vector, err := deserializeSliceSparceEmbedding(bs)
if err != nil {
return nil, err
}
vectors = append(vectors, vector)
}
return NewColumnSparseVectors(fd.GetFieldName(), vectors), nil
default:
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
}
Expand Down
217 changes: 217 additions & 0 deletions entity/columns_sparse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// Copyright (C) 2019-2021 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

package entity

import (
"encoding/binary"
"fmt"
"math"
"sort"

"github.com/cockroachdb/errors"
schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)

type SparseEmbedding interface {
Dim() int // the dimension
Len() int // the actual items in this vector
Get(idx int) (pos uint32, value float32, ok bool)
Serialize() []byte
}

var _ SparseEmbedding = sliceSparseEmbedding{}
var _ Vector = sliceSparseEmbedding{}

type sliceSparseEmbedding struct {
positions []uint32
values []float32
dim int
len int
}

func (e sliceSparseEmbedding) Dim() int {
return e.dim
}

func (e sliceSparseEmbedding) Len() int {
return e.len
}

func (e sliceSparseEmbedding) FieldType() FieldType {
return FieldTypeSparseVector
}

func (e sliceSparseEmbedding) Get(idx int) (uint32, float32, bool) {
if idx < 0 || idx >= int(e.len) {
return 0, 0, false
}
return e.positions[idx], e.values[idx], true
}

func (e sliceSparseEmbedding) Serialize() []byte {
row := make([]byte, 8*e.Len())
for idx := 0; idx < e.Len(); idx++ {
pos, value, _ := e.Get(idx)
binary.LittleEndian.PutUint32(row[idx*8:], pos)
binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value))
}
return row
}

// Less implements sort.Interce
func (e sliceSparseEmbedding) Less(i, j int) bool {
return e.positions[i] < e.positions[j]
}

func (e sliceSparseEmbedding) Swap(i, j int) {
e.positions[i], e.positions[j] = e.positions[j], e.positions[i]
e.values[i], e.values[j] = e.values[j], e.values[i]
}

func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) {
length := len(bs)
if length%8 != 0 {
return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes")
}

length = length / 8

result := sliceSparseEmbedding{
positions: make([]uint32, length),
values: make([]float32, length),
len: length,
}

for i := 0; i < length; i++ {
result.positions[i] = binary.LittleEndian.Uint32(bs[i*8 : i*8+4])
result.values[i] = math.Float32frombits(binary.LittleEndian.Uint32(bs[i*8+4 : i*8+8]))
}
return result, nil
}

func NewSliceSparseEmbedding(positions []uint32, values []float32) (SparseEmbedding, error) {
if len(positions) != len(values) {
return nil, errors.New("invalid sparse embedding input, positions shall have same number of values")
}

se := sliceSparseEmbedding{
positions: positions,
values: values,
len: len(positions),
}

sort.Sort(se)

if se.len > 0 {
se.dim = int(se.positions[se.len-1]) + 1
}

return se, nil
}

var _ (Column) = (*ColumnSparseFloatVector)(nil)

type ColumnSparseFloatVector struct {
ColumnBase

vectors []SparseEmbedding
name string
}

// Name returns column name.
func (c *ColumnSparseFloatVector) Name() string {
return c.name
}

// Type returns column FieldType.
func (c *ColumnSparseFloatVector) Type() FieldType {
return FieldTypeSparseVector
}

// Len returns column values length.
func (c *ColumnSparseFloatVector) Len() int {
return len(c.vectors)
}

// Get returns value at index as interface{}.
func (c *ColumnSparseFloatVector) Get(idx int) (interface{}, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.vectors[idx], nil
}

// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnSparseFloatVector) ValueByIdx(idx int) (SparseEmbedding, error) {
var r SparseEmbedding // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.vectors[idx], nil
}

func (c *ColumnSparseFloatVector) FieldData() *schema.FieldData {
fd := &schema.FieldData{
Type: schema.DataType_SparseFloatVector,
FieldName: c.name,
}

dim := int(0)
data := make([][]byte, 0, len(c.vectors))
for _, vector := range c.vectors {
row := make([]byte, 8*vector.Len())
for idx := 0; idx < vector.Len(); idx++ {
pos, value, _ := vector.Get(idx)
binary.LittleEndian.PutUint32(row[idx*8:], pos)
binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value))
}
data = append(data, row)
if vector.Dim() > dim {
dim = vector.Dim()
}
}

fd.Field = &schema.FieldData_Vectors{
Vectors: &schema.VectorField{
Dim: int64(dim),
Data: &schema.VectorField_SparseFloatVector{
SparseFloatVector: &schema.SparseFloatArray{
Dim: int64(dim),
Contents: data,
},
},
},
}
return fd
}

func (c *ColumnSparseFloatVector) AppendValue(i interface{}) error {
v, ok := i.(SparseEmbedding)
if !ok {
return fmt.Errorf("invalid type, expect SparseEmbedding interface, got %T", i)
}
c.vectors = append(c.vectors, v)

return nil
}

func (c *ColumnSparseFloatVector) Data() []SparseEmbedding {
return c.vectors
}

func NewColumnSparseVectors(name string, values []SparseEmbedding) *ColumnSparseFloatVector {
return &ColumnSparseFloatVector{
name: name,
vectors: values,
}
}
120 changes: 120 additions & 0 deletions entity/columns_sparse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (C) 2019-2021 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

package entity

import (
"fmt"
"math/rand"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSliceSparseEmbedding(t *testing.T) {
t.Run("normal_case", func(t *testing.T) {

length := 1 + rand.Intn(5)
positions := make([]uint32, length)
values := make([]float32, length)
for i := 0; i < length; i++ {
positions[i] = uint32(i)
values[i] = rand.Float32()
}
se, err := NewSliceSparseEmbedding(positions, values)
require.NoError(t, err)

assert.EqualValues(t, length, se.Dim())
assert.EqualValues(t, length, se.Len())

bs := se.Serialize()
nv, err := deserializeSliceSparceEmbedding(bs)
require.NoError(t, err)

for i := 0; i < length; i++ {
pos, val, ok := se.Get(i)
require.True(t, ok)
assert.Equal(t, positions[i], pos)
assert.Equal(t, values[i], val)

npos, nval, ok := nv.Get(i)
require.True(t, ok)
assert.Equal(t, positions[i], npos)
assert.Equal(t, values[i], nval)
}

_, _, ok := se.Get(-1)
assert.False(t, ok)
_, _, ok = se.Get(length)
assert.False(t, ok)
})

t.Run("position values not match", func(t *testing.T) {
_, err := NewSliceSparseEmbedding([]uint32{1}, []float32{})
assert.Error(t, err)
})

}

func TestColumnSparseEmbedding(t *testing.T) {
columnName := fmt.Sprintf("column_sparse_embedding_%d", rand.Int())
columnLen := 8 + rand.Intn(10)

v := make([]SparseEmbedding, 0, columnLen)
for i := 0; i < columnLen; i++ {
length := 1 + rand.Intn(5)
positions := make([]uint32, length)
values := make([]float32, length)
for j := 0; j < length; j++ {
positions[j] = uint32(j)
values[j] = rand.Float32()
}
se, err := NewSliceSparseEmbedding(positions, values)
require.NoError(t, err)
v = append(v, se)
}
column := NewColumnSparseVectors(columnName, v)

t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, FieldTypeSparseVector, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})

t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})

t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.Error(t, err)
_, err = column.ValueByIdx(columnLen)
assert.Error(t, err)

_, err = column.Get(-1)
assert.Error(t, err)
_, err = column.Get(columnLen)
assert.Error(t, err)

for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.NoError(t, err)
assert.Equal(t, column.vectors[i], v)
getV, err := column.Get(i)
assert.NoError(t, err)
assert.Equal(t, v, getV)
}
})
}
2 changes: 2 additions & 0 deletions entity/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,6 @@ const (
FieldTypeFloat16Vector FieldType = 102
// FieldTypeBinaryVector field type bf16 vector
FieldTypeBFloat16Vector FieldType = 103
// FieldTypeBinaryVector field type sparse vector
FieldTypeSparseVector FieldType = 104
)

0 comments on commit 7ab04e8

Please sign in to comment.