-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enhance: [GoSDK] support unmarshal result set into orm receiver (#36789)
Related to milvus-io/milvus-sdk-go#800 Signed-off-by: Congqi Xia <[email protected]>
- Loading branch information
Showing
4 changed files
with
364 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
// Licensed to the LF AI & Data foundation under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package client | ||
|
||
import ( | ||
"reflect" | ||
|
||
"github.com/cockroachdb/errors" | ||
|
||
"github.com/milvus-io/milvus/client/v2/column" | ||
"github.com/milvus-io/milvus/client/v2/entity" | ||
"github.com/milvus-io/milvus/client/v2/row" | ||
) | ||
|
||
// ResultSet is struct for search result set. | ||
type ResultSet struct { | ||
// internal schema for unmarshaling | ||
sch *entity.Schema | ||
|
||
ResultCount int // the returning entry count | ||
GroupByValue column.Column | ||
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API | ||
Fields DataSet // output field data | ||
Scores []float32 // distance to the target vector | ||
Err error // search error if any | ||
} | ||
|
||
// GetColumn returns column with provided field name. | ||
func (rs *ResultSet) GetColumn(fieldName string) column.Column { | ||
for _, column := range rs.Fields { | ||
if column.Name() == fieldName { | ||
return column | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
// Unmarshal puts dataset into receiver in row based way. | ||
// `receiver` shall be a slice of pointer of model struct | ||
// eg, []*Records, in which type `Record` defines the row data. | ||
// note that distance/score is not unmarshaled here. | ||
func (sr *ResultSet) Unmarshal(receiver any) (err error) { | ||
err = sr.Fields.Unmarshal(receiver) | ||
if err != nil { | ||
return err | ||
} | ||
return sr.fillPKEntry(receiver) | ||
} | ||
|
||
func (sr *ResultSet) fillPKEntry(receiver any) (err error) { | ||
defer func() { | ||
if x := recover(); x != nil { | ||
err = errors.Newf("failed to unmarshal result set: %v", x) | ||
} | ||
}() | ||
rr := reflect.ValueOf(receiver) | ||
|
||
if rr.Kind() == reflect.Ptr { | ||
if rr.IsNil() && rr.CanAddr() { | ||
rr.Set(reflect.New(rr.Type().Elem())) | ||
} | ||
rr = rr.Elem() | ||
} | ||
|
||
rt := rr.Type() | ||
rv := rr | ||
|
||
switch rt.Kind() { | ||
case reflect.Slice: | ||
pkField := sr.sch.PKField() | ||
|
||
et := rt.Elem() | ||
for et.Kind() == reflect.Ptr { | ||
et = et.Elem() | ||
} | ||
|
||
candidates := row.ParseCandidate(et) | ||
candi, ok := candidates[pkField.Name] | ||
if !ok { | ||
// pk field not found in struct, skip | ||
return nil | ||
} | ||
for i := 0; i < sr.IDs.Len(); i++ { | ||
row := rv.Index(i) | ||
for row.Kind() == reflect.Ptr { | ||
row = row.Elem() | ||
} | ||
|
||
val, err := sr.IDs.Get(i) | ||
if err != nil { | ||
return err | ||
} | ||
row.Field(candi).Set(reflect.ValueOf(val)) | ||
} | ||
rr.Set(rv) | ||
default: | ||
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) | ||
} | ||
return nil | ||
} | ||
|
||
// DataSet is an alias type for column slice. | ||
// Returned by query API. | ||
type DataSet []column.Column | ||
|
||
// Len returns the row count of dataset. | ||
// if there is no column, it shall return 0. | ||
func (ds DataSet) Len() int { | ||
if len(ds) == 0 { | ||
return 0 | ||
} | ||
return ds[0].Len() | ||
} | ||
|
||
// Unmarshal puts dataset into receiver in row based way. | ||
// `receiver` shall be a slice of pointer of model struct | ||
// eg, []*Records, in which type `Record` defines the row data. | ||
func (ds DataSet) Unmarshal(receiver any) (err error) { | ||
defer func() { | ||
if x := recover(); x != nil { | ||
err = errors.Newf("failed to unmarshal result set: %v", x) | ||
} | ||
}() | ||
rr := reflect.ValueOf(receiver) | ||
|
||
if rr.Kind() == reflect.Ptr { | ||
if rr.IsNil() && rr.CanAddr() { | ||
rr.Set(reflect.New(rr.Type().Elem())) | ||
} | ||
rr = rr.Elem() | ||
} | ||
|
||
rt := rr.Type() | ||
rv := rr | ||
|
||
switch rt.Kind() { | ||
// TODO maybe support Array and just fill data | ||
// case reflect.Array: | ||
case reflect.Slice: | ||
et := rt.Elem() | ||
if et.Kind() != reflect.Ptr { | ||
return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind()) | ||
} | ||
for et.Kind() == reflect.Ptr { | ||
et = et.Elem() | ||
} | ||
for i := 0; i < ds.Len(); i++ { | ||
data := reflect.New(et) | ||
err := ds.fillData(data.Elem(), et, i) | ||
if err != nil { | ||
return err | ||
} | ||
rv = reflect.Append(rv, data) | ||
} | ||
rr.Set(rv) | ||
default: | ||
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) | ||
} | ||
return nil | ||
} | ||
|
||
func (ds DataSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error { | ||
m := row.ParseCandidate(dataType) | ||
for i := 0; i < len(ds); i++ { | ||
name := ds[i].Name() | ||
fidx, ok := m[name] | ||
if !ok { | ||
// if target is not found, the behavior here is to ignore the column | ||
// `strict` mode could be added in the future to return error if any column missing | ||
continue | ||
} | ||
val, err := ds[i].Get(idx) | ||
if err != nil { | ||
return err | ||
} | ||
// TODO check datatype, return error here instead of reflect panicking & recover | ||
data.Field(fidx).Set(reflect.ValueOf(val)) | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
// Licensed to the LF AI & Data foundation under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package client | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/suite" | ||
|
||
"github.com/milvus-io/milvus/client/v2/column" | ||
"github.com/milvus-io/milvus/client/v2/entity" | ||
) | ||
|
||
type ResultSetSuite struct { | ||
suite.Suite | ||
} | ||
|
||
func (s *ResultSetSuite) TestResultsetUnmarshal() { | ||
type MyData struct { | ||
A int64 `milvus:"name:id"` | ||
V []float32 `milvus:"name:vector"` | ||
} | ||
type OtherData struct { | ||
A string `milvus:"name:id"` | ||
V []float32 `milvus:"name:vector"` | ||
} | ||
|
||
var ( | ||
idData = []int64{1, 2, 3} | ||
vectorData = [][]float32{ | ||
{0.1, 0.2}, | ||
{0.1, 0.2}, | ||
{0.1, 0.2}, | ||
} | ||
) | ||
|
||
rs := DataSet([]column.Column{ | ||
column.NewColumnInt64("id", idData), | ||
column.NewColumnFloatVector("vector", 2, vectorData), | ||
}) | ||
err := rs.Unmarshal([]MyData{}) | ||
s.Error(err) | ||
|
||
receiver := []MyData{} | ||
err = rs.Unmarshal(&receiver) | ||
s.Error(err) | ||
|
||
var ptrReceiver []*MyData | ||
err = rs.Unmarshal(&ptrReceiver) | ||
s.NoError(err) | ||
|
||
for idx, row := range ptrReceiver { | ||
s.Equal(row.A, idData[idx]) | ||
s.Equal(row.V, vectorData[idx]) | ||
} | ||
|
||
var otherReceiver []*OtherData | ||
err = rs.Unmarshal(&otherReceiver) | ||
s.Error(err) | ||
} | ||
|
||
func (s *ResultSetSuite) TestSearchResultUnmarshal() { | ||
type MyData struct { | ||
A int64 `milvus:"name:id"` | ||
V []float32 `milvus:"name:vector"` | ||
} | ||
type OtherData struct { | ||
A string `milvus:"name:id"` | ||
V []float32 `milvus:"name:vector"` | ||
} | ||
|
||
var ( | ||
idData = []int64{1, 2, 3} | ||
vectorData = [][]float32{ | ||
{0.1, 0.2}, | ||
{0.1, 0.2}, | ||
{0.1, 0.2}, | ||
} | ||
) | ||
|
||
sr := ResultSet{ | ||
sch: entity.NewSchema(). | ||
WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)). | ||
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)), | ||
IDs: column.NewColumnInt64("id", idData), | ||
Fields: DataSet([]column.Column{ | ||
column.NewColumnFloatVector("vector", 2, vectorData), | ||
}), | ||
} | ||
err := sr.Unmarshal([]MyData{}) | ||
s.Error(err) | ||
|
||
receiver := []MyData{} | ||
err = sr.Unmarshal(&receiver) | ||
s.Error(err) | ||
|
||
var ptrReceiver []*MyData | ||
err = sr.Unmarshal(&ptrReceiver) | ||
s.NoError(err) | ||
|
||
for idx, row := range ptrReceiver { | ||
s.Equal(row.A, idData[idx]) | ||
s.Equal(row.V, vectorData[idx]) | ||
} | ||
|
||
var otherReceiver []*OtherData | ||
err = sr.Unmarshal(&otherReceiver) | ||
s.Error(err) | ||
} | ||
|
||
func TestResults(t *testing.T) { | ||
suite.Run(t, new(ResultSetSuite)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Licensed to the LF AI & Data foundation under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package row | ||
|
||
import ( | ||
"go/ast" | ||
"reflect" | ||
) | ||
|
||
func ParseCandidate(dataType reflect.Type) map[string]int { | ||
result := make(map[string]int) | ||
for i := 0; i < dataType.NumField(); i++ { | ||
f := dataType.Field(i) | ||
// ignore anonymous field for now | ||
if f.Anonymous || !ast.IsExported(f.Name) { | ||
continue | ||
} | ||
|
||
name := f.Name | ||
tag := f.Tag.Get(MilvusTag) | ||
tagSettings := ParseTagSetting(tag, MilvusTagSep) | ||
if tagName, has := tagSettings[MilvusTagName]; has { | ||
name = tagName | ||
} | ||
|
||
result[name] = i | ||
} | ||
return result | ||
} |