diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index ef39d46ef7c1f..d56ce8e7ff38e 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -9,6 +9,7 @@ import ( // v2 const ( // --- category --- + DataBaseCategory = "/databases/" CollectionCategory = "/collections/" EntityCategory = "/entities/" PartitionCategory = "/partitions/" @@ -74,6 +75,8 @@ const ( HTTPCollectionName = "collectionName" HTTPCollectionID = "collectionID" HTTPDbName = "dbName" + HTTPDbID = "dbID" + HTTPProperties = "properties" HTTPPartitionName = "partitionName" HTTPPartitionNames = "partitionNames" HTTPUserName = "userName" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 38f3984a1aa21..5ab7a2be52f1f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -50,70 +50,75 @@ func NewHandlersV2(proxyClient types.ProxyComponent) *HandlersV2 { } func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { - router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listCollections))))) - router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasCollection))))) + router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listCollections)))) + router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.hasCollection)))) // todo review the return data - router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionDetails))))) - router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionStats))))) - router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionLoadState))))) - router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{AutoID: DisableAutoID} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createCollection))))) - router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropCollection))))) - router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.renameCollection))))) - router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadCollection))))) - router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releaseCollection))))) - + router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionDetails)))) + router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionStats)))) + router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.getCollectionLoadState)))) + router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{AutoID: DisableAutoID} }, wrapperTraceLog(h.createCollection)))) + router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.dropCollection)))) + router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.renameCollection)))) + router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.loadCollection)))) + router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection)))) + + router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase)))) + router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.dropDatabase)))) + router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases)))) + router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.describeDatabase)))) + router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase)))) // Query router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &QueryReqV2{ Limit: 100, OutputFields: []string{DefaultOutputFields}, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.query)))), true)) + }, wrapperTraceLog(h.query))), true)) // Get router.POST(EntityCategory+GetAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionIDReq{ OutputFields: []string{DefaultOutputFields}, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.get)))), true)) + }, wrapperTraceLog(h.get))), true)) // Delete router.POST(EntityCategory+DeleteAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionFilterReq{} - }, wrapperTraceLog(h.wrapperCheckDatabase(h.delete)))), false)) + }, wrapperTraceLog(h.delete))), false)) // Insert router.POST(EntityCategory+InsertAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionDataReq{} - }, wrapperTraceLog(h.wrapperCheckDatabase(h.insert)))), false)) + }, wrapperTraceLog(h.insert))), false)) // Upsert router.POST(EntityCategory+UpsertAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &CollectionDataReq{} - }, wrapperTraceLog(h.wrapperCheckDatabase(h.upsert)))), false)) + }, wrapperTraceLog(h.upsert))), false)) // Search router.POST(EntityCategory+SearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &SearchReqV2{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))), true)) + }, wrapperTraceLog(h.search))), true)) // advanced_search, backward compatible uri router.POST(EntityCategory+AdvancedSearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &HybridSearchReq{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))), true)) + }, wrapperTraceLog(h.advancedSearch))), true)) // HybridSearch router.POST(EntityCategory+HybridSearchAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &HybridSearchReq{ Limit: 100, } - }, wrapperTraceLog(h.wrapperCheckDatabase(h.advancedSearch)))), true)) + }, wrapperTraceLog(h.advancedSearch))), true)) - router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions))))) - router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions))))) - router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.statsPartition))))) + router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.listPartitions)))) + router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.hasPartitions)))) + router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.statsPartition)))) - router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createPartition))))) - router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropPartition))))) - router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadPartitions))))) - router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releasePartitions))))) + router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.createPartition)))) + router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.dropPartition)))) + router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.loadPartitions)))) + router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.releasePartitions)))) router.POST(UserCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listUsers)))) router.POST(UserCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &UserReq{} }, wrapperTraceLog(h.describeUser)))) @@ -141,24 +146,24 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.addPrivilegesToGroup)))) router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup)))) - router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listIndexes))))) - router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeIndex))))) + router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.listIndexes)))) + router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.describeIndex)))) - router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createIndex))))) + router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.createIndex)))) // todo cannot drop index before release it ? - router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropIndex))))) + router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.dropIndex)))) - router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listAlias))))) - router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeAlias))))) + router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.listAlias)))) + router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.describeAlias)))) - router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createAlias))))) - router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropAlias))))) - router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.alterAlias))))) + router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.createAlias)))) + router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.dropAlias)))) + router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.alterAlias)))) - router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob))))) - router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob))))) - router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) - router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.listImportJob)))) + router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.createImportJob)))) + router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess)))) + router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess)))) } type ( @@ -191,13 +196,15 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc { return } dbName := "" - if getter, ok := req.(requestutil.DBNameGetter); ok { - dbName = getter.GetDbName() - } - if dbName == "" { - dbName = c.Request.Header.Get(HTTPHeaderDBName) + if req != nil { + if getter, ok := req.(requestutil.DBNameGetter); ok { + dbName = getter.GetDbName() + } if dbName == "" { - dbName = DefaultDbName + dbName = c.Request.Header.Get(HTTPHeaderDBName) + if dbName == "" { + dbName = DefaultDbName + } } } username, _ := c.Get(ContextUsername) @@ -261,7 +268,7 @@ func wrapperTraceLog(v2 handlerFuncV2) handlerFuncV2 { if err != nil { log.Ctx(ctx).Info("trace info: all, error", zap.Error(err)) } else { - log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp)) + log.Ctx(ctx).Info("trace info: all, unknown") } } return resp, err @@ -334,31 +341,6 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu return response, err } -func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV2) handlerFuncV2 { - return func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) { - if dbName == DefaultDbName || proxy.CheckDatabase(ctx, dbName) { - return v2(ctx, c, req, dbName) - } - resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) { - return h.proxy.ListDatabases(reqCtx, &milvuspb.ListDatabasesRequest{}) - }) - if err != nil { - return resp, err - } - for _, db := range resp.(*milvuspb.ListDatabasesResponse).DbNames { - if db == dbName { - return v2(ctx, c, req, dbName) - } - } - log.Ctx(ctx).Warn("high level restful api, non-exist database", zap.String("database", dbName)) - HTTPAbortReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), - HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName, - }) - return nil, merr.ErrDatabaseNotFound - } -} - func (h *HandlersV2) hasCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { getter, _ := anyReq.(requestutil.CollectionNameGetter) collectionName := getter.GetCollectionName() @@ -1381,6 +1363,99 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe return statusResponse, err } +func (h *HandlersV2) createDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DatabaseReqWithProperties) + req := &milvuspb.CreateDatabaseRequest{ + DbName: dbName, + } + properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties)) + for key, value := range httpReq.Properties { + properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + req.Properties = properties + + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.CreateDatabase(reqCtx, req.(*milvuspb.CreateDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.DropDatabaseRequest{ + DbName: dbName, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DropDatabase(reqCtx, req.(*milvuspb.DropDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +// todo: use a more flexible way to handle the number of input parameters of req +func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ListDatabasesRequest{} + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.ListDatabases(reqCtx, req.(*milvuspb.ListDatabasesRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListDatabasesResponse).DbNames)) + } + return resp, err +} + +func (h *HandlersV2) describeDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.DescribeDatabaseRequest{ + DbName: dbName, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeDatabase", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DescribeDatabase(reqCtx, req.(*milvuspb.DescribeDatabaseRequest)) + }) + if err != nil { + return nil, err + } + info, _ := resp.(*milvuspb.DescribeDatabaseResponse) + if info.Properties == nil { + info.Properties = []*commonpb.KeyValuePair{} + } + dataBaseInfo := map[string]any{ + HTTPDbName: info.DbName, + HTTPDbID: info.DbID, + HTTPProperties: info.Properties, + } + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: dataBaseInfo}) + return resp, err +} + +func (h *HandlersV2) alterDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DatabaseReqWithProperties) + req := &milvuspb.AlterDatabaseRequest{ + DbName: dbName, + } + properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties)) + for key, value := range httpReq.Properties { + properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + req.Properties = properties + + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) req := &milvuspb.ShowPartitionsRequest{ diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 5248f9c9a69b8..9cb86a26cbe4f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -376,99 +376,47 @@ func TestTimeout(t *testing.T) { } } -func TestDatabaseWrapper(t *testing.T) { +func TestCreateIndex(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + postTestCases := []requestBodyTestCase{} mp := mocks.NewMockProxy(t) - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &StatusSuccess, - DbNames: []string{DefaultCollectionName, "exist"}, - }, nil).Twice() - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once() - h := NewHandlersV2(mp) - ginHandler := gin.Default() - app := ginHandler.Group("", genAuthMiddleWare(false)) - path := "/wrapper/database" - app.POST(path, wrapperPost(func() any { return &DefaultReq{} }, h.wrapperCheckDatabase(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) { - return nil, nil - }))) - postTestCases = append(postTestCases, requestBodyTestCase{ - path: path, - requestBody: []byte(`{}`), - }) - postTestCases = append(postTestCases, requestBodyTestCase{ - path: path, - requestBody: []byte(`{"dbName": "exist"}`), - }) + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(IndexCategory, CreateAction) + // the previous format postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"dbName": "non-exist"}`), - errMsg: "database not found, database: non-exist", - errCode: 800, // ErrDatabaseNotFound + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`), }) + // the current format postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"dbName": "test"}`), - errMsg: "", - errCode: 65535, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`), }) for _, testcase := range postTestCases { t.Run("post"+testcase.path, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) w := httptest.NewRecorder() - ginHandler.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - fmt.Println(w.Body.String()) - if testcase.errCode != 0 { - returnBody := &ReturnErrMsg{} - err := json.Unmarshal(w.Body.Bytes(), returnBody) - assert.Nil(t, err) - assert.Equal(t, testcase.errCode, returnBody.Code) - assert.Equal(t, testcase.errMsg, returnBody.Message) - } - }) - } - - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &StatusSuccess, - DbNames: []string{DefaultCollectionName, "default"}, - }, nil).Once() - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &StatusSuccess, - DbNames: []string{DefaultCollectionName, "test"}, - }, nil).Once() - mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once() - rawTestCases := []rawTestCase{ - { - errMsg: "database not found, database: test", - errCode: 800, // ErrDatabaseNotFound - }, - {}, - { - errMsg: "", - errCode: 65535, - }, - } - for _, testcase := range rawTestCases { - t.Run("post with db"+testcase.path, func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader([]byte(`{}`))) - req.Header.Set(HTTPHeaderDBName, "test") - w := httptest.NewRecorder() - ginHandler.ServeHTTP(w, req) + testEngine.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) if testcase.errCode != 0 { - returnBody := &ReturnErrMsg{} - err := json.Unmarshal(w.Body.Bytes(), returnBody) - assert.Nil(t, err) - assert.Equal(t, testcase.errCode, returnBody.Code) assert.Equal(t, testcase.errMsg, returnBody.Message) } }) } } -func TestCreateIndex(t *testing.T) { +func TestDatabase(t *testing.T) { paramtable.Init() // disable rate limit paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") @@ -476,18 +424,107 @@ func TestCreateIndex(t *testing.T) { postTestCases := []requestBodyTestCase{} mp := mocks.NewMockProxy(t) - mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() testEngine := initHTTPServerV2(mp, false) - path := versionalV2(IndexCategory, CreateAction) - // the previous format + path := versionalV2(DataBaseCategory, CreateAction) + // success postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`), + requestBody: []byte(`{"dbName":"test"}`), }) - // the current format + // mock fail postTestCases = append(postTestCases, requestBodyTestCase{ path: path, - requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`), + requestBody: []byte(`{"dbName":"invalid_name"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, DropAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{DbNames: []string{"a", "b", "c"}}, nil).Once() + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(DataBaseCategory, ListAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{DbName: "test", DbID: 100}, nil).Once() + mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(DataBaseCategory, DescribeAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, AlterAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid }) for _, testcase := range postTestCases { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 0ef3c2045e0f6..5ae7babd346a0 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -9,12 +9,23 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +type EmptyReq struct{} + +func (req *EmptyReq) GetDbName() string { return "" } + type DatabaseReq struct { DbName string `json:"dbName"` } func (req *DatabaseReq) GetDbName() string { return req.DbName } +type DatabaseReqWithProperties struct { + DbName string `json:"dbName" binding:"required"` + Properties map[string]interface{} `json:"properties"` +} + +func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName } + type CollectionNameReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` diff --git a/internal/proxy/util.go b/internal/proxy/util.go index b5732a1a6dc1f..a3aadeb433b2f 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -987,13 +987,16 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string { func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context { dbKey := strings.ToLower(util.HeaderDBName) - if username == "" { - return contextutil.AppendToIncomingContext(ctx, dbKey, dbName) + if dbName != "" { + ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName) } - originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) - authKey := strings.ToLower(util.HeaderAuthorize) - authValue := crypto.Base64Encode(originValue) - return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName) + if username != "" { + originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue) + } + return ctx } func AppendUserInfoForRPC(ctx context.Context) context.Context { diff --git a/internal/rootcoord/drop_collection_task.go b/internal/rootcoord/drop_collection_task.go index 4d443a1ad3ad7..a5a65b1cd1d85 100644 --- a/internal/rootcoord/drop_collection_task.go +++ b/internal/rootcoord/drop_collection_task.go @@ -57,9 +57,9 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { // dropping collection with `ts1` but a collection exists in catalog with newer ts which is bigger than `ts1`. // fortunately, if ddls are promised to execute in sequence, then everything is OK. The `ts1` will always be latest. collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp) - if errors.Is(err, merr.ErrCollectionNotFound) { + if errors.Is(err, merr.ErrCollectionNotFound) || errors.Is(err, merr.ErrDatabaseNotFound) { // make dropping collection idempotent. - log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName())) + log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName()), zap.String("database", t.Req.GetDbName())) return nil } diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index 53e16e2486e0f..fd22fe7b0f64c 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -601,6 +601,11 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str dbName = util.DefaultDBName } + db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp) + if err != nil { + return nil, err + } + collectionID, ok := mt.aliases.get(dbName, collectionName) if ok { return mt.getCollectionByIDInternal(ctx, dbName, collectionID, ts, false) @@ -615,11 +620,6 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName) } - db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp) - if err != nil { - return nil, err - } - // travel meta information from catalog. No need to check time travel logic again, since catalog already did. ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue()) coll, err := mt.catalog.GetCollectionByName(ctx1, db.ID, collectionName, ts) diff --git a/internal/rootcoord/meta_table_test.go b/internal/rootcoord/meta_table_test.go index 3122eb4554318..3bf93aeb31cdb 100644 --- a/internal/rootcoord/meta_table_test.go +++ b/internal/rootcoord/meta_table_test.go @@ -537,6 +537,24 @@ func TestMetaTable_getCollectionByIDInternal(t *testing.T) { } func TestMetaTable_GetCollectionByName(t *testing.T) { + t.Run("db not found", func(t *testing.T) { + meta := &MetaTable{ + aliases: newNameDb(), + collID2Meta: map[typeutil.UniqueID]*model.Collection{ + 100: { + State: pb.CollectionState_CollectionCreated, + CreateTime: 99, + Partitions: []*model.Partition{}, + }, + }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, + } + ctx := context.Background() + _, err := meta.GetCollectionByName(ctx, "not_exist", "name", 101) + assert.Error(t, err) + }) t.Run("get by alias", func(t *testing.T) { meta := &MetaTable{ aliases: newNameDb(), @@ -550,6 +568,9 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { }, }, }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, } meta.aliases.insert(util.DefaultDBName, "alias", 100) ctx := context.Background() @@ -574,6 +595,9 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { }, }, }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, } meta.names.insert(util.DefaultDBName, "name", 100) ctx := context.Background() @@ -661,7 +685,13 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { t.Run("get latest version", func(t *testing.T) { ctx := context.Background() - meta := &MetaTable{names: newNameDb(), aliases: newNameDb()} + meta := &MetaTable{ + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, + names: newNameDb(), + aliases: newNameDb(), + } _, err := meta.GetCollectionByName(ctx, "", "not_exist", typeutil.MaxTimestamp) assert.Error(t, err) assert.ErrorIs(t, err, merr.ErrCollectionNotFound) @@ -1880,6 +1910,9 @@ func TestMetaTable_EmtpyDatabaseName(t *testing.T) { collID2Meta: map[typeutil.UniqueID]*model.Collection{ 1: {CollectionID: 1}, }, + dbName2Meta: map[string]*model.Database{ + util.DefaultDBName: model.NewDefaultDatabase(), + }, } mt.aliases.insert(util.DefaultDBName, "aliases", 1)