diff --git a/client/insert.go b/client/insert.go index fe24db40..d7b78cf7 100644 --- a/client/insert.go +++ b/client/insert.go @@ -224,7 +224,9 @@ func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, o if has && len(ids) > 0 { flushed := func() bool { resp, err := c.Service.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, + SegmentIDs: ids, + FlushTs: resp.GetCollFlushTs()[collName], + CollectionName: collName, }) if err != nil { // TODO max retry @@ -506,6 +508,12 @@ func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue { placeHolderType = commonpb.PlaceholderType_FloatVector case entity.BinaryVector: placeHolderType = commonpb.PlaceholderType_BinaryVector + case entity.BFloat16Vector: + placeHolderType = commonpb.PlaceholderType_BFloat16Vector + case entity.Float16Vector: + placeHolderType = commonpb.PlaceholderType_FloatVector + case entity.SparseEmbedding: + placeHolderType = commonpb.PlaceholderType_SparseFloatVector } ph.Type = placeHolderType for _, vector := range vectors { diff --git a/entity/columns.go b/entity/columns.go index cc3e4b90..27e6b93c 100644 --- a/entity/columns.go +++ b/entity/columns.go @@ -377,6 +377,25 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { vector = append(vector, v) } return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_SparseFloatVector: + sparseVectors := fd.GetVectors().GetSparseFloatVector() + if sparseVectors == nil { + return nil, errFieldDataTypeNotMatch + } + data := sparseVectors.Contents + if end < 0 { + end = len(data) + } + data = data[begin:end] + vectors := make([]SparseEmbedding, 0, len(data)) + for _, bs := range data { + vector, err := deserializeSliceSparceEmbedding(bs) + if err != nil { + return nil, err + } + vectors = append(vectors, vector) + } + return NewColumnSparseVectors(fd.GetFieldName(), vectors), nil default: return nil, fmt.Errorf("unsupported data type %s", fd.GetType()) } diff --git a/entity/columns_sparse.go b/entity/columns_sparse.go new file mode 100644 index 00000000..6fd06add --- /dev/null +++ b/entity/columns_sparse.go @@ -0,0 +1,217 @@ +// Copyright (C) 2019-2021 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package entity + +import ( + "encoding/binary" + "fmt" + "math" + "sort" + + "github.com/cockroachdb/errors" + schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type SparseEmbedding interface { + Dim() int // the dimension + Len() int // the actual items in this vector + Get(idx int) (pos uint32, value float32, ok bool) + Serialize() []byte +} + +var _ SparseEmbedding = sliceSparseEmbedding{} +var _ Vector = sliceSparseEmbedding{} + +type sliceSparseEmbedding struct { + positions []uint32 + values []float32 + dim int + len int +} + +func (e sliceSparseEmbedding) Dim() int { + return e.dim +} + +func (e sliceSparseEmbedding) Len() int { + return e.len +} + +func (e sliceSparseEmbedding) FieldType() FieldType { + return FieldTypeSparseVector +} + +func (e sliceSparseEmbedding) Get(idx int) (uint32, float32, bool) { + if idx < 0 || idx >= int(e.len) { + return 0, 0, false + } + return e.positions[idx], e.values[idx], true +} + +func (e sliceSparseEmbedding) Serialize() []byte { + row := make([]byte, 8*e.Len()) + for idx := 0; idx < e.Len(); idx++ { + pos, value, _ := e.Get(idx) + binary.LittleEndian.PutUint32(row[idx*8:], pos) + binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value)) + } + return row +} + +// Less implements sort.Interce +func (e sliceSparseEmbedding) Less(i, j int) bool { + return e.positions[i] < e.positions[j] +} + +func (e sliceSparseEmbedding) Swap(i, j int) { + e.positions[i], e.positions[j] = e.positions[j], e.positions[i] + e.values[i], e.values[j] = e.values[j], e.values[i] +} + +func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) { + length := len(bs) + if length%8 != 0 { + return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes") + } + + length = length / 8 + + result := sliceSparseEmbedding{ + positions: make([]uint32, length), + values: make([]float32, length), + len: length, + } + + for i := 0; i < length; i++ { + result.positions[i] = binary.LittleEndian.Uint32(bs[i*8 : i*8+4]) + result.values[i] = math.Float32frombits(binary.LittleEndian.Uint32(bs[i*8+4 : i*8+8])) + } + return result, nil +} + +func NewSliceSparseEmbedding(positions []uint32, values []float32) (SparseEmbedding, error) { + if len(positions) != len(values) { + return nil, errors.New("invalid sparse embedding input, positions shall have same number of values") + } + + se := sliceSparseEmbedding{ + positions: positions, + values: values, + len: len(positions), + } + + sort.Sort(se) + + if se.len > 0 { + se.dim = int(se.positions[se.len-1]) + 1 + } + + return se, nil +} + +var _ (Column) = (*ColumnSparseFloatVector)(nil) + +type ColumnSparseFloatVector struct { + ColumnBase + + vectors []SparseEmbedding + name string +} + +// Name returns column name. +func (c *ColumnSparseFloatVector) Name() string { + return c.name +} + +// Type returns column FieldType. +func (c *ColumnSparseFloatVector) Type() FieldType { + return FieldTypeSparseVector +} + +// Len returns column values length. +func (c *ColumnSparseFloatVector) Len() int { + return len(c.vectors) +} + +// Get returns value at index as interface{}. +func (c *ColumnSparseFloatVector) Get(idx int) (interface{}, error) { + if idx < 0 || idx >= c.Len() { + return nil, errors.New("index out of range") + } + return c.vectors[idx], nil +} + +// ValueByIdx returns value of the provided index +// error occurs when index out of range +func (c *ColumnSparseFloatVector) ValueByIdx(idx int) (SparseEmbedding, error) { + var r SparseEmbedding // use default value + if idx < 0 || idx >= c.Len() { + return r, errors.New("index out of range") + } + return c.vectors[idx], nil +} + +func (c *ColumnSparseFloatVector) FieldData() *schema.FieldData { + fd := &schema.FieldData{ + Type: schema.DataType_SparseFloatVector, + FieldName: c.name, + } + + dim := int(0) + data := make([][]byte, 0, len(c.vectors)) + for _, vector := range c.vectors { + row := make([]byte, 8*vector.Len()) + for idx := 0; idx < vector.Len(); idx++ { + pos, value, _ := vector.Get(idx) + binary.LittleEndian.PutUint32(row[idx*8:], pos) + binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value)) + } + data = append(data, row) + if vector.Dim() > dim { + dim = vector.Dim() + } + } + + fd.Field = &schema.FieldData_Vectors{ + Vectors: &schema.VectorField{ + Dim: int64(dim), + Data: &schema.VectorField_SparseFloatVector{ + SparseFloatVector: &schema.SparseFloatArray{ + Dim: int64(dim), + Contents: data, + }, + }, + }, + } + return fd +} + +func (c *ColumnSparseFloatVector) AppendValue(i interface{}) error { + v, ok := i.(SparseEmbedding) + if !ok { + return fmt.Errorf("invalid type, expect SparseEmbedding interface, got %T", i) + } + c.vectors = append(c.vectors, v) + + return nil +} + +func (c *ColumnSparseFloatVector) Data() []SparseEmbedding { + return c.vectors +} + +func NewColumnSparseVectors(name string, values []SparseEmbedding) *ColumnSparseFloatVector { + return &ColumnSparseFloatVector{ + name: name, + vectors: values, + } +} diff --git a/entity/columns_sparse_test.go b/entity/columns_sparse_test.go new file mode 100644 index 00000000..67d953a8 --- /dev/null +++ b/entity/columns_sparse_test.go @@ -0,0 +1,120 @@ +// Copyright (C) 2019-2021 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package entity + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSliceSparseEmbedding(t *testing.T) { + t.Run("normal_case", func(t *testing.T) { + + length := 1 + rand.Intn(5) + positions := make([]uint32, length) + values := make([]float32, length) + for i := 0; i < length; i++ { + positions[i] = uint32(i) + values[i] = rand.Float32() + } + se, err := NewSliceSparseEmbedding(positions, values) + require.NoError(t, err) + + assert.EqualValues(t, length, se.Dim()) + assert.EqualValues(t, length, se.Len()) + + bs := se.Serialize() + nv, err := deserializeSliceSparceEmbedding(bs) + require.NoError(t, err) + + for i := 0; i < length; i++ { + pos, val, ok := se.Get(i) + require.True(t, ok) + assert.Equal(t, positions[i], pos) + assert.Equal(t, values[i], val) + + npos, nval, ok := nv.Get(i) + require.True(t, ok) + assert.Equal(t, positions[i], npos) + assert.Equal(t, values[i], nval) + } + + _, _, ok := se.Get(-1) + assert.False(t, ok) + _, _, ok = se.Get(length) + assert.False(t, ok) + }) + + t.Run("position values not match", func(t *testing.T) { + _, err := NewSliceSparseEmbedding([]uint32{1}, []float32{}) + assert.Error(t, err) + }) + +} + +func TestColumnSparseEmbedding(t *testing.T) { + columnName := fmt.Sprintf("column_sparse_embedding_%d", rand.Int()) + columnLen := 8 + rand.Intn(10) + + v := make([]SparseEmbedding, 0, columnLen) + for i := 0; i < columnLen; i++ { + length := 1 + rand.Intn(5) + positions := make([]uint32, length) + values := make([]float32, length) + for j := 0; j < length; j++ { + positions[j] = uint32(j) + values[j] = rand.Float32() + } + se, err := NewSliceSparseEmbedding(positions, values) + require.NoError(t, err) + v = append(v, se) + } + column := NewColumnSparseVectors(columnName, v) + + t.Run("test column attribute", func(t *testing.T) { + assert.Equal(t, columnName, column.Name()) + assert.Equal(t, FieldTypeSparseVector, column.Type()) + assert.Equal(t, columnLen, column.Len()) + assert.EqualValues(t, v, column.Data()) + }) + + t.Run("test column field data", func(t *testing.T) { + fd := column.FieldData() + assert.NotNil(t, fd) + assert.Equal(t, fd.GetFieldName(), columnName) + }) + + t.Run("test column value by idx", func(t *testing.T) { + _, err := column.ValueByIdx(-1) + assert.Error(t, err) + _, err = column.ValueByIdx(columnLen) + assert.Error(t, err) + + _, err = column.Get(-1) + assert.Error(t, err) + _, err = column.Get(columnLen) + assert.Error(t, err) + + for i := 0; i < columnLen; i++ { + v, err := column.ValueByIdx(i) + assert.NoError(t, err) + assert.Equal(t, column.vectors[i], v) + getV, err := column.Get(i) + assert.NoError(t, err) + assert.Equal(t, v, getV) + } + }) +} diff --git a/entity/schema.go b/entity/schema.go index 868561e6..d9395f31 100644 --- a/entity/schema.go +++ b/entity/schema.go @@ -483,4 +483,6 @@ const ( FieldTypeFloat16Vector FieldType = 102 // FieldTypeBinaryVector field type bf16 vector FieldTypeBFloat16Vector FieldType = 103 + // FieldTypeBinaryVector field type sparse vector + FieldTypeSparseVector FieldType = 104 )