From 744a36c28750f255f9b584e40c220e60d9b47b48 Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 15 Oct 2024 10:37:23 +0800 Subject: [PATCH] enhance: [GoSDK] support unmarshal result set into orm receiver (#36789) Related to milvus-io/milvus-sdk-go#800 Signed-off-by: Congqi Xia --- client/read.go | 24 ----- client/results.go | 194 +++++++++++++++++++++++++++++++++++++++++ client/results_test.go | 127 +++++++++++++++++++++++++++ client/row/type.go | 43 +++++++++ 4 files changed, 364 insertions(+), 24 deletions(-) create mode 100644 client/results.go create mode 100644 client/results_test.go create mode 100644 client/row/type.go diff --git a/client/read.go b/client/read.go index d13f5e2601cf9..d6f14b12a692e 100644 --- a/client/read.go +++ b/client/read.go @@ -29,30 +29,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) -type ResultSets struct{} - -type ResultSet struct { - ResultCount int // the returning entry count - GroupByValue column.Column - IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API - Fields DataSet // output field data - Scores []float32 // distance to the target vector - Err error // search error if any -} - -// DataSet is an alias type for column slice. -type DataSet []column.Column - -// GetColumn returns column with provided field name. -func (rs ResultSet) GetColumn(fieldName string) column.Column { - for _, column := range rs.Fields { - if column.Name() == fieldName { - return column - } - } - return nil -} - func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) { req := option.Request() collection, err := c.getCollection(ctx, req.GetCollectionName()) diff --git a/client/results.go b/client/results.go new file mode 100644 index 0000000000000..95a711b5177e8 --- /dev/null +++ b/client/results.go @@ -0,0 +1,194 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "reflect" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/client/v2/row" +) + +// ResultSet is struct for search result set. +type ResultSet struct { + // internal schema for unmarshaling + sch *entity.Schema + + ResultCount int // the returning entry count + GroupByValue column.Column + IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API + Fields DataSet // output field data + Scores []float32 // distance to the target vector + Err error // search error if any +} + +// GetColumn returns column with provided field name. +func (rs *ResultSet) GetColumn(fieldName string) column.Column { + for _, column := range rs.Fields { + if column.Name() == fieldName { + return column + } + } + return nil +} + +// Unmarshal puts dataset into receiver in row based way. +// `receiver` shall be a slice of pointer of model struct +// eg, []*Records, in which type `Record` defines the row data. +// note that distance/score is not unmarshaled here. +func (sr *ResultSet) Unmarshal(receiver any) (err error) { + err = sr.Fields.Unmarshal(receiver) + if err != nil { + return err + } + return sr.fillPKEntry(receiver) +} + +func (sr *ResultSet) fillPKEntry(receiver any) (err error) { + defer func() { + if x := recover(); x != nil { + err = errors.Newf("failed to unmarshal result set: %v", x) + } + }() + rr := reflect.ValueOf(receiver) + + if rr.Kind() == reflect.Ptr { + if rr.IsNil() && rr.CanAddr() { + rr.Set(reflect.New(rr.Type().Elem())) + } + rr = rr.Elem() + } + + rt := rr.Type() + rv := rr + + switch rt.Kind() { + case reflect.Slice: + pkField := sr.sch.PKField() + + et := rt.Elem() + for et.Kind() == reflect.Ptr { + et = et.Elem() + } + + candidates := row.ParseCandidate(et) + candi, ok := candidates[pkField.Name] + if !ok { + // pk field not found in struct, skip + return nil + } + for i := 0; i < sr.IDs.Len(); i++ { + row := rv.Index(i) + for row.Kind() == reflect.Ptr { + row = row.Elem() + } + + val, err := sr.IDs.Get(i) + if err != nil { + return err + } + row.Field(candi).Set(reflect.ValueOf(val)) + } + rr.Set(rv) + default: + return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) + } + return nil +} + +// DataSet is an alias type for column slice. +// Returned by query API. +type DataSet []column.Column + +// Len returns the row count of dataset. +// if there is no column, it shall return 0. +func (ds DataSet) Len() int { + if len(ds) == 0 { + return 0 + } + return ds[0].Len() +} + +// Unmarshal puts dataset into receiver in row based way. +// `receiver` shall be a slice of pointer of model struct +// eg, []*Records, in which type `Record` defines the row data. +func (ds DataSet) Unmarshal(receiver any) (err error) { + defer func() { + if x := recover(); x != nil { + err = errors.Newf("failed to unmarshal result set: %v", x) + } + }() + rr := reflect.ValueOf(receiver) + + if rr.Kind() == reflect.Ptr { + if rr.IsNil() && rr.CanAddr() { + rr.Set(reflect.New(rr.Type().Elem())) + } + rr = rr.Elem() + } + + rt := rr.Type() + rv := rr + + switch rt.Kind() { + // TODO maybe support Array and just fill data + // case reflect.Array: + case reflect.Slice: + et := rt.Elem() + if et.Kind() != reflect.Ptr { + return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind()) + } + for et.Kind() == reflect.Ptr { + et = et.Elem() + } + for i := 0; i < ds.Len(); i++ { + data := reflect.New(et) + err := ds.fillData(data.Elem(), et, i) + if err != nil { + return err + } + rv = reflect.Append(rv, data) + } + rr.Set(rv) + default: + return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) + } + return nil +} + +func (ds DataSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error { + m := row.ParseCandidate(dataType) + for i := 0; i < len(ds); i++ { + name := ds[i].Name() + fidx, ok := m[name] + if !ok { + // if target is not found, the behavior here is to ignore the column + // `strict` mode could be added in the future to return error if any column missing + continue + } + val, err := ds[i].Get(idx) + if err != nil { + return err + } + // TODO check datatype, return error here instead of reflect panicking & recover + data.Field(fidx).Set(reflect.ValueOf(val)) + } + return nil +} diff --git a/client/results_test.go b/client/results_test.go new file mode 100644 index 0000000000000..3d27847e68386 --- /dev/null +++ b/client/results_test.go @@ -0,0 +1,127 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" +) + +type ResultSetSuite struct { + suite.Suite +} + +func (s *ResultSetSuite) TestResultsetUnmarshal() { + type MyData struct { + A int64 `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + type OtherData struct { + A string `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + + var ( + idData = []int64{1, 2, 3} + vectorData = [][]float32{ + {0.1, 0.2}, + {0.1, 0.2}, + {0.1, 0.2}, + } + ) + + rs := DataSet([]column.Column{ + column.NewColumnInt64("id", idData), + column.NewColumnFloatVector("vector", 2, vectorData), + }) + err := rs.Unmarshal([]MyData{}) + s.Error(err) + + receiver := []MyData{} + err = rs.Unmarshal(&receiver) + s.Error(err) + + var ptrReceiver []*MyData + err = rs.Unmarshal(&ptrReceiver) + s.NoError(err) + + for idx, row := range ptrReceiver { + s.Equal(row.A, idData[idx]) + s.Equal(row.V, vectorData[idx]) + } + + var otherReceiver []*OtherData + err = rs.Unmarshal(&otherReceiver) + s.Error(err) +} + +func (s *ResultSetSuite) TestSearchResultUnmarshal() { + type MyData struct { + A int64 `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + type OtherData struct { + A string `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + + var ( + idData = []int64{1, 2, 3} + vectorData = [][]float32{ + {0.1, 0.2}, + {0.1, 0.2}, + {0.1, 0.2}, + } + ) + + sr := ResultSet{ + sch: entity.NewSchema(). + WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)). + WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)), + IDs: column.NewColumnInt64("id", idData), + Fields: DataSet([]column.Column{ + column.NewColumnFloatVector("vector", 2, vectorData), + }), + } + err := sr.Unmarshal([]MyData{}) + s.Error(err) + + receiver := []MyData{} + err = sr.Unmarshal(&receiver) + s.Error(err) + + var ptrReceiver []*MyData + err = sr.Unmarshal(&ptrReceiver) + s.NoError(err) + + for idx, row := range ptrReceiver { + s.Equal(row.A, idData[idx]) + s.Equal(row.V, vectorData[idx]) + } + + var otherReceiver []*OtherData + err = sr.Unmarshal(&otherReceiver) + s.Error(err) +} + +func TestResults(t *testing.T) { + suite.Run(t, new(ResultSetSuite)) +} diff --git a/client/row/type.go b/client/row/type.go new file mode 100644 index 0000000000000..5815487163833 --- /dev/null +++ b/client/row/type.go @@ -0,0 +1,43 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package row + +import ( + "go/ast" + "reflect" +) + +func ParseCandidate(dataType reflect.Type) map[string]int { + result := make(map[string]int) + for i := 0; i < dataType.NumField(); i++ { + f := dataType.Field(i) + // ignore anonymous field for now + if f.Anonymous || !ast.IsExported(f.Name) { + continue + } + + name := f.Name + tag := f.Tag.Get(MilvusTag) + tagSettings := ParseTagSetting(tag, MilvusTagSep) + if tagName, has := tagSettings[MilvusTagName]; has { + name = tagName + } + + result[name] = i + } + return result +}