Skip to content

Commit

Permalink
enhance: support templates for expression in Restful api (milvus-io#3…
Browse files Browse the repository at this point in the history
…8040)

milvus-io#36672

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Dec 2, 2024
1 parent 3a7a8c7 commit cb92668
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 38 deletions.
10 changes: 7 additions & 3 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
QueryParams: []*commonpb.KeyValuePair{},
UseDefaultConsistency: true,
}
req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams)
c.Set(ContextRequest, req)
if httpReq.Offset > 0 {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
Expand Down Expand Up @@ -748,6 +749,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN
PartitionName: httpReq.PartitionName,
Expr: httpReq.Filter,
}
req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams)
c.Set(ContextRequest, req)
if req.Expr == "" {
body, _ := c.Get(gin.BodyBytesKey)
Expand Down Expand Up @@ -976,7 +978,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}

func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair {
func generateSearchParams(reqSearchParams searchParams) []*commonpb.KeyValuePair {
var searchParams []*commonpb.KeyValuePair
if reqSearchParams.Params == nil {
reqSearchParams.Params = make(map[string]any)
Expand Down Expand Up @@ -1017,7 +1019,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
return nil, err
}

searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
searchParams := generateSearchParams(httpReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
if httpReq.GroupByField != "" {
Expand All @@ -1040,6 +1042,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
}
req.SearchParams = searchParams
req.PlaceholderGroup = placeholderGroup
req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
})
Expand Down Expand Up @@ -1093,7 +1096,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams := generateSearchParams(ctx, c, subReq.SearchParams)
searchParams := generateSearchParams(subReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
Expand All @@ -1116,6 +1119,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
}
searchReq.ExprTemplateValues = generateExpressionTemplate(subReq.ExprParams)
req.Requests = append(req.Requests, searchReq)
}
bs, _ := json.Marshal(httpReq.Rerank.Params)
Expand Down
4 changes: 2 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,7 @@ func TestSearchV2(t *testing.T) {
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]}, "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
Expand Down Expand Up @@ -2077,7 +2077,7 @@ func TestSearchV2(t *testing.T) {
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]},"metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
Expand Down
70 changes: 37 additions & 33 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ type JobIDReq struct {
func (req *JobIDReq) GetJobID() string { return req.JobID }

type QueryReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
ExprParams map[string]interface{} `json:"exprParams"`
}

func (req *QueryReqV2) GetDbName() string { return req.DbName }
Expand All @@ -140,10 +141,11 @@ type CollectionIDReq struct {
func (req *CollectionIDReq) GetDbName() string { return req.DbName }

type CollectionFilterReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Filter string `json:"filter" binding:"required"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Filter string `json:"filter" binding:"required"`
ExprParams map[string]interface{} `json:"exprParams"`
}

func (req *CollectionFilterReq) GetDbName() string { return req.DbName }
Expand All @@ -165,20 +167,21 @@ type searchParams struct {
}

type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"`
StrictGroupSize bool `json:"strictGroupSize"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"`
StrictGroupSize bool `json:"strictGroupSize"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
ExprParams map[string]interface{} `json:"exprParams"`
// not use Params any more, just for compatibility
Params map[string]float64 `json:"params"`
}
Expand All @@ -191,14 +194,15 @@ type Rand struct {
}

type SubSearchReq struct {
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"params"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"params"`
ExprParams map[string]interface{} `json:"exprParams"`
}

type HybridSearchReq struct {
Expand Down
168 changes: 168 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"fmt"
"math"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -1679,3 +1680,170 @@ func RequestHandlerFunc(c *gin.Context) {
}
c.Next()
}

func generateTemplateArrayData(list []interface{}) *schemapb.TemplateArrayValue {
dtype := getTemplateArrayType(list)
var data *schemapb.TemplateArrayValue
switch dtype {
case schemapb.DataType_Bool:
result := make([]bool, len(list))
for i, item := range list {
result[i] = item.(bool)
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_BoolData{
BoolData: &schemapb.BoolArray{
Data: result,
},
},
}
case schemapb.DataType_String:
result := make([]string, len(list))
for i, item := range list {
result[i] = item.(string)
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_StringData{
StringData: &schemapb.StringArray{
Data: result,
},
},
}
case schemapb.DataType_Int64:
result := make([]int64, len(list))
for i, item := range list {
result[i] = int64(item.(float64))
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_LongData{
LongData: &schemapb.LongArray{
Data: result,
},
},
}
case schemapb.DataType_Float:
result := make([]float64, len(list))
for i, item := range list {
result[i] = item.(float64)
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: result,
},
},
}
case schemapb.DataType_Array:
result := make([]*schemapb.TemplateArrayValue, len(list))
for i, item := range list {
result[i] = generateTemplateArrayData(item.([]interface{}))
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_ArrayData{
ArrayData: &schemapb.TemplateArrayValueArray{
Data: result,
},
},
}
case schemapb.DataType_JSON:
result := make([][]byte, len(list))
for i, item := range list {
bytes, err := json.Marshal(item)
// won't happen
if err != nil {
panic(fmt.Sprintf("marshal data(%v) fail, please check it!", item))
}
result[i] = bytes
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_JsonData{
JsonData: &schemapb.JSONArray{
Data: result,
},
},
}
// won't happen
default:
panic(fmt.Sprintf("Unexpected data(%v) type when generateTemplateArrayData, please check it!", list))
}
return data
}

func getTemplateArrayType(value []interface{}) schemapb.DataType {
dtype := getTemplateType(value[0])

for _, v := range value {
if getTemplateType(v) != dtype {
return schemapb.DataType_JSON
}
}
return dtype
}

func getTemplateType(value interface{}) schemapb.DataType {
switch v := value.(type) {
case bool:
return schemapb.DataType_Bool
case string:
return schemapb.DataType_String
case float64:
// note: all passed number is float64 type
// if field type is float64, but value in ExpressionTemplate is int64, it's ok to use TemplateValue_Int64Val to store it
// it will convert to float64 in ./internal/parser/planparserv2/utils.go, Line 233
if v == math.Trunc(v) && v >= math.MinInt64 && v <= math.MaxInt64 {
return schemapb.DataType_Int64
}
return schemapb.DataType_Float
// it won't happen
// case int64:
case []interface{}:
return schemapb.DataType_Array
default:
panic(fmt.Sprintf("Unexpected data(%v) when getTemplateType, please check it!", value))
}
}

func generateExpressionTemplate(params map[string]interface{}) map[string]*schemapb.TemplateValue {
expressionTemplate := make(map[string]*schemapb.TemplateValue, len(params))

for name, value := range params {
dtype := getTemplateType(value)
var data *schemapb.TemplateValue
switch dtype {
case schemapb.DataType_Bool:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_BoolVal{
BoolVal: value.(bool),
},
}
case schemapb.DataType_String:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_StringVal{
StringVal: value.(string),
},
}
case schemapb.DataType_Int64:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_Int64Val{
Int64Val: int64(value.(float64)),
},
}
case schemapb.DataType_Float:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_FloatVal{
FloatVal: value.(float64),
},
}
case schemapb.DataType_Array:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_ArrayVal{
ArrayVal: generateTemplateArrayData(value.([]interface{})),
},
}
default:
panic(fmt.Sprintf("Unexpected data(%v) when generateExpressionTemplate, please check it!", data))
}
expressionTemplate[name] = data
}
return expressionTemplate
}
Loading

0 comments on commit cb92668

Please sign in to comment.