Skip to content

Commit

Permalink
fix: fallback the logic of GlobalMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
devhaozi committed Nov 15, 2023
1 parent 85fc279 commit 975e4f0
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 74 deletions.
7 changes: 7 additions & 0 deletions context_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ func (r *ContextRequest) Method() string {

func (r *ContextRequest) Next() {
if err := r.instance.Next(); err != nil {
var fiberErr *fiber.Error
if errors.As(err, &fiberErr) {
if err := r.instance.Status(fiberErr.Code).SendString(fiberErr.Message); err == nil {
return
}
}

panic(err)
}
}
Expand Down
53 changes: 18 additions & 35 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@ type Group struct {
instance *fiber.App
originPrefix string
prefix string
globalMiddlewares []any
originMiddlewares []httpcontract.Middleware
middlewares []httpcontract.Middleware
lastMiddlewares []httpcontract.Middleware
}

func NewGroup(config config.Config, instance *fiber.App, prefix string, globalMiddlewares []any, originMiddlewares []httpcontract.Middleware, lastMiddlewares []httpcontract.Middleware) route.Router {
func NewGroup(config config.Config, instance *fiber.App, prefix string, originMiddlewares []httpcontract.Middleware, lastMiddlewares []httpcontract.Middleware) route.Router {
return &Group{
config: config,
instance: instance,
originPrefix: prefix,
globalMiddlewares: globalMiddlewares,
originMiddlewares: originMiddlewares,
lastMiddlewares: lastMiddlewares,
}
Expand All @@ -42,7 +40,7 @@ func (r *Group) Group(handler route.GroupFunc) {
prefix := pathToFiberPath(r.originPrefix + "/" + r.prefix)
r.prefix = ""

handler(NewGroup(r.config, r.instance, prefix, r.globalMiddlewares, middlewares, r.lastMiddlewares))
handler(NewGroup(r.config, r.instance, prefix, middlewares, r.lastMiddlewares))
}

func (r *Group) Prefix(addr string) route.Router {
Expand All @@ -58,67 +56,60 @@ func (r *Group) Middleware(middlewares ...httpcontract.Middleware) route.Router
}

func (r *Group) Any(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).All(relativePath, r.getMiddlewares(handler)...)
r.instance.All(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Get(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Get(relativePath, r.getMiddlewares(handler)...)
r.instance.Get(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Post(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Post(relativePath, r.getMiddlewares(handler)...)
r.instance.Post(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Delete(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Delete(relativePath, r.getMiddlewares(handler)...)
r.instance.Delete(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Patch(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Patch(relativePath, r.getMiddlewares(handler)...)
r.instance.Patch(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Put(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Put(relativePath, r.getMiddlewares(handler)...)
r.instance.Put(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Options(relativePath string, handler httpcontract.HandlerFunc) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Options(relativePath, r.getMiddlewares(handler)...)
r.instance.Options(r.getPath(relativePath), r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Resource(relativePath string, controller httpcontract.ResourceController) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Get(relativePath, r.getMiddlewares(controller.Index)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Post(relativePath, r.getMiddlewares(controller.Store)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Get(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Show)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Put(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Patch(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Delete(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Destroy)...)
r.instance.Get(relativePath, r.getMiddlewares(controller.Index)...)
r.instance.Post(relativePath, r.getMiddlewares(controller.Store)...)
r.instance.Get(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Show)...)
r.instance.Put(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Patch(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Delete(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Destroy)...)
r.clearMiddlewares()
}

func (r *Group) Static(relativePath, root string) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Use(r.getMiddlewaresWithPath(relativePath, nil)...).Static(relativePath, root)
r.instance.Use(r.getMiddlewaresWithPath(r.getPath(relativePath), nil)...).Static(relativePath, root)
r.clearMiddlewares()
}

func (r *Group) StaticFile(relativePath, filePath string) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Use(r.getMiddlewaresWithPath(relativePath, nil)...).Use(relativePath, func(c *fiber.Ctx) error {
r.instance.Use(r.getMiddlewaresWithPath(relativePath, nil)...).Use(relativePath, func(c *fiber.Ctx) error {
dir, file := filepath.Split(filePath)
escapedFile := url.PathEscape(file)
escapedPath := filepath.Join(dir, escapedFile)
Expand All @@ -130,7 +121,7 @@ func (r *Group) StaticFile(relativePath, filePath string) {

func (r *Group) StaticFS(relativePath string, fs http.FileSystem) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Use(r.getMiddlewaresWithPath(relativePath, nil)...).Use(relativePath, filesystem.New(filesystem.Config{
r.instance.Use(r.getMiddlewaresWithPath(relativePath, nil)...).Use(relativePath, filesystem.New(filesystem.Config{
Root: fs,
}))
r.clearMiddlewares()
Expand Down Expand Up @@ -174,14 +165,6 @@ func (r *Group) getMiddlewaresWithPath(relativePath string, handler httpcontract
return handlers
}

func (r *Group) getGlobalMiddlewaresWithPath(relativePath string) []any {
var handlers []any
handlers = append(handlers, relativePath)
handlers = append(handlers, r.globalMiddlewares...)

return handlers
}

func (r *Group) clearMiddlewares() {
r.middlewares = []httpcontract.Middleware{}
}
70 changes: 35 additions & 35 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(3)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(3)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(3)
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -288,13 +288,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(2)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(2)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(2)
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -313,13 +313,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(4)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(4)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(4)
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -338,13 +338,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(5)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(5)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(5)
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -363,13 +363,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(6)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(6)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(6)
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand Down
6 changes: 2 additions & 4 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ func NewRoute(config config.Config, parameters map[string]any) (*Route, error) {
config,
app,
"",
[]any{func(c *fiber.Ctx) error {
return c.Next()
}},
[]httpcontract.Middleware{},
[]httpcontract.Middleware{},
),
Expand Down Expand Up @@ -111,11 +108,12 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
tempMiddlewares = append(tempMiddlewares, middleware)
}

r.instance.Use(tempMiddlewares...)

r.Router = NewGroup(
r.config,
r.instance,
"",
tempMiddlewares,
[]httpcontract.Middleware{},
[]httpcontract.Middleware{ResponseMiddleware()},
)
Expand Down

0 comments on commit 975e4f0

Please sign in to comment.