diff --git a/client/data.go b/client/data.go index e6b526ebd..4cb19b274 100644 --- a/client/data.go +++ b/client/data.go @@ -143,25 +143,20 @@ func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []st Scores: results.GetScores()[offset : offset+rc], } - // parse result set if current nq is not empty - if rc > 0 { - entry.IDs, entry.Err = entity.IDColumns(results.GetIds(), offset, offset+rc) + entry.IDs, entry.Err = entity.IDColumns(schema, results.GetIds(), offset, offset+rc) + if entry.Err != nil { + continue + } + // parse group-by values + if gb != nil { + entry.GroupByValue, entry.Err = entity.FieldDataColumn(gb, offset, offset+rc) if entry.Err != nil { offset += rc continue } - // parse group-by values - if gb != nil { - entry.GroupByValue, entry.Err = entity.FieldDataColumn(gb, offset, offset+rc) - if entry.Err != nil { - offset += rc - continue - } - } - // entry.GroupByValue, entry.Err = c.parseSearchResult() - entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc) - sr = append(sr, entry) } + entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc) + sr = append(sr, entry) offset += rc } diff --git a/client/insert.go b/client/insert.go index d23301276..2f52f0e2e 100644 --- a/client/insert.go +++ b/client/insert.go @@ -77,7 +77,7 @@ func (c *GrpcClient) Insert(ctx context.Context, collName string, partitionName } MetaCache.setSessionTs(collName, resp.Timestamp) // 3. parse id column - return entity.IDColumns(resp.GetIDs(), 0, -1) + return entity.IDColumns(coll.Schema, resp.GetIDs(), 0, -1) } func (c *GrpcClient) processInsertColumns(colSchema *entity.Schema, columns ...entity.Column) ([]*schemapb.FieldData, int, error) { @@ -392,7 +392,7 @@ func (c *GrpcClient) Upsert(ctx context.Context, collName string, partitionName } MetaCache.setSessionTs(collName, resp.Timestamp) // 3. parse id column - return entity.IDColumns(resp.GetIDs(), 0, -1) + return entity.IDColumns(coll.Schema, resp.GetIDs(), 0, -1) } // BulkInsert data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments diff --git a/client/row.go b/client/row.go index 42001ea70..64ac5c022 100644 --- a/client/row.go +++ b/client/row.go @@ -118,7 +118,7 @@ func (c *GrpcClient) InsertRows(ctx context.Context, collName string, partitionN } MetaCache.setSessionTs(collName, resp.Timestamp) // 3. parse id column - return entity.IDColumns(resp.GetIDs(), 0, -1) + return entity.IDColumns(coll.Schema, resp.GetIDs(), 0, -1) } // SearchResultByRows search result for row-based Search diff --git a/entity/columns.go b/entity/columns.go index 27e6b93c1..8c06cee12 100644 --- a/entity/columns.go +++ b/entity/columns.go @@ -16,7 +16,7 @@ import ( "fmt" "math" - schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/cockroachdb/errors" ) @@ -28,7 +28,7 @@ type Column interface { Name() string Type() FieldType Len() int - FieldData() *schema.FieldData + FieldData() *schemapb.FieldData AppendValue(interface{}) error Get(int) (interface{}, error) GetAsInt64(int) (int64, error) @@ -142,38 +142,51 @@ func (bv BinaryVector) FieldType() FieldType { var errFieldDataTypeNotMatch = errors.New("FieldData type not matched") -// IDColumns converts schema.IDs to corresponding column +// IDColumns converts schemapb.IDs to corresponding column // currently Int64 / string may be in IDs -func IDColumns(idField *schema.IDs, begin, end int) (Column, error) { +func IDColumns(schema *Schema, idField *schemapb.IDs, begin, end int) (Column, error) { var idColumn Column - if idField == nil { - return nil, errors.New("nil Ids from response") + + pkField := schema.PKField() + if pkField == nil { + return nil, errors.New("PK Field not found") } - switch field := idField.GetIdField().(type) { - case *schema.IDs_IntId: + switch pkField.DataType { + case FieldTypeInt64: + data := idField.GetIntId().GetData() + if data == nil { + return NewColumnInt64(pkField.Name, nil), nil + } if end >= 0 { - idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end]) + idColumn = NewColumnInt64(pkField.Name, data[begin:end]) } else { - idColumn = NewColumnInt64("", field.IntId.GetData()[begin:]) + idColumn = NewColumnInt64(pkField.Name, data[begin:]) + } + case FieldTypeVarChar, FieldTypeString: + data := idField.GetStrId().GetData() + if data == nil { + return NewColumnVarChar(pkField.Name, nil), nil } - case *schema.IDs_StrId: if end >= 0 { - idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end]) + idColumn = NewColumnVarChar(pkField.Name, data[begin:end]) } else { - idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:]) + idColumn = NewColumnVarChar(pkField.Name, data[begin:]) } default: - return nil, fmt.Errorf("unsupported id type %v", field) + return nil, fmt.Errorf("unsupported id type %v", pkField.DataType) + } + if idField == nil { + return nil, errors.New("nil Ids from response") } return idColumn, nil } -// FieldDataColumn converts schema.FieldData to Column, used int search result conversion logic +// FieldDataColumn converts schemapb.FieldData to Column, used int search result conversion logic // begin, end specifies the start and end positions -func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { +func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { switch fd.GetType() { - case schema.DataType_Bool: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_BoolData) + case schemapb.DataType_Bool: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_BoolData) if !ok { return nil, errFieldDataTypeNotMatch } @@ -182,7 +195,7 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:end]), nil - case schema.DataType_Int8: + case schemapb.DataType_Int8: data, ok := getIntData(fd) if !ok { return nil, errFieldDataTypeNotMatch @@ -198,7 +211,7 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { return NewColumnInt8(fd.GetFieldName(), values[begin:end]), nil - case schema.DataType_Int16: + case schemapb.DataType_Int16: data, ok := getIntData(fd) if !ok { return nil, errFieldDataTypeNotMatch @@ -213,7 +226,7 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { return NewColumnInt16(fd.GetFieldName(), values[begin:end]), nil - case schema.DataType_Int32: + case schemapb.DataType_Int32: data, ok := getIntData(fd) if !ok { return nil, errFieldDataTypeNotMatch @@ -223,8 +236,8 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:end]), nil - case schema.DataType_Int64: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_LongData) + case schemapb.DataType_Int64: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_LongData) if !ok { return nil, errFieldDataTypeNotMatch } @@ -233,8 +246,8 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:end]), nil - case schema.DataType_Float: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_FloatData) + case schemapb.DataType_Float: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_FloatData) if !ok { return nil, errFieldDataTypeNotMatch } @@ -243,8 +256,8 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:end]), nil - case schema.DataType_Double: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_DoubleData) + case schemapb.DataType_Double: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_DoubleData) if !ok { return nil, errFieldDataTypeNotMatch } @@ -253,8 +266,8 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:end]), nil - case schema.DataType_String: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_StringData) + case schemapb.DataType_String: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData) if !ok { return nil, errFieldDataTypeNotMatch } @@ -263,8 +276,8 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil - case schema.DataType_VarChar: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_StringData) + case schemapb.DataType_VarChar: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData) if !ok { return nil, errFieldDataTypeNotMatch } @@ -273,12 +286,12 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil - case schema.DataType_Array: + case schemapb.DataType_Array: data := fd.GetScalars().GetArrayData() if data == nil { return nil, errFieldDataTypeNotMatch } - var arrayData []*schema.ScalarField + var arrayData []*schemapb.ScalarField if end < 0 { arrayData = data.GetData()[begin:] } else { @@ -287,8 +300,8 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { return parseArrayData(fd.GetFieldName(), data.GetElementType(), arrayData) - case schema.DataType_JSON: - data, ok := fd.GetScalars().GetData().(*schema.ScalarField_JsonData) + case schemapb.DataType_JSON: + data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_JsonData) isDynamic := fd.GetIsDynamic() if !ok { return nil, errFieldDataTypeNotMatch @@ -298,9 +311,9 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:end]).WithIsDynamic(isDynamic), nil - case schema.DataType_FloatVector: + case schemapb.DataType_FloatVector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_FloatVector) + x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -317,9 +330,9 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_BinaryVector: + case schemapb.DataType_BinaryVector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_BinaryVector) + x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -340,9 +353,9 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_Float16Vector: + case schemapb.DataType_Float16Vector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_Float16Vector) + x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -359,9 +372,9 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_BFloat16Vector: + case schemapb.DataType_BFloat16Vector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector) + x, ok := vectors.GetData().(*schemapb.VectorField_Bfloat16Vector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -377,7 +390,7 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { vector = append(vector, v) } return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_SparseFloatVector: + case schemapb.DataType_SparseFloatVector: sparseVectors := fd.GetVectors().GetSparseFloatVector() if sparseVectors == nil { return nil, errFieldDataTypeNotMatch @@ -401,17 +414,17 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } } -func parseArrayData(fieldName string, elementType schema.DataType, fieldDataList []*schema.ScalarField) (Column, error) { +func parseArrayData(fieldName string, elementType schemapb.DataType, fieldDataList []*schemapb.ScalarField) (Column, error) { switch elementType { - case schema.DataType_Bool: + case schemapb.DataType_Bool: var data [][]bool for _, fd := range fieldDataList { data = append(data, fd.GetBoolData().GetData()) } return NewColumnBoolArray(fieldName, data), nil - case schema.DataType_Int8: + case schemapb.DataType_Int8: var data [][]int8 for _, fd := range fieldDataList { raw := fd.GetIntData().GetData() @@ -423,7 +436,7 @@ func parseArrayData(fieldName string, elementType schema.DataType, fieldDataList } return NewColumnInt8Array(fieldName, data), nil - case schema.DataType_Int16: + case schemapb.DataType_Int16: var data [][]int16 for _, fd := range fieldDataList { raw := fd.GetIntData().GetData() @@ -435,35 +448,35 @@ func parseArrayData(fieldName string, elementType schema.DataType, fieldDataList } return NewColumnInt16Array(fieldName, data), nil - case schema.DataType_Int32: + case schemapb.DataType_Int32: var data [][]int32 for _, fd := range fieldDataList { data = append(data, fd.GetIntData().GetData()) } return NewColumnInt32Array(fieldName, data), nil - case schema.DataType_Int64: + case schemapb.DataType_Int64: var data [][]int64 for _, fd := range fieldDataList { data = append(data, fd.GetLongData().GetData()) } return NewColumnInt64Array(fieldName, data), nil - case schema.DataType_Float: + case schemapb.DataType_Float: var data [][]float32 for _, fd := range fieldDataList { data = append(data, fd.GetFloatData().GetData()) } return NewColumnFloatArray(fieldName, data), nil - case schema.DataType_Double: + case schemapb.DataType_Double: var data [][]float64 for _, fd := range fieldDataList { data = append(data, fd.GetDoubleData().GetData()) } return NewColumnDoubleArray(fieldName, data), nil - case schema.DataType_VarChar, schema.DataType_String: + case schemapb.DataType_VarChar, schemapb.DataType_String: var data [][][]byte for _, fd := range fieldDataList { strs := fd.GetStringData().GetData() @@ -483,15 +496,15 @@ func parseArrayData(fieldName string, elementType schema.DataType, fieldDataList // getIntData get int32 slice from result field data // also handles LongData bug (see also https://github.com/milvus-io/milvus/issues/23850) -func getIntData(fd *schema.FieldData) (*schema.ScalarField_IntData, bool) { +func getIntData(fd *schemapb.FieldData) (*schemapb.ScalarField_IntData, bool) { switch data := fd.GetScalars().GetData().(type) { - case *schema.ScalarField_IntData: + case *schemapb.ScalarField_IntData: return data, true - case *schema.ScalarField_LongData: + case *schemapb.ScalarField_LongData: // only alway empty LongData for backward compatibility if len(data.LongData.GetData()) == 0 { - return &schema.ScalarField_IntData{ - IntData: &schema.IntArray{}, + return &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{}, }, true } return nil, false @@ -500,12 +513,12 @@ func getIntData(fd *schema.FieldData) (*schema.ScalarField_IntData, bool) { } } -// FieldDataColumn converts schema.FieldData to vector Column -func FieldDataVector(fd *schema.FieldData) (Column, error) { +// FieldDataColumn converts schemapb.FieldData to vector Column +func FieldDataVector(fd *schemapb.FieldData) (Column, error) { switch fd.GetType() { - case schema.DataType_FloatVector: + case schemapb.DataType_FloatVector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_FloatVector) + x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -518,9 +531,9 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) { vector = append(vector, v) } return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_BinaryVector: + case schemapb.DataType_BinaryVector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_BinaryVector) + x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -537,9 +550,9 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) { vector = append(vector, v) } return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_Float16Vector: + case schemapb.DataType_Float16Vector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_Float16Vector) + x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector) if !ok { return nil, errFieldDataTypeNotMatch } @@ -552,9 +565,9 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) { vector = append(vector, v) } return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil - case schema.DataType_BFloat16Vector: + case schemapb.DataType_BFloat16Vector: vectors := fd.GetVectors() - x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector) + x, ok := vectors.GetData().(*schemapb.VectorField_Bfloat16Vector) if !ok { return nil, errFieldDataTypeNotMatch } diff --git a/entity/columns_test.go b/entity/columns_test.go index 833d788e3..c9b79c1e9 100644 --- a/entity/columns_test.go +++ b/entity/columns_test.go @@ -42,12 +42,27 @@ func TestIDColumns(t *testing.T) { dataLen := rand.Intn(100) + 1 base := rand.Intn(5000) // id start point + intPKCol := NewSchema().WithField( + NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(FieldTypeInt64), + ) + strPKCol := NewSchema().WithField( + NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(FieldTypeVarChar), + ) + t.Run("nil id", func(t *testing.T) { - _, err := IDColumns(nil, 0, -1) - assert.NotNil(t, err) + col, err := IDColumns(intPKCol, nil, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) + col, err = IDColumns(strPKCol, nil, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) idField := &schema.IDs{} - _, err = IDColumns(idField, 0, -1) - assert.NotNil(t, err) + col, err = IDColumns(intPKCol, idField, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) + col, err = IDColumns(strPKCol, idField, 0, -1) + assert.NoError(t, err) + assert.EqualValues(t, 0, col.Len()) }) t.Run("int ids", func(t *testing.T) { @@ -62,12 +77,12 @@ func TestIDColumns(t *testing.T) { }, }, } - column, err := IDColumns(idField, 0, dataLen) + column, err := IDColumns(intPKCol, idField, 0, dataLen) assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) - column, err = IDColumns(idField, 0, -1) // test -1 method + column, err = IDColumns(intPKCol, idField, 0, -1) // test -1 method assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) @@ -84,12 +99,12 @@ func TestIDColumns(t *testing.T) { }, }, } - column, err := IDColumns(idField, 0, dataLen) + column, err := IDColumns(strPKCol, idField, 0, dataLen) assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) - column, err = IDColumns(idField, 0, -1) // test -1 method + column, err = IDColumns(strPKCol, idField, 0, -1) // test -1 method assert.Nil(t, err) assert.NotNil(t, column) assert.Equal(t, dataLen, column.Len()) diff --git a/entity/schema.go b/entity/schema.go index d9395f31a..6b31bc085 100644 --- a/entity/schema.go +++ b/entity/schema.go @@ -55,6 +55,7 @@ type Schema struct { AutoID bool Fields []*Field EnableDynamicField bool + pkField *Field } // NewSchema creates an empty schema object. @@ -86,6 +87,9 @@ func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema { // WithField adds a field into schema and returns schema itself. func (s *Schema) WithField(f *Field) *Schema { + if f.PrimaryKey { + s.pkField = f + } s.Fields = append(s.Fields, f) return s } @@ -112,7 +116,11 @@ func (s *Schema) ReadProto(p *schema.CollectionSchema) *Schema { s.CollectionName = p.GetName() s.Fields = make([]*Field, 0, len(p.GetFields())) for _, fp := range p.GetFields() { - s.Fields = append(s.Fields, NewField().ReadProto(fp)) + field := NewField().ReadProto(fp) + if field.PrimaryKey { + s.pkField = field + } + s.Fields = append(s.Fields, field) } s.EnableDynamicField = p.GetEnableDynamicField() return s @@ -120,12 +128,15 @@ func (s *Schema) ReadProto(p *schema.CollectionSchema) *Schema { // PKFieldName returns pk field name for this schema. func (s *Schema) PKFieldName() string { - for _, field := range s.Fields { - if field.PrimaryKey { - return field.Name - } + if s.pkField == nil { + return "" } - return "" + return s.pkField.Name +} + +// PKField returns PK Field schema for this schema. +func (s *Schema) PKField() *Field { + return s.pkField } // Field represent field schema in milvus