diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index e61004f3c885a..d6c4463be6315 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -37,12 +37,14 @@ import ( "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -315,6 +317,7 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu if checkAuth { err := checkAuthorizationV2(ctx, c, ignoreErr, req) if err != nil { + hookutil.GetExtension().ReportRefused(ctx, req, WrapErrorToResponse(err), nil, c.FullPath()) return nil, err } } @@ -322,6 +325,7 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu _, err := CheckLimiter(ctx, req, pxy) if err != nil { log.Warn("high level restful api, fail to check limiter", zap.Error(err), zap.String("method", fullMethod)) + hookutil.GetExtension().ReportRefused(ctx, req, WrapErrorToResponse(merr.ErrHTTPRateLimit), nil, c.FullPath()) HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit), HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(), @@ -334,13 +338,15 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu if !ok { username = "" } - response, err := proxy.HookInterceptor(ctx, req, username.(string), fullMethod, handler) + + response, err := proxy.HookInterceptor(context.WithValue(ctx, hook.GinParamsKey, c.Keys), req, username.(string), fullMethod, handler) if err == nil { status, ok := requestutil.GetStatusFromResponse(response) if ok { err = merr.Error(status) } } + if err != nil { log.Ctx(ctx).Warn("high level restful api, grpc call failed", zap.Error(err)) if !ignoreErr { @@ -2162,15 +2168,6 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a } c.Set(ContextRequest, req) - if h.checkAuth { - err := checkAuthorizationV2(ctx, c, false, &milvuspb.ListImportsAuthPlaceholder{ - DbName: dbName, - CollectionName: collectionName, - }) - if err != nil { - return nil, err - } - } resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListImports", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListImports(reqCtx, req.(*internalpb.ListImportsRequest)) }) @@ -2214,16 +2211,6 @@ func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq } c.Set(ContextRequest, req) - if h.checkAuth { - err := checkAuthorizationV2(ctx, c, false, &milvuspb.ImportAuthPlaceholder{ - DbName: dbName, - CollectionName: collectionGetter.GetCollectionName(), - PartitionName: partitionGetter.GetPartitionName(), - }) - if err != nil { - return nil, err - } - } resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ImportV2", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ImportV2(reqCtx, req.(*internalpb.ImportRequest)) }) @@ -2243,14 +2230,6 @@ func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, an } c.Set(ContextRequest, req) - if h.checkAuth { - err := checkAuthorizationV2(ctx, c, false, &milvuspb.GetImportProgressAuthPlaceholder{ - DbName: dbName, - }) - if err != nil { - return nil, err - } - } resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/GetImportProgress", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.GetImportProgress(reqCtx, req.(*internalpb.GetImportProgressRequest)) }) diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 9df0e2a64ebd4..cc1babd11030c 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1679,3 +1679,9 @@ func RequestHandlerFunc(c *gin.Context) { } c.Next() } + +func WrapErrorToResponse(err error) *milvuspb.BoolResponse { + return &milvuspb.BoolResponse{ + Status: merr.Status(err), + } +} diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 33c21e505f2b0..f8c0f3333feb8 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -63,6 +63,7 @@ import ( "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" _ "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/tracer" @@ -141,6 +142,10 @@ func authenticate(c *gin.Context) { c.Set(httpserver.ContextUsername, user) return } + + hookutil.GetExtension().ReportRefused(context.Background(), nil, &milvuspb.BoolResponse{ + Status: merr.Status(err), + }, nil, c.FullPath()) log.Warn("fail to verify apikey", zap.Error(err)) } c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{mhttp.HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), mhttp.HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) diff --git a/internal/util/hookutil/default.go b/internal/util/hookutil/default.go index 7bbd467bb6f6a..546ed80c4388c 100644 --- a/internal/util/hookutil/default.go +++ b/internal/util/hookutil/default.go @@ -59,3 +59,7 @@ var _ hook.Extension = (*DefaultExtension)(nil) func (d DefaultExtension) Report(info any) int { return 0 } + +func (d DefaultExtension) ReportRefused(ctx context.Context, req interface{}, resp interface{}, err error, fullMethod string) error { + return nil +}