Skip to content

Commit

Permalink
enhance: support db request in Restful api(#38140)(#38078) (#38188)
Browse files Browse the repository at this point in the history
pr: #38078
pr: #38140
issue: #38077

---------

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Dec 4, 2024
1 parent 7bd401c commit ab88d23
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 162 deletions.
3 changes: 3 additions & 0 deletions internal/distributed/proxy/httpserver/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
// v2
const (
// --- category ---
DataBaseCategory = "/databases/"
CollectionCategory = "/collections/"
EntityCategory = "/entities/"
PartitionCategory = "/partitions/"
Expand Down Expand Up @@ -74,6 +75,8 @@ const (
HTTPCollectionName = "collectionName"
HTTPCollectionID = "collectionID"
HTTPDbName = "dbName"
HTTPDbID = "dbID"
HTTPProperties = "properties"
HTTPPartitionName = "partitionName"
HTTPPartitionNames = "partitionNames"
HTTPUserName = "userName"
Expand Down
217 changes: 146 additions & 71 deletions internal/distributed/proxy/httpserver/handler_v2.go

Large diffs are not rendered by default.

191 changes: 114 additions & 77 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,118 +376,155 @@ func TestTimeout(t *testing.T) {
}
}

func TestDatabaseWrapper(t *testing.T) {
func TestCreateIndex(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)

postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "exist"},
}, nil).Twice()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once()
h := NewHandlersV2(mp)
ginHandler := gin.Default()
app := ginHandler.Group("", genAuthMiddleWare(false))
path := "/wrapper/database"
app.POST(path, wrapperPost(func() any { return &DefaultReq{} }, h.wrapperCheckDatabase(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
return nil, nil
})))
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName": "exist"}`),
})
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(IndexCategory, CreateAction)
// the previous format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName": "non-exist"}`),
errMsg: "database not found, database: non-exist",
errCode: 800, // ErrDatabaseNotFound
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`),
})
// the current format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName": "test"}`),
errMsg: "",
errCode: 65535,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`),
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
if testcase.errCode != 0 {
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}

mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "default"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "test"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once()
rawTestCases := []rawTestCase{
{
errMsg: "database not found, database: test",
errCode: 800, // ErrDatabaseNotFound
},
{},
{
errMsg: "",
errCode: 65535,
},
}
for _, testcase := range rawTestCases {
t.Run("post with db"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader([]byte(`{}`)))
req.Header.Set(HTTPHeaderDBName, "test")
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}

func TestCreateIndex(t *testing.T) {
func TestDatabase(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)

postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(IndexCategory, CreateAction)
// the previous format
path := versionalV2(DataBaseCategory, CreateAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`),
requestBody: []byte(`{"dbName":"test"}`),
})
// the current format
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`),
requestBody: []byte(`{"dbName":"invalid_name"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
path = versionalV2(DataBaseCategory, DropAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{DbNames: []string{"a", "b", "c"}}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
path = versionalV2(DataBaseCategory, ListAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{DbName: "test", DbID: 100}, nil).Once()
mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{
Status: &commonpb.Status{
Code: 1100,
Reason: "mock",
},
}, nil).Once()
path = versionalV2(DataBaseCategory, DescribeAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(
&commonpb.Status{
Code: 1100,
Reason: "mock",
}, nil).Once()
path = versionalV2(DataBaseCategory, AlterAction)
// success
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"test"}`),
})
// mock fail
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName":"mock"}`),
errMsg: "mock",
errCode: 1100, // ErrParameterInvalid
})

for _, testcase := range postTestCases {
Expand Down
11 changes: 11 additions & 0 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)

type EmptyReq struct{}

func (req *EmptyReq) GetDbName() string { return "" }

type DatabaseReq struct {
DbName string `json:"dbName"`
}

func (req *DatabaseReq) GetDbName() string { return req.DbName }

type DatabaseReqWithProperties struct {
DbName string `json:"dbName" binding:"required"`
Properties map[string]interface{} `json:"properties"`
}

func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName }

type CollectionNameReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Expand Down
15 changes: 9 additions & 6 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,13 +987,16 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string {

func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
dbKey := strings.ToLower(util.HeaderDBName)
if username == "" {
return contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
if dbName != "" {
ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
}
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName)
if username != "" {
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue)
}
return ctx
}

func AppendUserInfoForRPC(ctx context.Context) context.Context {
Expand Down
4 changes: 2 additions & 2 deletions internal/rootcoord/drop_collection_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error {
// dropping collection with `ts1` but a collection exists in catalog with newer ts which is bigger than `ts1`.
// fortunately, if ddls are promised to execute in sequence, then everything is OK. The `ts1` will always be latest.
collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if errors.Is(err, merr.ErrCollectionNotFound) {
if errors.Is(err, merr.ErrCollectionNotFound) || errors.Is(err, merr.ErrDatabaseNotFound) {
// make dropping collection idempotent.
log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName()))
log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName()), zap.String("database", t.Req.GetDbName()))
return nil
}

Expand Down
10 changes: 5 additions & 5 deletions internal/rootcoord/meta_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,11 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str
dbName = util.DefaultDBName
}

db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp)
if err != nil {
return nil, err
}

collectionID, ok := mt.aliases.get(dbName, collectionName)
if ok {
return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, false)
Expand All @@ -615,11 +620,6 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str
return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName)
}

db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp)
if err != nil {
return nil, err
}

// travel meta information from catalog. No need to check time travel logic again, since catalog already did.
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
coll, err := mt.catalog.GetCollectionByName(ctx1, db.ID, collectionName, ts)
Expand Down
Loading

0 comments on commit ab88d23

Please sign in to comment.