Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the Provider Match functionality to be overriden in the default AuthN decorator #99

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions cmd/dev_server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ func main() {
server.Logger(accLogger),
server.AddHeader("Content-Type", "application/json"),
server.AddHeader("X-Content-Type-Options", "nosniff"),
server.Authentication([]auth.Provider{
auth.NewMTLSAuthProvider(certPool),
auth.NewGitHubProvider(authTimeout),
auth.NewSpiffeAuthProvider(certPool),
auth.NewSpiffeAuthFallbackProvider(certPool),
}),
server.Authentication(
[]auth.Provider{
auth.NewMTLSAuthProvider(certPool),
auth.NewGitHubProvider(authTimeout),
auth.NewSpiffeAuthProvider(certPool),
auth.NewSpiffeAuthFallbackProvider(certPool),
},
nil),
}

r, err := server.GetRouter(cryptor, db, decorators, make([]server.Route, 0))
Expand Down
27 changes: 20 additions & 7 deletions server/decorators.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,28 @@ func buildRequest(req *http.Request, p knox.Principal, params map[string]string)
return r
}

// ProviderMatcher is a function that determines whether or not the specified
// authentication provider is suitable for the specified HTTP request. It is
// expected to return a boolean value detailing whether or not the specified
// provider is a match and is also expected to return any applicable
// authentication payload that would then be passed to the provider.
type ProviderMatcher func(provider auth.Provider, request *http.Request) (providerSupportsRequest bool, authenticationPayload string)

// Authentication sets the principal or returns an error if the principal cannot be authenticated.
func Authentication(providers []auth.Provider) func(http.HandlerFunc) http.HandlerFunc {
func Authentication(providers []auth.Provider, matcher ProviderMatcher) func(http.HandlerFunc) http.HandlerFunc {
if matcher == nil {
matcher = providerMatch
}

return func(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var defaultPrincipal knox.Principal
allPrincipals := map[string]knox.Principal{}
errReturned := fmt.Errorf("No matching authentication providers found")

for _, p := range providers {
if token, match := providerMatch(p, r.Header.Get("Authorization")); match {
principal, errAuthenticate := p.Authenticate(token, r)
if match, payload := matcher(p, r); match {
principal, errAuthenticate := p.Authenticate(payload, r)
if errAuthenticate != nil {
errReturned = errAuthenticate
continue
Expand All @@ -241,11 +252,13 @@ func Authentication(providers []auth.Provider) func(http.HandlerFunc) http.Handl
}
}

func providerMatch(provider auth.Provider, a string) (string, bool) {
if len(a) > 2 && a[0] == provider.Version() && a[1] == provider.Type() {
return a[2:], true
func providerMatch(provider auth.Provider, request *http.Request) (providerSupportsRequest bool, payload string) {
authorizationHeaderValue := request.Header.Get("Authorization")

if len(authorizationHeaderValue) > 2 && authorizationHeaderValue[0] == provider.Version() && authorizationHeaderValue[1] == provider.Type() {
return true, authorizationHeaderValue[2:]
}
return "", false
return false, ""
}

func parseParams(parameters []Parameter) func(http.HandlerFunc) http.HandlerFunc {
Expand Down
2 changes: 1 addition & 1 deletion server/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func setup() {
decorators := [](func(http.HandlerFunc) http.HandlerFunc){
AddHeader("Content-Type", "application/json"),
AddHeader("X-Content-Type-Options", "nosniff"),
Authentication([]auth.Provider{auth.MockGitHubProvider()}),
Authentication([]auth.Provider{auth.MockGitHubProvider()}, nil),
}
var err error
router, err = GetRouter(cryptor, db, decorators, make([]Route, 0))
Expand Down
Loading