Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: support templates for expression in Restful api(#38040) #38161

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
QueryParams: []*commonpb.KeyValuePair{},
UseDefaultConsistency: true,
}
req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams)
c.Set(ContextRequest, req)
if httpReq.Offset > 0 {
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
Expand Down Expand Up @@ -731,6 +732,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN
PartitionName: httpReq.PartitionName,
Expr: httpReq.Filter,
}
req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams)
c.Set(ContextRequest, req)
if req.Expr == "" {
body, _ := c.Get(gin.BodyBytesKey)
Expand Down Expand Up @@ -943,7 +945,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
})
}

func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair {
func generateSearchParams(reqSearchParams searchParams) []*commonpb.KeyValuePair {
var searchParams []*commonpb.KeyValuePair
if reqSearchParams.Params == nil {
reqSearchParams.Params = make(map[string]any)
Expand Down Expand Up @@ -983,7 +985,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
return nil, err
}

searchParams := generateSearchParams(ctx, c, httpReq.SearchParams)
searchParams := generateSearchParams(httpReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
Expand All @@ -1000,6 +1002,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
}
req.SearchParams = searchParams
req.PlaceholderGroup = placeholderGroup
req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams)
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
})
Expand Down Expand Up @@ -1052,7 +1055,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
body, _ := c.Get(gin.BodyBytesKey)
searchArray := gjson.Get(string(body.([]byte)), "search").Array()
for i, subReq := range httpReq.Search {
searchParams := generateSearchParams(ctx, c, subReq.SearchParams)
searchParams := generateSearchParams(subReq.SearchParams)
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
Expand All @@ -1076,6 +1079,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
}
searchReq.ExprTemplateValues = generateExpressionTemplate(subReq.ExprParams)
req.Requests = append(req.Requests, searchReq)
}
bs, _ := json.Marshal(httpReq.Rerank.Params)
Expand Down
4 changes: 2 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ func TestSearchV2(t *testing.T) {
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]}, "limit": 4, "outputFields": ["word_count"]}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
Expand Down Expand Up @@ -1575,7 +1575,7 @@ func TestSearchV2(t *testing.T) {
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]},"metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
Expand Down
66 changes: 35 additions & 31 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ type JobIDReq struct {
func (req *JobIDReq) GetJobID() string { return req.JobID }

type QueryReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
ExprParams map[string]interface{} `json:"exprParams"`
}

func (req *QueryReqV2) GetDbName() string { return req.DbName }
Expand All @@ -124,10 +125,11 @@ type CollectionIDReq struct {
func (req *CollectionIDReq) GetDbName() string { return req.DbName }

type CollectionFilterReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Filter string `json:"filter" binding:"required"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Filter string `json:"filter" binding:"required"`
ExprParams map[string]interface{} `json:"exprParams"`
}

func (req *CollectionFilterReq) GetDbName() string { return req.DbName }
Expand All @@ -149,18 +151,19 @@ type searchParams struct {
}

type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
SearchParams searchParams `json:"searchParams"`
ConsistencyLevel string `json:"consistencyLevel"`
ExprParams map[string]interface{} `json:"exprParams"`
// not use Params any more, just for compatibility
Params map[string]float64 `json:"params"`
}
Expand All @@ -173,14 +176,15 @@ type Rand struct {
}

type SubSearchReq struct {
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"params"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
SearchParams searchParams `json:"params"`
ExprParams map[string]interface{} `json:"exprParams"`
}

type HybridSearchReq struct {
Expand Down
167 changes: 167 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1340,3 +1340,170 @@
// ConsistencyLevel_Bounded default in PyMilvus
return commonpb.ConsistencyLevel_Bounded, true, nil
}

func generateTemplateArrayData(list []interface{}) *schemapb.TemplateArrayValue {
dtype := getTemplateArrayType(list)
var data *schemapb.TemplateArrayValue
switch dtype {
case schemapb.DataType_Bool:
result := make([]bool, len(list))
for i, item := range list {
result[i] = item.(bool)
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_BoolData{
BoolData: &schemapb.BoolArray{
Data: result,
},
},
}
case schemapb.DataType_String:
result := make([]string, len(list))
for i, item := range list {
result[i] = item.(string)
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_StringData{
StringData: &schemapb.StringArray{
Data: result,
},
},
}
case schemapb.DataType_Int64:
result := make([]int64, len(list))
for i, item := range list {
result[i] = int64(item.(float64))
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_LongData{
LongData: &schemapb.LongArray{
Data: result,
},
},
}
case schemapb.DataType_Float:
result := make([]float64, len(list))
for i, item := range list {
result[i] = item.(float64)
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: result,
},
},
}
case schemapb.DataType_Array:
result := make([]*schemapb.TemplateArrayValue, len(list))
for i, item := range list {
result[i] = generateTemplateArrayData(item.([]interface{}))
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_ArrayData{
ArrayData: &schemapb.TemplateArrayValueArray{
Data: result,
},
},
}

Check warning on line 1407 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L1396-L1407

Added lines #L1396 - L1407 were not covered by tests
case schemapb.DataType_JSON:
result := make([][]byte, len(list))
for i, item := range list {
bytes, err := json.Marshal(item)
// won't happen
if err != nil {
panic(fmt.Sprintf("marshal data(%v) fail, please check it!", item))

Check warning on line 1414 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L1414

Added line #L1414 was not covered by tests
}
result[i] = bytes
}
data = &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_JsonData{
JsonData: &schemapb.JSONArray{
Data: result,
},
},
}
// won't happen
default:
panic(fmt.Sprintf("Unexpected data(%v) type when generateTemplateArrayData, please check it!", list))

Check warning on line 1427 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L1426-L1427

Added lines #L1426 - L1427 were not covered by tests
}
return data
}

func getTemplateArrayType(value []interface{}) schemapb.DataType {
dtype := getTemplateType(value[0])

for _, v := range value {
if getTemplateType(v) != dtype {
return schemapb.DataType_JSON
}
}
return dtype
}

func getTemplateType(value interface{}) schemapb.DataType {
switch v := value.(type) {
case bool:
return schemapb.DataType_Bool
case string:
return schemapb.DataType_String
case float64:
// note: all passed number is float64 type
// if field type is float64, but value in ExpressionTemplate is int64, it's ok to use TemplateValue_Int64Val to store it
// it will convert to float64 in ./internal/parser/planparserv2/utils.go, Line 233
if v == math.Trunc(v) && v >= math.MinInt64 && v <= math.MaxInt64 {
return schemapb.DataType_Int64
}
return schemapb.DataType_Float
// it won't happen
// case int64:
case []interface{}:
return schemapb.DataType_Array
default:
panic(fmt.Sprintf("Unexpected data(%v) when getTemplateType, please check it!", value))

Check warning on line 1462 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L1461-L1462

Added lines #L1461 - L1462 were not covered by tests
}
}

func generateExpressionTemplate(params map[string]interface{}) map[string]*schemapb.TemplateValue {
expressionTemplate := make(map[string]*schemapb.TemplateValue, len(params))

for name, value := range params {
dtype := getTemplateType(value)
var data *schemapb.TemplateValue
switch dtype {
case schemapb.DataType_Bool:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_BoolVal{
BoolVal: value.(bool),
},
}
case schemapb.DataType_String:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_StringVal{
StringVal: value.(string),
},
}
case schemapb.DataType_Int64:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_Int64Val{
Int64Val: int64(value.(float64)),
},
}
case schemapb.DataType_Float:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_FloatVal{
FloatVal: value.(float64),
},
}
case schemapb.DataType_Array:
data = &schemapb.TemplateValue{
Val: &schemapb.TemplateValue_ArrayVal{
ArrayVal: generateTemplateArrayData(value.([]interface{})),
},
}
default:
panic(fmt.Sprintf("Unexpected data(%v) when generateExpressionTemplate, please check it!", data))

Check warning on line 1504 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L1503-L1504

Added lines #L1503 - L1504 were not covered by tests
}
expressionTemplate[name] = data
}
return expressionTemplate
}
Loading
Loading