Skip to content

Commit

Permalink
enhance: remove no need database check in restful sdk
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Nov 29, 2024
1 parent 61f73d5 commit 620d2de
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 160 deletions.
101 changes: 38 additions & 63 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,70 +66,70 @@ func NewHandlersV2(proxyClient types.ProxyComponent) *HandlersV2 {
}

func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listCollections)))))
router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasCollection)))))
router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listCollections))))
router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.hasCollection))))
// todo review the return data
router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionDetails)))))
router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionStats)))))
router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionLoadState)))))
router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{AutoID: DisableAutoID} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createCollection)))))
router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropCollection)))))
router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.renameCollection)))))
router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadCollection)))))
router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releaseCollection)))))
router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionDetails))))
router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionStats))))
router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionLoadState))))
router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{AutoID: DisableAutoID} }, wrapperTraceLog(h.createCollection))))
router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.dropCollection))))
router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.renameCollection))))
router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.loadCollection))))
router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection))))

// Query
router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &QueryReqV2{
Limit: 100,
OutputFields: []string{DefaultOutputFields},
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.query)))), true))
}, wrapperTraceLog(h.query))), true))
// Get
router.POST(EntityCategory+GetAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &CollectionIDReq{
OutputFields: []string{DefaultOutputFields},
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.get)))), true))
}, wrapperTraceLog(h.get))), true))
// Delete
router.POST(EntityCategory+DeleteAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &CollectionFilterReq{}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.delete)))), false))
}, wrapperTraceLog(h.delete))), false))
// Insert
router.POST(EntityCategory+InsertAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &CollectionDataReq{}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.insert)))), false))
}, wrapperTraceLog(h.insert))), false))
// Upsert
router.POST(EntityCategory+UpsertAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &CollectionDataReq{}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.upsert)))), false))
}, wrapperTraceLog(h.upsert))), false))
// Search
router.POST(EntityCategory+SearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &SearchReqV2{
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))), true))
}, wrapperTraceLog(h.search))), true))
// advanced_search, backward compatible uri
router.POST(EntityCategory+AdvancedSearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &HybridSearchReq{
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))), true))
}, wrapperTraceLog(h.advancedSearch))), true))
// HybridSearch
router.POST(EntityCategory+HybridSearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any {
return &HybridSearchReq{
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))), true))
}, wrapperTraceLog(h.advancedSearch))), true))

router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions)))))
router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions)))))
router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.statsPartition)))))
router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.listPartitions))))
router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.hasPartitions))))
router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.statsPartition))))

router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createPartition)))))
router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropPartition)))))
router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadPartitions)))))
router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releasePartitions)))))
router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.createPartition))))
router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.dropPartition))))
router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.loadPartitions))))
router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.releasePartitions))))

router.POST(UserCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listUsers))))
router.POST(UserCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &UserReq{} }, wrapperTraceLog(h.describeUser))))
Expand Down Expand Up @@ -157,24 +157,24 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.addPrivilegesToGroup))))
router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup))))

router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listIndexes)))))
router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeIndex)))))
router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.listIndexes))))
router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.describeIndex))))

router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createIndex)))))
router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.createIndex))))
// todo cannot drop index before release it ?
router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropIndex)))))
router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.dropIndex))))

router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listAlias)))))
router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeAlias)))))
router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.listAlias))))
router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.describeAlias))))

router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createAlias)))))
router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropAlias)))))
router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.alterAlias)))))
router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.createAlias))))
router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.dropAlias))))
router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.alterAlias))))

router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob)))))
router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob)))))
router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess)))))
router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess)))))
router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.listImportJob))))
router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.createImportJob))))
router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess))))
router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess))))
}

type (
Expand Down Expand Up @@ -350,31 +350,6 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu
return response, err
}

func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV2) handlerFuncV2 {
return func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
if dbName == DefaultDbName || proxy.CheckDatabase(ctx, dbName) {
return v2(ctx, c, req, dbName)
}
resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.ListDatabases(reqCtx, &milvuspb.ListDatabasesRequest{})
})
if err != nil {
return resp, err
}
for _, db := range resp.(*milvuspb.ListDatabasesResponse).DbNames {
if db == dbName {
return v2(ctx, c, req, dbName)
}
}
log.Ctx(ctx).Warn("high level restful api, non-exist database", zap.String("database", dbName))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound),
HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName,
})
return nil, merr.ErrDatabaseNotFound
}
}

func (h *HandlersV2) hasCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
getter, _ := anyReq.(requestutil.CollectionNameGetter)
collectionName := getter.GetCollectionName()
Expand Down
92 changes: 0 additions & 92 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,98 +409,6 @@ func TestTimeout(t *testing.T) {
}
}

func TestDatabaseWrapper(t *testing.T) {
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "exist"},
}, nil).Twice()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once()
h := NewHandlersV2(mp)
ginHandler := gin.Default()
app := ginHandler.Group("", genAuthMiddleWare(false))
path := "/wrapper/database"
app.POST(path, wrapperPost(func() any { return &DefaultReq{} }, h.wrapperCheckDatabase(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
return nil, nil
})))
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName": "exist"}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName": "non-exist"}`),
errMsg: "database not found, database: non-exist",
errCode: 800, // ErrDatabaseNotFound
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"dbName": "test"}`),
errMsg: "",
errCode: 65535,
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
if testcase.errCode != 0 {
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}

mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "default"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: &StatusSuccess,
DbNames: []string{DefaultCollectionName, "test"},
}, nil).Once()
mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once()
rawTestCases := []rawTestCase{
{
errMsg: "database not found, database: test",
errCode: 800, // ErrDatabaseNotFound
},
{},
{
errMsg: "",
errCode: 65535,
},
}
for _, testcase := range rawTestCases {
t.Run("post with db"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader([]byte(`{}`)))
req.Header.Set(HTTPHeaderDBName, "test")
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
if testcase.errCode != 0 {
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}

func TestDocInDocOutCreateCollection(t *testing.T) {
paramtable.Init()
// disable rate limit
Expand Down
10 changes: 5 additions & 5 deletions internal/rootcoord/meta_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,11 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str
dbName = util.DefaultDBName
}

db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp)
if err != nil {
return nil, err
}

collectionID, ok := mt.aliases.get(dbName, collectionName)
if ok {
return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, false)
Expand All @@ -623,11 +628,6 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str
return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName)
}

db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp)
if err != nil {
return nil, err
}

// travel meta information from catalog. No need to check time travel logic again, since catalog already did.
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
coll, err := mt.catalog.GetCollectionByName(ctx1, db.ID, collectionName, ts)
Expand Down
Loading

0 comments on commit 620d2de

Please sign in to comment.