diff --git a/context_request.go b/context_request.go index 554b138..da0b7ab 100644 --- a/context_request.go +++ b/context_request.go @@ -44,11 +44,12 @@ func NewContextRequest(ctx *Context, log log.Log, validation contractsvalidate.V request := contextRequestPool.Get().(*ContextRequest) httpBody, err := getHttpBody(ctx) if err != nil { - log.Error(fmt.Sprintf("%+v", err)) + LogFacade.Error(fmt.Sprintf("%+v", errors.Unwrap(err))) } request.ctx = ctx request.instance = ctx.instance request.httpBody = httpBody + request.log = log request.validation = validation return request } diff --git a/context_request_test.go b/context_request_test.go index 8c2536e..2846300 100644 --- a/context_request_test.go +++ b/context_request_test.go @@ -44,8 +44,10 @@ func (s *ContextRequestSuite) SetupTest() { ValidationFacade = validation.NewValidation() var err error - s.route, err = NewRoute(s.mockConfig, nil) - s.Require().Nil(err) + route, err := NewRoute(s.mockConfig, nil) + s.Require().NotNil(route) + s.Require().NoError(err) + s.route = route } func (s *ContextRequestSuite) TearDownTest() { diff --git a/facades/gin.go b/facades/gin.go index 06af911..9fae238 100644 --- a/facades/gin.go +++ b/facades/gin.go @@ -3,15 +3,16 @@ package facades import ( "log" - "github.com/goravel/framework/contracts/route" - "github.com/goravel/gin" + + "github.com/goravel/framework/contracts/route" ) func Route(driver string) route.Route { instance, err := gin.App.MakeWith(gin.RouteBinding, map[string]any{ "driver": driver, }) + if err != nil { log.Fatalln(err) return nil diff --git a/middleware_timeout.go b/middleware_timeout.go index 9d0dafa..907291d 100644 --- a/middleware_timeout.go +++ b/middleware_timeout.go @@ -20,26 +20,15 @@ func Timeout(timeout time.Duration) contractshttp.Middleware { done := make(chan struct{}) go func() { - defer func() { - if r := recover(); r != nil { - if LogFacade != nil { - LogFacade.Request(ctx.Request()).Error(r) - } - - // TODO can be customized in https://github.com/goravel/goravel/issues/521 - _ = ctx.Response().Status(http.StatusInternalServerError).String("Internal Server Error").Render() - } - - close(done) - }() - + defer HandleRecover(ctx, globalRecoverCallback) ctx.Request().Next() + close(done) }() select { case <-done: - case <-ctx.Request().Origin().Context().Done(): - if errors.Is(ctx.Request().Origin().Context().Err(), context.DeadlineExceeded) { + case <-timeoutCtx.Done(): + if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { ctx.Request().AbortWithStatus(http.StatusGatewayTimeout) } } diff --git a/middleware_timeout_test.go b/middleware_timeout_test.go index b5c0e4d..1183235 100644 --- a/middleware_timeout_test.go +++ b/middleware_timeout_test.go @@ -6,11 +6,11 @@ import ( "testing" "time" + "github.com/gin-gonic/gin" contractshttp "github.com/goravel/framework/contracts/http" mocksconfig "github.com/goravel/framework/mocks/config" mockslog "github.com/goravel/framework/mocks/log" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -24,16 +24,23 @@ func TestTimeoutMiddleware(t *testing.T) { route.Middleware(Timeout(1*time.Second)).Get("/timeout", func(ctx contractshttp.Context) contractshttp.Response { time.Sleep(2 * time.Second) - - return ctx.Response().Success().String("timeout") + return nil }) + route.Middleware(Timeout(1*time.Second)).Get("/normal", func(ctx contractshttp.Context) contractshttp.Response { return ctx.Response().Success().String("normal") }) + route.Middleware(Timeout(1*time.Second)).Get("/panic", func(ctx contractshttp.Context) contractshttp.Response { panic(1) }) + globalRecover := func(ctx contractshttp.Context, err any) { + ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"}) + } + + route.Recover(globalRecover) + w := httptest.NewRecorder() req, err := http.NewRequest("GET", "/timeout", nil) require.NoError(t, err) @@ -54,11 +61,9 @@ func TestTimeoutMiddleware(t *testing.T) { require.NoError(t, err) mockLog := mockslog.NewLog(t) - mockLog.EXPECT().Request(mock.Anything).Return(mockLog).Once() - mockLog.EXPECT().Error(mock.Anything).Once() LogFacade = mockLog route.ServeHTTP(w, req) assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Equal(t, "Internal Server Error", w.Body.String()) + assert.Equal(t, "{\"error\":\"Internal Panic\"}", w.Body.String()) } diff --git a/route.go b/route.go index 8e1515a..257b5d3 100644 --- a/route.go +++ b/route.go @@ -19,6 +19,8 @@ import ( "github.com/savioxavier/termlink" ) +var globalRecoverCallback func(ctx httpcontract.Context, err any) + type Route struct { route.Router config config.Config @@ -89,6 +91,26 @@ func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) { r.setMiddlewares(middlewares) } +func HandleRecover(ctx httpcontract.Context, recoverCallback func(ctx httpcontract.Context, err any)) { + if err := recover(); err != nil { + if recoverCallback != nil { + recoverCallback(ctx, err) + } else { + ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"}) + } + } +} + +func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) { + globalRecoverCallback = callback + r.setMiddlewares([]httpcontract.Middleware{ + func(ctx httpcontract.Context) { + defer HandleRecover(ctx, globalRecoverCallback) + ctx.Request().Next() + }, + }) +} + func (r *Route) Listen(l net.Listener) error { r.outputRoutes() color.Green().Println(termlink.Link("[HTTP] Listening and serving HTTP on", "http://"+l.Addr().String())) diff --git a/route_test.go b/route_test.go index c449253..6fa2d21 100644 --- a/route_test.go +++ b/route_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/render" contractshttp "github.com/goravel/framework/contracts/http" "github.com/goravel/framework/contracts/validation" @@ -21,6 +22,58 @@ import ( "github.com/stretchr/testify/assert" ) +func TestRecoverWithCustomCallback(t *testing.T) { + mockConfig := configmocks.NewConfig(t) + mockConfig.EXPECT().GetBool("app.debug").Return(true).Once() + mockConfig.EXPECT().GetInt("http.drivers.gin.body_limit", 4096).Return(4096).Once() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/recover", nil) + + route, err := NewRoute(mockConfig, nil) + assert.Nil(t, err) + + globalRecover := func(ctx contractshttp.Context, err any) { + ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"}) + } + + route.Recover(globalRecover) + + route.Get("/recover", func(ctx contractshttp.Context) contractshttp.Response { + panic(1) + }) + + route.ServeHTTP(w, req) + + assert.Equal(t, "{\"error\":\"Internal Panic\"}", w.Body.String()) + assert.Equal(t, http.StatusInternalServerError, w.Code) + + mockConfig.AssertExpectations(t) +} + +func TestRecoverWithDefaultCallback(t *testing.T) { + mockConfig := configmocks.NewConfig(t) + mockConfig.EXPECT().GetBool("app.debug").Return(true).Once() + mockConfig.EXPECT().GetInt("http.drivers.gin.body_limit", 4096).Return(4096).Once() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/recover", nil) + + route, err := NewRoute(mockConfig, nil) + assert.Nil(t, err) + + route.Get("/recover", func(ctx contractshttp.Context) contractshttp.Response { + panic(1) + }) + + route.ServeHTTP(w, req) + + assert.Equal(t, "", w.Body.String()) + assert.Equal(t, http.StatusInternalServerError, w.Code) + + mockConfig.AssertExpectations(t) +} + func TestFallback(t *testing.T) { mockConfig := &configmocks.Config{} mockConfig.EXPECT().GetBool("app.debug").Return(true).Once() diff --git a/service_provider.go b/service_provider.go index 2938d80..418afe0 100644 --- a/service_provider.go +++ b/service_provider.go @@ -36,12 +36,15 @@ func (receiver *ServiceProvider) Boot(app foundation.Application) { if ConfigFacade = app.MakeConfig(); ConfigFacade == nil { color.Errorln(errors.ConfigFacadeNotSet.SetModule(module)) } + if LogFacade = app.MakeLog(); LogFacade == nil { color.Errorln(errors.LogFacadeNotSet.SetModule(module)) } + if ValidationFacade = app.MakeValidation(); ValidationFacade == nil { color.Errorln(errors.New("validation facade is not initialized").SetModule(module)) } + if ViewFacade = app.MakeView(); ViewFacade == nil { color.Errorln(errors.New("view facade is not initialized").SetModule(module)) } @@ -50,3 +53,4 @@ func (receiver *ServiceProvider) Boot(app foundation.Application) { "config/cors.go": app.ConfigPath("cors.go"), }) } + diff --git a/view.go b/view.go index 4c94426..fe56148 100644 --- a/view.go +++ b/view.go @@ -18,6 +18,7 @@ func NewView(instance *gin.Context) *View { func (receive *View) Make(view string, data ...any) contractshttp.Response { shared := ViewFacade.GetShared() + if len(data) == 0 { return &HtmlResponse{shared, receive.instance, view} } else { @@ -28,11 +29,9 @@ func (receive *View) Make(view string, data ...any) contractshttp.Response { for key, value := range dataMap { shared[key] = value } - return &HtmlResponse{shared, receive.instance, view} case reflect.Map: fillShared(data[0], shared) - return &HtmlResponse{data[0], receive.instance, view} default: panic(fmt.Sprintf("make %s view failed, data must be map or struct", view)) @@ -82,6 +81,7 @@ func structToMap(data any) map[string]any { func fillShared(data any, shared map[string]any) { dataValue := reflect.ValueOf(data) keys := dataValue.MapKeys() + for key, value := range shared { exist := false for _, k := range keys { @@ -90,6 +90,7 @@ func fillShared(data any, shared map[string]any) { break } } + if !exist { dataValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value)) }