From 9f3f25f39db894fabdd72b3d3d8b4882476dffcf Mon Sep 17 00:00:00 2001 From: Congqi Xia Date: Mon, 18 Mar 2024 19:11:58 +0800 Subject: [PATCH] Fix: Adapt bf16/fp16 in row-based API See also #673 Signed-off-by: Congqi Xia --- entity/rows.go | 38 +++++++++++++++++++++-- entity/rows_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/entity/rows.go b/entity/rows.go index 4c473cb4..33799d41 100644 --- a/entity/rows.go +++ b/entity/rows.go @@ -38,6 +38,9 @@ const ( // VectorDimTag struct tag const for vector dimension VectorDimTag = `DIM` + // VectorTypeTag struct tag const for binary vector type + VectorTypeTag = `VECTOR_TYPE` + // MilvusPrimaryKey struct tag const for primary key indicator MilvusPrimaryKey = `PRIMARY_KEY` @@ -197,8 +200,15 @@ func ParseSchemaAny(r interface{}) (*Schema, error) { } elemType := ft.Elem() switch elemType.Kind() { - case reflect.Uint8: // []byte! - field.DataType = FieldTypeBinaryVector + case reflect.Uint8: // []byte, could be BinaryVector, fp16, bf 6 + switch tagSettings[VectorTypeTag] { + case "fp16": + field.DataType = FieldTypeFloat16Vector + case "bf16": + field.DataType = FieldTypeBFloat16Vector + default: + field.DataType = FieldTypeBinaryVector + } case reflect.Float32: field.DataType = FieldTypeFloatVector default: @@ -355,6 +365,30 @@ func AnyToColumns(rows []interface{}, schemas ...*Schema) ([]Column, error) { } col := NewColumnBinaryVector(field.Name, int(dim), data) nameColumns[field.Name] = col + case FieldTypeFloat16Vector: + data := make([][]byte, 0, rowsLen) + dimStr, has := field.TypeParams[TypeParamDim] + if !has { + return []Column{}, errors.New("vector field with no dim") + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return []Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error()) + } + col := NewColumnFloat16Vector(field.Name, int(dim), data) + nameColumns[field.Name] = col + case FieldTypeBFloat16Vector: + data := make([][]byte, 0, rowsLen) + dimStr, has := field.TypeParams[TypeParamDim] + if !has { + return []Column{}, errors.New("vector field with no dim") + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return []Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error()) + } + col := NewColumnBFloat16Vector(field.Name, int(dim), data) + nameColumns[field.Name] = col } } diff --git a/entity/rows_test.go b/entity/rows_test.go index a9e0333d..19d182d6 100644 --- a/entity/rows_test.go +++ b/entity/rows_test.go @@ -117,6 +117,18 @@ func TestParseSchema(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "RowBase", sch.CollectionName) + getVectorField := func(schema *Schema) *Field { + for _, field := range schema.Fields { + if field.DataType == FieldTypeFloatVector || + field.DataType == FieldTypeBinaryVector || + field.DataType == FieldTypeBFloat16Vector || + field.DataType == FieldTypeFloat16Vector { + return field + } + } + return nil + } + type ValidStruct struct { RowBase ID int64 `milvus:"primary_key"` @@ -134,6 +146,44 @@ func TestParseSchema(t *testing.T) { assert.NotNil(t, sch) assert.Equal(t, "ValidStruct", sch.CollectionName) + type ValidFp16Struct struct { + RowBase + ID int64 `milvus:"primary_key"` + Attr1 int8 + Attr2 int16 + Attr3 int32 + Attr4 float32 + Attr5 float64 + Attr6 string + Vector []byte `milvus:"dim:128;vector_type:fp16"` + } + fp16Vs := &ValidFp16Struct{} + sch, err = ParseSchema(fp16Vs) + assert.Nil(t, err) + assert.NotNil(t, sch) + assert.Equal(t, "ValidFp16Struct", sch.CollectionName) + vectorField := getVectorField(sch) + assert.Equal(t, FieldTypeFloat16Vector, vectorField.DataType) + + type ValidBf16Struct struct { + RowBase + ID int64 `milvus:"primary_key"` + Attr1 int8 + Attr2 int16 + Attr3 int32 + Attr4 float32 + Attr5 float64 + Attr6 string + Vector []byte `milvus:"dim:128;vector_type:bf16"` + } + bf16Vs := &ValidBf16Struct{} + sch, err = ParseSchema(bf16Vs) + assert.Nil(t, err) + assert.NotNil(t, sch) + assert.Equal(t, "ValidBf16Struct", sch.CollectionName) + vectorField = getVectorField(sch) + assert.Equal(t, FieldTypeBFloat16Vector, vectorField.DataType) + type ValidByteStruct struct { RowBase ID int64 `milvus:"primary_key"` @@ -240,6 +290,32 @@ func (s *RowsSuite) TestRowsToColumns() { s.Equal("Vector", columns[0].Name()) }) + s.Run("fp16", func() { + type BF16Struct struct { + RowBase + ID int64 `milvus:"primary_key;auto_id"` + Vector []byte `milvus:"dim:16;vector_type:bf16"` + } + columns, err := RowsToColumns([]Row{&BF16Struct{}}) + s.Nil(err) + s.Require().Equal(1, len(columns)) + s.Equal("Vector", columns[0].Name()) + s.Equal(FieldTypeBFloat16Vector, columns[0].Type()) + }) + + s.Run("fp16", func() { + type FP16Struct struct { + RowBase + ID int64 `milvus:"primary_key;auto_id"` + Vector []byte `milvus:"dim:16;vector_type:fp16"` + } + columns, err := RowsToColumns([]Row{&FP16Struct{}}) + s.Nil(err) + s.Require().Equal(1, len(columns)) + s.Equal("Vector", columns[0].Name()) + s.Equal(FieldTypeFloat16Vector, columns[0].Type()) + }) + s.Run("invalid_cases", func() { // empty input _, err := RowsToColumns([]Row{})