From 09d8b760487e9311831767f4072ff8e16377bd68 Mon Sep 17 00:00:00 2001 From: PowderLi <135960789+PowderLi@users.noreply.github.com> Date: Tue, 17 Oct 2023 07:00:14 -0500 Subject: [PATCH] [restful] new context with grpc metadata (#27668) Signed-off-by: PowderLi --- .../unittest/test_disk_file_manager_test.cpp | 2 +- .../proxy/httpserver/handler_v1.go | 107 +++++++++++------- .../proxy/httpserver/handler_v1_test.go | 2 + internal/distributed/proxy/service.go | 1 + internal/proxy/privilege_interceptor.go | 17 --- internal/proxy/util.go | 13 +++ 6 files changed, 80 insertions(+), 62 deletions(-) diff --git a/internal/core/unittest/test_disk_file_manager_test.cpp b/internal/core/unittest/test_disk_file_manager_test.cpp index 52762a57c902d..310dec776caea 100644 --- a/internal/core/unittest/test_disk_file_manager_test.cpp +++ b/internal/core/unittest/test_disk_file_manager_test.cpp @@ -101,7 +101,7 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) { int test_worker(string s) { std::cout << s << std::endl; - sleep(4); + std::this_thread::sleep_for(std::chrono::seconds(4)); std::cout << s << std::endl; return 1; } diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 256e444d2fddf..9b40a90c9b5fe 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -1,6 +1,7 @@ package httpserver import ( + "context" "encoding/json" "net/http" "strconv" @@ -20,14 +21,14 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) -func checkAuthorization(c *gin.Context, req interface{}) error { +func checkAuthorization(ctx context.Context, c *gin.Context, req interface{}) error { if proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() { username, ok := c.Get(ContextUsername) - if !ok { + if !ok || username.(string) == "" { c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) return merr.ErrNeedAuthenticate } - _, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), req) + _, authErr := proxy.PrivilegeInterceptor(ctx, req) if authErr != nil { c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()}) return authErr @@ -36,11 +37,11 @@ func checkAuthorization(c *gin.Context, req interface{}) error { return nil } -func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool { +func (h *Handlers) checkDatabase(ctx context.Context, c *gin.Context, dbName string) bool { if dbName == DefaultDbName { return true } - response, err := h.proxy.ListDatabases(c, &milvuspb.ListDatabasesRequest{}) + response, err := h.proxy.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{}) if err == nil { err = merr.Error(response.GetStatus()) } @@ -57,17 +58,17 @@ func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool { return false } -func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) { +func (h *Handlers) describeCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) { req := milvuspb.DescribeCollectionRequest{ DbName: dbName, CollectionName: collectionName, } if needAuth { - if err := checkAuthorization(c, &req); err != nil { + if err := checkAuthorization(ctx, c, &req); err != nil { return nil, err } } - response, err := h.proxy.DescribeCollection(c, &req) + response, err := h.proxy.DescribeCollection(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -83,12 +84,12 @@ func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionN return response, nil } -func (h *Handlers) hasCollection(c *gin.Context, dbName string, collectionName string) (bool, error) { +func (h *Handlers) hasCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string) (bool, error) { req := milvuspb.HasCollectionRequest{ DbName: dbName, CollectionName: collectionName, } - response, err := h.proxy.HasCollection(c, &req) + response, err := h.proxy.HasCollection(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -116,13 +117,15 @@ func (h *Handlers) listCollections(c *gin.Context) { req := milvuspb.ShowCollectionsRequest{ DbName: dbName, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, dbName) { + if !h.checkDatabase(ctx, c, dbName) { return } - response, err := h.proxy.ShowCollections(c, &req) + response, err := h.proxy.ShowCollections(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -195,13 +198,15 @@ func (h *Handlers) createCollection(c *gin.Context) { ShardsNum: ShardNumDefault, ConsistencyLevel: commonpb.ConsistencyLevel_Bounded, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - response, err := h.proxy.CreateCollection(c, &req) + response, err := h.proxy.CreateCollection(ctx, &req) if err == nil { err = merr.Error(response) } @@ -210,7 +215,7 @@ func (h *Handlers) createCollection(c *gin.Context) { return } - response, err = h.proxy.CreateIndex(c, &milvuspb.CreateIndexRequest{ + response, err = h.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, FieldName: httpReq.VectorField, @@ -224,7 +229,7 @@ func (h *Handlers) createCollection(c *gin.Context) { c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } - response, err = h.proxy.LoadCollection(c, &milvuspb.LoadCollectionRequest{ + response, err = h.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, }) @@ -246,14 +251,16 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) { return } dbName := c.DefaultQuery(HTTPDbName, DefaultDbName) - if !h.checkDatabase(c, dbName) { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), dbName) + if !h.checkDatabase(ctx, c, dbName) { return } - coll, err := h.describeCollection(c, dbName, collectionName, true) + coll, err := h.describeCollection(ctx, c, dbName, collectionName, true) if err != nil { return } - stateResp, err := h.proxy.GetLoadState(c, &milvuspb.GetLoadStateRequest{ + stateResp, err := h.proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ DbName: dbName, CollectionName: collectionName, }) @@ -276,7 +283,7 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) { break } } - indexResp, err := h.proxy.DescribeIndex(c, &milvuspb.DescribeIndexRequest{ + indexResp, err := h.proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{ DbName: dbName, CollectionName: collectionName, FieldName: vectorField, @@ -324,13 +331,15 @@ func (h *Handlers) dropCollection(c *gin.Context) { DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName) + has, err := h.hasCollection(ctx, c, httpReq.DbName, httpReq.CollectionName) if err != nil { return } @@ -338,7 +347,7 @@ func (h *Handlers) dropCollection(c *gin.Context) { c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()}) return } - response, err := h.proxy.DropCollection(c, &req) + response, err := h.proxy.DropCollection(ctx, &req) if err == nil { err = merr.Error(response) } @@ -379,13 +388,15 @@ func (h *Handlers) query(c *gin.Context) { if httpReq.Limit > 0 { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - response, err := h.proxy.Query(c, &req) + response, err := h.proxy.Query(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -423,13 +434,15 @@ func (h *Handlers) get(c *gin.Context) { OutputFields: httpReq.OutputFields, GuaranteeTimestamp: BoundedTimestamp, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false) + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) if err != nil || coll == nil { return } @@ -440,7 +453,7 @@ func (h *Handlers) get(c *gin.Context) { return } req.Expr = filter - response, err := h.proxy.Query(c, &req) + response, err := h.proxy.Query(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -476,13 +489,15 @@ func (h *Handlers) delete(c *gin.Context) { DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false) + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) if err != nil || coll == nil { return } @@ -493,7 +508,7 @@ func (h *Handlers) delete(c *gin.Context) { return } req.Expr = filter - response, err := h.proxy.Delete(c, &req) + response, err := h.proxy.Delete(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -532,13 +547,15 @@ func (h *Handlers) insert(c *gin.Context) { PartitionName: "_default", NumRows: uint32(len(httpReq.Data)), } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false) + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) if err != nil || coll == nil { return } @@ -555,7 +572,7 @@ func (h *Handlers) insert(c *gin.Context) { c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) return } - response, err := h.proxy.Insert(c, &req) + response, err := h.proxy.Insert(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } @@ -609,13 +626,15 @@ func (h *Handlers) search(c *gin.Context) { GuaranteeTimestamp: BoundedTimestamp, Nq: int64(1), } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - response, err := h.proxy.Search(c, &req) + response, err := h.proxy.Search(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) } diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index 8f29d6b3bba99..faea5d5fffaca 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -103,6 +103,7 @@ func initHTTPServer(proxy types.ProxyComponent, needAuth bool) *gin.Engine { func genAuthMiddleWare(needAuth bool) gin.HandlerFunc { if needAuth { return func(c *gin.Context) { + c.Set(ContextUsername, "") username, password, ok := ParseUsernamePassword(c) if !ok { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) @@ -1317,6 +1318,7 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { h := NewHandlers(mp) testEngine := gin.New() app := testEngine.Group("/", func(c *gin.Context) { + c.Set(ContextUsername, "") username, _, ok := ParseUsernamePassword(c) if ok { c.Set(ContextUsername, username) diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index cd9b928d733a6..a0d82324950c9 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -116,6 +116,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) } func authenticate(c *gin.Context) { + c.Set(httpserver.ContextUsername, "") if !proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() { return } diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 3a4c8882c9426..b928ef95e9135 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -77,23 +77,6 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context log.Warn("GetCurUserFromContext fail", zap.Error(err)) return ctx, err } - return privilegeInterceptor(ctx, privilegeExt, username, req) -} - -func PrivilegeInterceptorWithUsername(ctx context.Context, username string, req interface{}) (context.Context, error) { - if !Params.CommonCfg.AuthorizationEnabled.GetAsBool() { - return ctx, nil - } - log.Debug("PrivilegeInterceptor", zap.String("type", reflect.TypeOf(req).String())) - privilegeExt, err := funcutil.GetPrivilegeExtObj(req) - if err != nil { - log.Info("GetPrivilegeExtObj err", zap.Error(err)) - return ctx, nil - } - return privilegeInterceptor(ctx, privilegeExt, username, req) -} - -func privilegeInterceptor(ctx context.Context, privilegeExt commonpb.PrivilegeExt, username string, req interface{}) (context.Context, error) { if username == util.UserRoot { return ctx, nil } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index e12131edca0a5..70babc48db475 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -894,6 +894,19 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string { return dbNameData[0] } +func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context { + originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + dbKey := strings.ToLower(util.HeaderDBName) + contextMap := map[string]string{ + authKey: authValue, + dbKey: dbName, + } + md := metadata.New(contextMap) + return metadata.NewIncomingContext(ctx, md) +} + func GetRole(username string) ([]string, error) { if globalMetaCache == nil { return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")