Skip to content

Commit

Permalink
fix: recovery (#122)
Browse files Browse the repository at this point in the history
* fix: recovery

* optimize

* optimize

* optimize sort
  • Loading branch information
hwbrzzl authored Dec 29, 2024
1 parent 1cbcae6 commit d790c69
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
16 changes: 14 additions & 2 deletions middleware_timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net/http"
"time"

"github.com/gin-gonic/gin"

contractshttp "github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/errors"
)
Expand All @@ -20,9 +22,19 @@ func Timeout(timeout time.Duration) contractshttp.Middleware {
done := make(chan struct{})

go func() {
defer HandleRecover(ctx, globalRecoverCallback)
defer func() {
if err := recover(); err != nil {
if globalRecoverCallback != nil {
globalRecoverCallback(ctx, err)
} else {
LogFacade.Error(err)
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"})
}
}

close(done)
}()
ctx.Request().Next()
close(done)
}()

select {
Expand Down
26 changes: 18 additions & 8 deletions middleware_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ func TestTimeoutMiddleware(t *testing.T) {
panic(1)
})

globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"})
}

route.Recover(globalRecover)

w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)
Expand All @@ -56,12 +50,28 @@ func TestTimeoutMiddleware(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "normal", w.Body.String())

// Test with default recover callback
mockLog := mockslog.NewLog(t)
mockLog.EXPECT().Error(1).Once()
LogFacade = mockLog

w = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

mockLog := mockslog.NewLog(t)
LogFacade = mockLog
route.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "{\"error\":\"Internal Server Error\"}", w.Body.String())

// Test with custom recover callback
globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"})
}
route.Recover(globalRecover)

w = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

route.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
Expand Down
16 changes: 5 additions & 11 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,15 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
r.setMiddlewares(middlewares)
}

func HandleRecover(ctx httpcontract.Context, recoverCallback func(ctx httpcontract.Context, err any)) {
if err := recover(); err != nil {
if recoverCallback != nil {
recoverCallback(ctx, err)
} else {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"})
}
}
}

func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) {
globalRecoverCallback = callback
r.setMiddlewares([]httpcontract.Middleware{
func(ctx httpcontract.Context) {
defer HandleRecover(ctx, globalRecoverCallback)
defer func() {
if err := recover(); err != nil {
callback(ctx, err)
}
}()
ctx.Request().Next()
},
})
Expand Down
5 changes: 2 additions & 3 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,15 @@ func TestRecoverWithDefaultCallback(t *testing.T) {
mockConfig.EXPECT().GetBool("app.debug").Return(true).Once()
mockConfig.EXPECT().GetInt("http.drivers.gin.body_limit", 4096).Return(4096).Once()

w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/recover", nil)

route, err := NewRoute(mockConfig, nil)
assert.Nil(t, err)

route.Get("/recover", func(ctx contractshttp.Context) contractshttp.Response {
panic(1)
})

w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/recover", nil)
route.ServeHTTP(w, req)

assert.Equal(t, "", w.Body.String())
Expand Down

0 comments on commit d790c69

Please sign in to comment.