Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: [GoSDK] support unmarshal result set into orm receiver #36789

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions client/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)

type ResultSets struct{}

type ResultSet struct {
ResultCount int // the returning entry count
GroupByValue column.Column
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields DataSet // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
}

// DataSet is an alias type for column slice.
type DataSet []column.Column

// GetColumn returns column with provided field name.
func (rs ResultSet) GetColumn(fieldName string) column.Column {
for _, column := range rs.Fields {
if column.Name() == fieldName {
return column
}
}
return nil
}

func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) {
req := option.Request()
collection, err := c.getCollection(ctx, req.GetCollectionName())
Expand Down
194 changes: 194 additions & 0 deletions client/results.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"reflect"

"github.com/cockroachdb/errors"

"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/row"
)

// ResultSet is struct for search result set.
type ResultSet struct {
// internal schema for unmarshaling
sch *entity.Schema

ResultCount int // the returning entry count
GroupByValue column.Column
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields DataSet // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
}

// GetColumn returns column with provided field name.
func (rs *ResultSet) GetColumn(fieldName string) column.Column {
for _, column := range rs.Fields {
if column.Name() == fieldName {
return column
}
}
return nil
}

// Unmarshal puts dataset into receiver in row based way.
// `receiver` shall be a slice of pointer of model struct
// eg, []*Records, in which type `Record` defines the row data.
// note that distance/score is not unmarshaled here.
func (sr *ResultSet) Unmarshal(receiver any) (err error) {
err = sr.Fields.Unmarshal(receiver)
if err != nil {
return err
}
return sr.fillPKEntry(receiver)
}

func (sr *ResultSet) fillPKEntry(receiver any) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
}
}()
rr := reflect.ValueOf(receiver)

if rr.Kind() == reflect.Ptr {
if rr.IsNil() && rr.CanAddr() {
rr.Set(reflect.New(rr.Type().Elem()))
}
rr = rr.Elem()
}

rt := rr.Type()
rv := rr

switch rt.Kind() {
case reflect.Slice:
pkField := sr.sch.PKField()

et := rt.Elem()
for et.Kind() == reflect.Ptr {
et = et.Elem()
}

candidates := row.ParseCandidate(et)
candi, ok := candidates[pkField.Name]
if !ok {
// pk field not found in struct, skip
return nil
}
for i := 0; i < sr.IDs.Len(); i++ {
row := rv.Index(i)
for row.Kind() == reflect.Ptr {
row = row.Elem()
}

val, err := sr.IDs.Get(i)
if err != nil {
return err
}
row.Field(candi).Set(reflect.ValueOf(val))
}
rr.Set(rv)
default:
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind())
}
return nil
}

// DataSet is an alias type for column slice.
// Returned by query API.
type DataSet []column.Column

// Len returns the row count of dataset.
// if there is no column, it shall return 0.
func (ds DataSet) Len() int {
if len(ds) == 0 {
return 0
}
return ds[0].Len()
}

// Unmarshal puts dataset into receiver in row based way.
// `receiver` shall be a slice of pointer of model struct
// eg, []*Records, in which type `Record` defines the row data.
func (ds DataSet) Unmarshal(receiver any) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
}
}()
rr := reflect.ValueOf(receiver)

if rr.Kind() == reflect.Ptr {
if rr.IsNil() && rr.CanAddr() {
rr.Set(reflect.New(rr.Type().Elem()))
}
rr = rr.Elem()
}

rt := rr.Type()
rv := rr

switch rt.Kind() {
// TODO maybe support Array and just fill data
// case reflect.Array:
case reflect.Slice:
et := rt.Elem()
if et.Kind() != reflect.Ptr {
return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind())
}
for et.Kind() == reflect.Ptr {
et = et.Elem()
}
for i := 0; i < ds.Len(); i++ {
data := reflect.New(et)
err := ds.fillData(data.Elem(), et, i)
if err != nil {
return err
}
rv = reflect.Append(rv, data)
}
rr.Set(rv)
default:
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind())
}
return nil
}

func (ds DataSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error {
m := row.ParseCandidate(dataType)
for i := 0; i < len(ds); i++ {
name := ds[i].Name()
fidx, ok := m[name]
if !ok {
// if target is not found, the behavior here is to ignore the column
// `strict` mode could be added in the future to return error if any column missing
continue
}
val, err := ds[i].Get(idx)
if err != nil {
return err
}
// TODO check datatype, return error here instead of reflect panicking & recover
data.Field(fidx).Set(reflect.ValueOf(val))
}
return nil
}
127 changes: 127 additions & 0 deletions client/results_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"testing"

"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
)

type ResultSetSuite struct {
suite.Suite
}

func (s *ResultSetSuite) TestResultsetUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}

var (
idData = []int64{1, 2, 3}
vectorData = [][]float32{
{0.1, 0.2},
{0.1, 0.2},
{0.1, 0.2},
}
)

rs := DataSet([]column.Column{
column.NewColumnInt64("id", idData),
column.NewColumnFloatVector("vector", 2, vectorData),
})
err := rs.Unmarshal([]MyData{})
s.Error(err)

receiver := []MyData{}
err = rs.Unmarshal(&receiver)
s.Error(err)

var ptrReceiver []*MyData
err = rs.Unmarshal(&ptrReceiver)
s.NoError(err)

for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
}

var otherReceiver []*OtherData
err = rs.Unmarshal(&otherReceiver)
s.Error(err)
}

func (s *ResultSetSuite) TestSearchResultUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}

var (
idData = []int64{1, 2, 3}
vectorData = [][]float32{
{0.1, 0.2},
{0.1, 0.2},
{0.1, 0.2},
}
)

sr := ResultSet{
sch: entity.NewSchema().
WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)),
IDs: column.NewColumnInt64("id", idData),
Fields: DataSet([]column.Column{
column.NewColumnFloatVector("vector", 2, vectorData),
}),
}
err := sr.Unmarshal([]MyData{})
s.Error(err)

receiver := []MyData{}
err = sr.Unmarshal(&receiver)
s.Error(err)

var ptrReceiver []*MyData
err = sr.Unmarshal(&ptrReceiver)
s.NoError(err)

for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
}

var otherReceiver []*OtherData
err = sr.Unmarshal(&otherReceiver)
s.Error(err)
}

func TestResults(t *testing.T) {
suite.Run(t, new(ResultSetSuite))
}
43 changes: 43 additions & 0 deletions client/row/type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package row

import (
"go/ast"
"reflect"
)

func ParseCandidate(dataType reflect.Type) map[string]int {
result := make(map[string]int)
for i := 0; i < dataType.NumField(); i++ {
f := dataType.Field(i)
// ignore anonymous field for now
if f.Anonymous || !ast.IsExported(f.Name) {
continue
}

name := f.Name
tag := f.Tag.Get(MilvusTag)
tagSettings := ParseTagSetting(tag, MilvusTagSep)
if tagName, has := tagSettings[MilvusTagName]; has {
name = tagName
}

result[name] = i
}
return result
}
Loading