From adefe1c89be2cc7338e862db261fa01afb91cf46 Mon Sep 17 00:00:00 2001 From: lixinguo Date: Wed, 27 Nov 2024 10:39:19 +0800 Subject: [PATCH] enhance: support templates for expression in Restful api Signed-off-by: lixinguo --- .../proxy/httpserver/handler_v2.go | 10 +- .../proxy/httpserver/handler_v2_test.go | 4 +- .../proxy/httpserver/request_v2.go | 66 +++---- .../distributed/proxy/httpserver/utils.go | 167 ++++++++++++++++++ .../proxy/httpserver/utils_test.go | 140 +++++++++++++++ 5 files changed, 351 insertions(+), 36 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 38f3984a1aa21..b035e058f0550 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -640,6 +640,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa QueryParams: []*commonpb.KeyValuePair{}, UseDefaultConsistency: true, } + req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) c.Set(ContextRequest, req) if httpReq.Offset > 0 { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) @@ -731,6 +732,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN PartitionName: httpReq.PartitionName, Expr: httpReq.Filter, } + req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) c.Set(ContextRequest, req) if req.Expr == "" { body, _ := c.Get(gin.BodyBytesKey) @@ -943,7 +945,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair { +func generateSearchParams(reqSearchParams searchParams) []*commonpb.KeyValuePair { var searchParams []*commonpb.KeyValuePair if reqSearchParams.Params == nil { reqSearchParams.Params = make(map[string]any) @@ -983,7 +985,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN return nil, err } - searchParams := generateSearchParams(ctx, c, httpReq.SearchParams) + searchParams := generateSearchParams(httpReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) @@ -1000,6 +1002,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN } req.SearchParams = searchParams req.PlaceholderGroup = placeholderGroup + req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest)) }) @@ -1052,7 +1055,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams := generateSearchParams(ctx, c, subReq.SearchParams) + searchParams := generateSearchParams(subReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) @@ -1076,6 +1079,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq PartitionNames: httpReq.PartitionNames, SearchParams: searchParams, } + searchReq.ExprTemplateValues = generateExpressionTemplate(subReq.ExprParams) req.Requests = append(req.Requests, searchReq) } bs, _ := json.Marshal(httpReq.Rerank.Params) diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 5248f9c9a69b8..dc54bd58b673a 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1494,7 +1494,7 @@ func TestSearchV2(t *testing.T) { queryTestCases := []requestBodyTestCase{} queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]}, "limit": 4, "outputFields": ["word_count"]}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, @@ -1575,7 +1575,7 @@ func TestSearchV2(t *testing.T) { queryTestCases = append(queryTestCases, requestBodyTestCase{ path: AdvancedSearchAction, requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + - `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "filter": "book_id in {list}", "exprParams":{"list": [2, 4, 6, 8]},"metricType": "L2", "limit": 3},` + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 0ef3c2045e0f6..b9b2fd74855d4 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -101,13 +101,14 @@ type JobIDReq struct { func (req *JobIDReq) GetJobID() string { return req.JobID } type QueryReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionNames []string `json:"partitionNames"` - OutputFields []string `json:"outputFields"` - Filter string `json:"filter"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + OutputFields []string `json:"outputFields"` + Filter string `json:"filter"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + ExprParams map[string]interface{} `json:"exprParams"` } func (req *QueryReqV2) GetDbName() string { return req.DbName } @@ -124,10 +125,11 @@ type CollectionIDReq struct { func (req *CollectionIDReq) GetDbName() string { return req.DbName } type CollectionFilterReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionName string `json:"partitionName"` - Filter string `json:"filter" binding:"required"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName"` + Filter string `json:"filter" binding:"required"` + ExprParams map[string]interface{} `json:"exprParams"` } func (req *CollectionFilterReq) GetDbName() string { return req.DbName } @@ -149,18 +151,19 @@ type searchParams struct { } type SearchReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - PartitionNames []string `json:"partitionNames"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - SearchParams searchParams `json:"searchParams"` - ConsistencyLevel string `json:"consistencyLevel"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + SearchParams searchParams `json:"searchParams"` + ConsistencyLevel string `json:"consistencyLevel"` + ExprParams map[string]interface{} `json:"exprParams"` // not use Params any more, just for compatibility Params map[string]float64 `json:"params"` } @@ -173,14 +176,15 @@ type Rand struct { } type SubSearchReq struct { - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - MetricType string `json:"metricType"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - SearchParams searchParams `json:"params"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + SearchParams searchParams `json:"params"` + ExprParams map[string]interface{} `json:"exprParams"` } type HybridSearchReq struct { diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 4d2ed437c3b66..e6f6edc546658 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1340,3 +1340,170 @@ func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLe // ConsistencyLevel_Bounded default in PyMilvus return commonpb.ConsistencyLevel_Bounded, true, nil } + +func generateTemplateArrayData(list []interface{}) *schemapb.TemplateArrayValue { + dtype := getTemplateArrayType(list) + var data *schemapb.TemplateArrayValue + switch dtype { + case schemapb.DataType_Bool: + result := make([]bool, len(list)) + for i, item := range list { + result[i] = item.(bool) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: result, + }, + }, + } + case schemapb.DataType_String: + result := make([]string, len(list)) + for i, item := range list { + result[i] = item.(string) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_StringData{ + StringData: &schemapb.StringArray{ + Data: result, + }, + }, + } + case schemapb.DataType_Int64: + result := make([]int64, len(list)) + for i, item := range list { + result[i] = int64(item.(float64)) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_LongData{ + LongData: &schemapb.LongArray{ + Data: result, + }, + }, + } + case schemapb.DataType_Float: + result := make([]float64, len(list)) + for i, item := range list { + result[i] = item.(float64) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: result, + }, + }, + } + case schemapb.DataType_Array: + result := make([]*schemapb.TemplateArrayValue, len(list)) + for i, item := range list { + result[i] = generateTemplateArrayData(item.([]interface{})) + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_ArrayData{ + ArrayData: &schemapb.TemplateArrayValueArray{ + Data: result, + }, + }, + } + case schemapb.DataType_JSON: + result := make([][]byte, len(list)) + for i, item := range list { + bytes, err := json.Marshal(item) + // won't happen + if err != nil { + panic(fmt.Sprintf("marshal data(%v) fail, please check it!", item)) + } + result[i] = bytes + } + data = &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: result, + }, + }, + } + // won't happen + default: + panic(fmt.Sprintf("Unexpected data(%v) type when generateTemplateArrayData, please check it!", list)) + } + return data +} + +func getTemplateArrayType(value []interface{}) schemapb.DataType { + dtype := getTemplateType(value[0]) + + for _, v := range value { + if getTemplateType(v) != dtype { + return schemapb.DataType_JSON + } + } + return dtype +} + +func getTemplateType(value interface{}) schemapb.DataType { + switch v := value.(type) { + case bool: + return schemapb.DataType_Bool + case string: + return schemapb.DataType_String + case float64: + // note: all passed number is float64 type + // if field type is float64, but value in ExpressionTemplate is int64, it's ok to use TemplateValue_Int64Val to store it + // it will convert to float64 in ./internal/parser/planparserv2/utils.go, Line 233 + if v == math.Trunc(v) && v >= math.MinInt64 && v <= math.MaxInt64 { + return schemapb.DataType_Int64 + } + return schemapb.DataType_Float + // it won't happen + // case int64: + case []interface{}: + return schemapb.DataType_Array + default: + panic(fmt.Sprintf("Unexpected data(%v) when getTemplateType, please check it!", value)) + } +} + +func generateExpressionTemplate(params map[string]interface{}) map[string]*schemapb.TemplateValue { + expressionTemplate := make(map[string]*schemapb.TemplateValue, len(params)) + + for name, value := range params { + dtype := getTemplateType(value) + var data *schemapb.TemplateValue + switch dtype { + case schemapb.DataType_Bool: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_BoolVal{ + BoolVal: value.(bool), + }, + } + case schemapb.DataType_String: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_StringVal{ + StringVal: value.(string), + }, + } + case schemapb.DataType_Int64: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: int64(value.(float64)), + }, + } + case schemapb.DataType_Float: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_FloatVal{ + FloatVal: value.(float64), + }, + } + case schemapb.DataType_Array: + data = &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: generateTemplateArrayData(value.([]interface{})), + }, + } + default: + panic(fmt.Sprintf("Unexpected data(%v) when generateExpressionTemplate, please check it!", data)) + } + expressionTemplate[name] = data + } + return expressionTemplate +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index d05f0e0661f42..c2dd7f8742f7a 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1592,3 +1592,143 @@ func TestConvertConsistencyLevel(t *testing.T) { _, _, err = convertConsistencyLevel("test") assert.NotNil(t, err) } + +func TestGenerateExpressionTemplate(t *testing.T) { + var mixedList []interface{} + var mixedAns [][]byte + + mixedList = append(mixedList, float64(1)) + mixedList = append(mixedList, "10") + mixedList = append(mixedList, true) + + val, _ := json.Marshal(1) + mixedAns = append(mixedAns, val) + val, _ = json.Marshal("10") + mixedAns = append(mixedAns, val) + val, _ = json.Marshal(true) + mixedAns = append(mixedAns, val) + // all passed number is float64 type, so all the number type has convert to float64 + expressionTemplates := []map[string]interface{}{ + {"str": "10"}, + {"min": float64(1), "max": float64(10)}, + {"bool": true}, + {"float64": 1.1}, + {"int64": float64(1)}, + {"list_of_str": []interface{}{"1", "10", "100"}}, + {"list_of_bool": []interface{}{true, false, true}}, + {"list_of_float": []interface{}{1.1, 10.1, 100.1}}, + {"list_of_int": []interface{}{float64(1), float64(10), float64(100)}}, + {"list_of_json": mixedList}, + } + ans := []map[string]*schemapb.TemplateValue{ + { + "str": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_StringVal{ + StringVal: "10", + }, + }, + }, + { + "min": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: 1, + }, + }, + "max": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: 10, + }, + }, + }, + { + "bool": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_BoolVal{ + BoolVal: true, + }, + }, + }, + { + "float64": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_FloatVal{ + FloatVal: 1.1, + }, + }, + }, + { + "int64": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_Int64Val{ + Int64Val: 1, + }, + }, + }, + { + "list_of_str": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"1", "10", "100"}, + }, + }, + }, + }, + }, + }, + { + "list_of_bool": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false, true}, + }, + }, + }, + }, + }, + }, + { + "list_of_float": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{1.1, 10.1, 100.1}, + }, + }, + }, + }, + }, + }, + { + "list_of_int": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 10, 100}, + }, + }, + }, + }, + }, + }, + { + "list_of_json": &schemapb.TemplateValue{ + Val: &schemapb.TemplateValue_ArrayVal{ + ArrayVal: &schemapb.TemplateArrayValue{ + Data: &schemapb.TemplateArrayValue_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: mixedAns, + }, + }, + }, + }, + }, + }, + } + for i, template := range expressionTemplates { + actual := generateExpressionTemplate(template) + assert.Equal(t, actual, ans[i]) + } +}