Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize code and bug fix #13

Merged
merged 14 commits into from
Aug 28, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ Fiber http driver for Goravel.
go get -u github.com/goravel/fiber
```

2. Register service provider, make sure it is registered first.
2. Register service provider

```
// config/app.go
import "github.com/goravel/fiber"

"providers": []foundation.ServiceProvider{
&fiber.ServiceProvider{},
...
&fiber.ServiceProvider{},
}
```

Expand All @@ -47,15 +47,14 @@ import (
"default": "fiber",

"drivers": map[string]any{
...
"fiber": map[string]any{
// prefork mode, see https://docs.gofiber.io/api/fiber/#config
"prefork": false,
"route": func() (route.Engine, error) {
return fiberfacades.Route(), nil
},
},
}
},
```

## Testing
Expand Down
25 changes: 25 additions & 0 deletions config/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package config

import (
"github.com/goravel/framework/facades"
)

func init() {
config := facades.Config()
config.Add("cors", map[string]any{
// Cross-Origin Resource Sharing (CORS) Configuration
//
// Here you may configure your settings for cross-origin resource sharing
// or "CORS". This determines what cross-origin operations may execute
// in web browsers. You are free to adjust these settings as needed.
//
// To learn more: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
"paths": []string{"*"},
"allowed_methods": []string{"*"},
"allowed_origins": []string{"*"},
"allowed_headers": []string{"*"},
"exposed_headers": []string{"*"},
"max_age": 0,
"supports_credentials": false,
})
}
3 changes: 1 addition & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ import (
"time"

"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"

"github.com/goravel/framework/contracts/http"
"github.com/valyala/fasthttp"
)

func Background() http.Context {
Expand Down
88 changes: 88 additions & 0 deletions cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package fiber

import (
nethttp "net/http"

"github.com/gofiber/fiber/v2/middleware/cors"
httpcontract "github.com/goravel/framework/contracts/http"
)

func Cors() httpcontract.Middleware {
return func(ctx httpcontract.Context) {
switch ctx := ctx.(type) {
case *Context:
var allowedMethods string
allowedMethodConfigs := ConfigFacade.Get("cors.allowed_methods").([]string)
for i, method := range allowedMethodConfigs {
if method == "*" {
allowedMethods = "GET,POST,HEAD,PUT,DELETE,PATCH"
devhaozi marked this conversation as resolved.
Show resolved Hide resolved
break
}
if i == len(allowedMethodConfigs)-1 {
allowedMethods += method
break
}

allowedMethods += method + ","
}
var allowedOrigins string
allowedOriginConfigs := ConfigFacade.Get("cors.allowed_origins").([]string)
for i, origin := range allowedOriginConfigs {
if origin == "*" {
allowedOrigins = "*"
break
}
if i == len(allowedOriginConfigs)-1 {
allowedOrigins += origin
break
}

allowedOrigins += origin + ","
}
var allowedHeaders string
allowedHeaderConfigs := ConfigFacade.Get("cors.allowed_headers").([]string)
for i, header := range allowedHeaderConfigs {
if header == "*" {
allowedHeaders = "*"
break
}
if i == len(allowedHeaderConfigs)-1 {
allowedHeaders += header
break
}

allowedHeaders += header + ","
}
var exposedHeaders string
exposedHeaderConfigs := ConfigFacade.Get("cors.exposed_headers").([]string)
for i, header := range exposedHeaderConfigs {
if header == "*" {
exposedHeaders = "*"
break
}
if i == len(exposedHeaderConfigs)-1 {
exposedHeaders += header
break
}

exposedHeaders += header + ","
}

_ = cors.New(cors.Config{
AllowMethods: allowedMethods,
AllowOrigins: allowedOrigins,
AllowHeaders: allowedHeaders,
ExposeHeaders: exposedHeaders,
MaxAge: ConfigFacade.GetInt("cors.max_age"),
AllowCredentials: ConfigFacade.GetBool("cors.supports_credentials"),
})(ctx.Instance())

if ctx.Request().Origin().Method == nethttp.MethodOptions &&
ctx.Request().Header("Access-Control-Request-Method") != "" {
ctx.Request().AbortWithStatus(nethttp.StatusNoContent)
}
}

ctx.Request().Next()
}
}
184 changes: 184 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package fiber

import (
"net/http"
"testing"

configmocks "github.com/goravel/framework/contracts/config/mocks"
contractshttp "github.com/goravel/framework/contracts/http"
"github.com/stretchr/testify/assert"
)

func TestCors(t *testing.T) {
var (
mockConfig *configmocks.Config
resp *http.Response
)
beforeEach := func() {
mockConfig = &configmocks.Config{}
}

tests := []struct {
name string
setup func()
assert func()
}{
{
name: "allow all paths",
setup: func() {
mockConfig.On("GetString", "app.name", "Goravel").Return("Goravel").Once()
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
ConfigFacade = mockConfig
},
assert: func() {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "*", resp.Header.Get("Access-Control-Expose-Headers"))
},
},
{
name: "not allow path",
setup: func() {
mockConfig.On("GetString", "app.name", "Goravel").Return("Goravel").Once()
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"api"}).Once()
ConfigFacade = mockConfig
},
assert: func() {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Expose-Headers"))
},
},
{
name: "allow path with *",
setup: func() {
mockConfig.On("GetString", "app.name", "Goravel").Return("Goravel").Once()
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"any/*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
ConfigFacade = mockConfig
},
assert: func() {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "*", resp.Header.Get("Access-Control-Expose-Headers"))
},
},
{
name: "allow POST",
devhaozi marked this conversation as resolved.
Show resolved Hide resolved
setup: func() {
mockConfig.On("GetString", "app.name", "Goravel").Return("Goravel").Once()
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"GET"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
ConfigFacade = mockConfig
},
assert: func() {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "*", resp.Header.Get("Access-Control-Expose-Headers"))
},
},
{
name: "allow origin",
devhaozi marked this conversation as resolved.
Show resolved Hide resolved
setup: func() {
mockConfig.On("GetString", "app.name", "Goravel").Return("Goravel").Once()
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"goravel.dev"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
ConfigFacade = mockConfig
},
assert: func() {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "*", resp.Header.Get("Access-Control-Expose-Headers"))
},
},
{
name: "not allow exposed headers",
setup: func() {
mockConfig.On("GetString", "app.name", "Goravel").Return("Goravel").Once()
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"Goravel"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
ConfigFacade = mockConfig
},
assert: func() {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "Goravel", resp.Header.Get("Access-Control-Expose-Headers"))
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
beforeEach()
test.setup()

f := NewRoute(mockConfig)
f.Any("/any/{id}", func(ctx contractshttp.Context) {
ctx.Response().Success().Json(contractshttp.Json{
"id": ctx.Request().Input("id"),
})
})

req, err := http.NewRequest("POST", "/any/1", nil)
assert.Nil(t, err)
req.Header.Set("Origin", "http://127.0.0.1")

resp, err = f.Test(req)
assert.NoError(t, err, test.name)

test.assert()

mockConfig.AssertExpectations(t)
})
}
}
Loading