Skip to content

Commit

Permalink
enhance: refactor createIndex in RESTful API (#37235)
Browse files Browse the repository at this point in the history
Make the parameter input method consistent with miluvs-client.

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Nov 7, 2024
1 parent 40b770c commit 86fd320
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 6 deletions.
12 changes: 10 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -1876,8 +1876,16 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
}
c.Set(ContextRequest, req)

for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
var err error
req.ExtraParams, err = convertToExtraParams(indexParam)
if err != nil {
// will not happen
log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
Expand Down
40 changes: 40 additions & 0 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,46 @@ func TestDocInDocOutSearch(t *testing.T) {
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
}

func TestCreateIndex(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)

postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(IndexCategory, CreateAction)
// the previous format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`),
})
// the current format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`),
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}

func TestCreateCollection(t *testing.T) {
paramtable.Init()
// disable rate limit
Expand Down
8 changes: 8 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,14 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro
if indexParam.IndexType != "" {
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType})
}
if indexParam.IndexType == "" {
for key, value := range indexParam.Params {
if key == common.IndexTypeKey {
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: fmt.Sprintf("%v", value)})
break
}
}
}
if indexParam.MetricType != "" {
params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType})
}
Expand Down
8 changes: 4 additions & 4 deletions tests/restful_client_v2/testcases/test_index_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_index_for_scalar_field(self, dim, index_type):
# create index
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector",
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector", "indexType": "INVERTED",
"params": {"index_type": "INVERTED"}}]
}
rsp = self.index_client.index_create(payload)
Expand All @@ -177,7 +177,7 @@ def test_index_for_scalar_field(self, dim, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']

@pytest.mark.parametrize("index_type", ["BIN_FLAT", "BIN_IVF_FLAT"])
@pytest.mark.parametrize("metric_type", ["JACCARD", "HAMMING"])
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type):
index_name = "binary_vector_index"
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type, "indexType": index_type,
"params": {"index_type": index_type}}]
}
if index_type == "BIN_IVF_FLAT":
Expand All @@ -247,7 +247,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']

@pytest.mark.parametrize("insert_round", [1])
@pytest.mark.parametrize("auto_id", [True])
Expand Down

0 comments on commit 86fd320

Please sign in to comment.