Skip to content

Commit

Permalink
enhance: refactor createIndex in RESTful API(milvus-io#37235) (milvus…
Browse files Browse the repository at this point in the history
…-io#37237)

pr: milvus-io#37235 
2.5: milvus-io#37236

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Nov 7, 2024
1 parent c8ba682 commit 60f9631
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 7 deletions.
12 changes: 10 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -1801,8 +1801,16 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
}
c.Set(ContextRequest, req)

for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
var err error
req.ExtraParams, err = convertToExtraParams(indexParam)
if err != nil {
// will not happen
log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
Expand Down
40 changes: 40 additions & 0 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,46 @@ func TestDatabaseWrapper(t *testing.T) {
}
}

func TestCreateIndex(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)

postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(IndexCategory, CreateAction)
// the previous format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`),
})
// the current format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`),
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}

func TestCreateCollection(t *testing.T) {
paramtable.Init()
// disable rate limit
Expand Down
8 changes: 8 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,14 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro
if indexParam.MetricType != "" {
params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType})
}
if indexParam.IndexType == "" {
for key, value := range indexParam.Params {
if key == common.IndexTypeKey {
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: fmt.Sprintf("%v", value)})
break
}
}
}
if len(indexParam.Params) != 0 {
v, err := json.Marshal(indexParam.Params)
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions tests/restful_client_v2/testcases/test_index_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def test_index_e2e(self, dim, metric_type, index_type):
"metricType": f"{metric_type}"}]
}
if index_type == "HNSW":
payload["indexParams"][0]["indexType"]="HNSW"
payload["indexParams"][0]["params"] = {"index_type": "HNSW", "M": "16", "efConstruction": "200"}
if index_type == "AUTOINDEX":
payload["indexParams"][0]["indexType"]="AUTOINDEX"
payload["indexParams"][0]["params"] = {"index_type": "AUTOINDEX"}
rsp = self.index_client.index_create(payload)
assert rsp['code'] == 0
Expand All @@ -89,8 +91,8 @@ def test_index_e2e(self, dim, metric_type, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']
assert expected_index[i]['metricType'] == actual_index[i]['metricType']
assert expected_index[i]["params"]['index_type'] == actual_index[i]['indexType']

# drop index
for i in range(len(actual_index)):
Expand Down Expand Up @@ -152,7 +154,7 @@ def test_index_for_scalar_field(self, dim, index_type):
# create index
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector",
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector","indexType": "INVERTED",
"params": {"index_type": "INVERTED"}}]
}
rsp = self.index_client.index_create(payload)
Expand All @@ -169,7 +171,7 @@ def test_index_for_scalar_field(self, dim, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']

@pytest.mark.parametrize("index_type", ["BIN_FLAT", "BIN_IVF_FLAT"])
@pytest.mark.parametrize("metric_type", ["JACCARD", "HAMMING"])
Expand Down Expand Up @@ -220,7 +222,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type):
index_name = "binary_vector_index"
payload = {
"collectionName": name,
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,"indexType": index_type,
"params": {"index_type": index_type}}]
}
if index_type == "BIN_IVF_FLAT":
Expand All @@ -239,7 +241,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']


@pytest.mark.L1
Expand Down

0 comments on commit 60f9631

Please sign in to comment.