From e814b34339b46de087418a54586d170e00957acf Mon Sep 17 00:00:00 2001 From: Rein Krul Date: Tue, 26 Sep 2023 14:19:14 +0200 Subject: [PATCH] IAM: Implement correct OAuth2 error handling --- auth/api/iam/api.go | 49 ++++--------- auth/api/iam/api_test.go | 97 ++++++++++++++++++++----- auth/api/iam/authorized_code.go | 17 ++--- auth/api/iam/error.go | 125 ++++++++++++++++++++++++++++++++ auth/api/iam/error_test.go | 123 +++++++++++++++++++++++++++++++ auth/api/iam/openid4vp.go | 12 ++- auth/api/iam/openid4vp_test.go | 34 +++++++++ cmd/root.go | 2 +- http/requestlogger.go | 14 ++-- 9 files changed, 402 insertions(+), 71 deletions(-) create mode 100644 auth/api/iam/error.go create mode 100644 auth/api/iam/error_test.go diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 816de10f7f..8652558e8f 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -29,7 +29,6 @@ import ( "github.com/nuts-foundation/nuts-node/auth/log" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/vcr" - "github.com/nuts-foundation/nuts-node/vcr/openid4vci" "github.com/nuts-foundation/nuts-node/vdr/didservice" vdr "github.com/nuts-foundation/nuts-node/vdr/types" "html/template" @@ -80,8 +79,7 @@ func (r Wrapper) Routes(router core.EchoRouter) { // Add http.Request to context, to allow reading URL query parameters requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request()) ctx.SetRequest(ctx.Request().WithContext(requestCtx)) - // TODO: Do we need a generic error handler? - // ctx.Set(core.ErrorWriterContextKey, &protocolErrorWriter{}) + ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{}) return f(ctx, request) } }, @@ -118,30 +116,21 @@ func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequ // Options: // - OpenID4VCI // - OpenID4VP, vp_token is sent in Token Response + panic("not implemented") case "vp_token": // Options: // - service-to-service vp_token flow + panic("not implemented") case "urn:ietf:params:oauth:grant-type:pre-authorized_code": // Options: // - OpenID4VCI + panic("not implemented") default: - // TODO: Don't use openid4vci package for errors - return nil, openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "invalid grant type", + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: "unsupported grant_type", } } - - // TODO: Handle? - //scope, err := handler(request.Body.AdditionalProperties) - //if err != nil { - // return nil, err - //} - // TODO: Generate access token with scope - return HandleTokenRequest200JSONResponse(TokenResponse{ - AccessToken: "", - }), nil } // HandleAuthorizeRequest handles calls to the authorization endpoint for starting an authorization code flow. @@ -161,7 +150,10 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho // TODO: Spec says that the redirect URI is optional, but it's not clear what to do if it's not provided. // Threat models say it's unsafe to omit redirect_uri. // See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 - return nil, errors.New("missing redirect URI") + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: "redirect_uri is required", + } } switch session.ResponseType { @@ -171,33 +163,24 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho // - OpenID4VCI; authorization code flow for credential issuance to (end-user) wallet // - OpenID4VP, vp_token is sent in Token Response; authorization code flow for presentation exchange (not required a.t.m.) // TODO: Switch on parameters to right flow + panic("not implemented") case responseTypeVPToken: // Options: // - OpenID4VP flow, vp_token is sent in Authorization Response // TODO: Check parameters for right flow // TODO: Do we actually need this? (probably not) + panic("not implemented") case responseTypeVPIDToken: // Options: // - OpenID4VP+SIOP flow, vp_token is sent in Authorization Response return r.handlePresentationRequest(params, session) default: // TODO: This should be a redirect? - // TODO: Don't use openid4vci package for errors - return nil, openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "invalid/unsupported response_type", + return nil, OAuth2Error{ + Code: UnsupportedResponseType, + RedirectURI: session.RedirectURI, } } - - // No handler could handle the request - // TODO: This should be a redirect? - // TODO: Don't use openid4vci package for errors - return nil, openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "missing or invalid parameters", - } } // OAuthAuthorizationServerMetadata returns the Authorization Server's metadata diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index 61bd6bc3ae..abdf666caa 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -19,6 +19,7 @@ package iam import ( + "context" "errors" ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/go-did/did" @@ -29,18 +30,20 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "net/http" "net/url" "testing" ) +var nutsDID = did.MustParseDID("did:nuts:123") + func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { - testDID := did.MustParseDID("did:nuts:123") t.Run("ok", func(t *testing.T) { // 200 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(true, nil) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(true, nil) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) require.NoError(t, err) assert.IsType(t, OAuthAuthorizationServerMetadata200JSONResponse{}, res) @@ -49,9 +52,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { t.Run("error - did not managed by this node", func(t *testing.T) { //404 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 404, statusCodeFrom(err)) assert.EqualError(t, err, "authz server metadata: did not owned") @@ -60,9 +63,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { t.Run("error - did does not exist", func(t *testing.T) { //404 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(false, vdr.ErrNotFound) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, vdr.ErrNotFound) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 404, statusCodeFrom(err)) assert.EqualError(t, err, "authz server metadata: unable to find the DID document") @@ -71,9 +74,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { t.Run("error - internal error 500", func(t *testing.T) { //500 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(false, errors.New("unknown error")) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, errors.New("unknown error")) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 500, statusCodeFrom(err)) assert.EqualError(t, err, "authz server metadata: unknown error") @@ -82,7 +85,6 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { } func TestWrapper_GetWebDID(t *testing.T) { - nutsDID := did.MustParseDID("did:nuts:123") webDID := did.MustParseDID("did:web:example.com:iam:123") publicURL := ssi.MustParseURI("https://example.com").URL webDIDBaseURL := publicURL.JoinPath("/iam") @@ -124,30 +126,29 @@ func TestWrapper_GetWebDID(t *testing.T) { } func TestWrapper_GetOAuthClientMetadata(t *testing.T) { - did := did.MustParseDID("did:nuts:123") t.Run("ok", func(t *testing.T) { ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, did).Return(true, nil) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(true, nil) - res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID}) + res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID}) require.NoError(t, err) assert.IsType(t, OAuthClientMetadata200JSONResponse{}, res) }) t.Run("error - did not managed by this node", func(t *testing.T) { ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, did) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID) - res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID}) + res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 404, statusCodeFrom(err)) assert.Nil(t, res) }) t.Run("error - internal error 500", func(t *testing.T) { ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, did).Return(false, errors.New("unknown error")) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, errors.New("unknown error")) - res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID}) + res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 500, statusCodeFrom(err)) assert.EqualError(t, err, "unknown error") @@ -155,6 +156,68 @@ func TestWrapper_GetOAuthClientMetadata(t *testing.T) { }) } +func TestWrapper_HandleAuthorizeRequest(t *testing.T) { + t.Run("missing redirect_uri", func(t *testing.T) { + ctx := newTestClient(t) + + res, err := ctx.client.HandleAuthorizeRequest(requestContext(map[string]string{}), HandleAuthorizeRequestRequestObject{ + Id: nutsDID.String(), + }) + + requireOAuthError(t, err, InvalidRequest, "redirect_uri is required") + assert.Nil(t, res) + }) + t.Run("unsupported response type", func(t *testing.T) { + ctx := newTestClient(t) + + res, err := ctx.client.HandleAuthorizeRequest(requestContext(map[string]string{ + "redirect_uri": "https://example.com", + "response_type": "unsupported", + }), HandleAuthorizeRequestRequestObject{ + Id: nutsDID.String(), + }) + + requireOAuthError(t, err, UnsupportedResponseType, "") + assert.Nil(t, res) + }) +} + +func TestWrapper_HandleTokenRequest(t *testing.T) { + t.Run("unsupported grant type", func(t *testing.T) { + ctx := newTestClient(t) + + res, err := ctx.client.HandleTokenRequest(nil, HandleTokenRequestRequestObject{ + Id: nutsDID.String(), + Body: &HandleTokenRequestFormdataRequestBody{ + GrantType: "unsupported", + }, + }) + + requireOAuthError(t, err, InvalidRequest, "unsupported grant_type") + assert.Nil(t, res) + }) +} + +func requireOAuthError(t *testing.T, err error, errorCode ErrorCode, errorDescription string) { + var oauthErr OAuth2Error + require.ErrorAs(t, err, &oauthErr) + assert.Equal(t, errorCode, oauthErr.Code) + assert.Equal(t, errorDescription, oauthErr.Description) +} + +func requestContext(queryParams map[string]string) context.Context { + vals := url.Values{} + for key, value := range queryParams { + vals.Add(key, value) + } + httpRequest := &http.Request{ + URL: &url.URL{ + RawQuery: vals.Encode(), + }, + } + return context.WithValue(audit.TestContext(), httpRequestContextKey, httpRequest) +} + // statusCodeFrom returns the statuscode if err is core.HTTPStatusCodeError, or 0 if it isn't func statusCodeFrom(err error) int { var SE core.HTTPStatusCodeError diff --git a/auth/api/iam/authorized_code.go b/auth/api/iam/authorized_code.go index 59faa1dbde..2e4cc2521b 100644 --- a/auth/api/iam/authorized_code.go +++ b/auth/api/iam/authorized_code.go @@ -26,7 +26,6 @@ import ( "fmt" "github.com/labstack/echo/v4" "github.com/nuts-foundation/nuts-node/core" - "github.com/nuts-foundation/nuts-node/vcr/openid4vci" "html/template" "net/http" "net/url" @@ -92,20 +91,16 @@ func (a authorizedCodeFlow) handleAuthConsent(c echo.Context) error { func (a authorizedCodeFlow) validateCode(params map[string]string) (string, error) { code, ok := params["code"] + invalidCodeError := OAuth2Error{ + Code: InvalidRequest, + Description: "missing or invalid code parameter", + } if !ok { - return "", openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "missing or invalid code parameter", - } + return "", invalidCodeError } session := a.sessions.Get(code) if session == nil { - return "", openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "invalid code", - } + return "", invalidCodeError } return session.Scope, nil } diff --git a/auth/api/iam/error.go b/auth/api/iam/error.go new file mode 100644 index 0000000000..8f243ebd74 --- /dev/null +++ b/auth/api/iam/error.go @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package iam + +import ( + "errors" + "github.com/labstack/echo/v4" + "github.com/nuts-foundation/nuts-node/core" + "net/http" + "net/url" + "strings" +) + +// ErrorCode specifies error codes as defined by the OAuth2 specifications. +// Codes and descriptions are taken from https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 +type ErrorCode string + +const ( + // InvalidRequest is returned when the request is missing a required parameter, includes an invalid parameter value, + // includes a parameter more than once, or is otherwise malformed. + InvalidRequest ErrorCode = "invalid_request" + // UnsupportedResponseType is returned when the authorization server does not support obtaining an authorization code using this method. + UnsupportedResponseType ErrorCode = "unsupported_response_type" + // ServerError is returned when the Authorization Server encounters an unexpected condition that prevents it from fulfilling the request. + ServerError ErrorCode = "server_error" +) + +// Make sure the error implements core.HTTPStatusCodeError, so the HTTP request logger can log the correct status code. +var _ core.HTTPStatusCodeError = OAuth2Error{} + +// OAuth2Error is an OAuth2 error that signals the error was (probably) caused by the client (e.g. bad request), +// or that the client can recover from the error (e.g. retry). +type OAuth2Error struct { + // Code is the error code as defined by the OAuth2 specification. + Code ErrorCode `json:"error"` + // Description is a human-readable ASCII [USASCII] text providing additional information, used to assist the client developer in understanding the error that occurred. + Description string `json:"error_description,omitempty"` + // InternalError is the underlying error, may be omitted. It is not intended to be returned to the client, only to be logged. + InternalError error `json:"-"` + // RedirectURI is the redirect URI that should be used to redirect the client to, in case the user-agent is a browser. + // It should not be set if the user-agent is not a browser, or there is no redirect_uri (because the request was malformed), this field is empty. + // When the field is set, the user-agent is redirected to the specified URI with the error code and description as query parameters. + // If it's not set, the error code and description are returned in the response body (plain text or JSON). + RedirectURI string `json:"-"` +} + +// StatusCode returns the HTTP status code to be returned to the client. It is defined for compatibility with core.HTTPStatusCodeError. +func (e OAuth2Error) StatusCode() int { + switch e.Code { + case ServerError: + return http.StatusInternalServerError + default: + return http.StatusBadRequest + } +} + +// OAuth2Error returns the error message, which is either the underlying error or the code if there is no underlying error +func (e OAuth2Error) Error() string { + var parts []string + parts = append(parts, string(e.Code)) + if e.InternalError != nil { + parts = append(parts, e.InternalError.Error()) + } + if e.Description != "" { + parts = append(parts, e.Description) + } + return strings.Join(parts, " - ") +} + +type oauth2ErrorWriter struct{} + +func (p oauth2ErrorWriter) Write(echoContext echo.Context, _ int, _ string, err error) error { + var oauthErr OAuth2Error + if !errors.As(err, &oauthErr) { + // Internal error, wrap it in an OAuth2 error + oauthErr = OAuth2Error{ + Code: ServerError, + InternalError: err, + } + } + if oauthErr.Code == "" { + // Somebody forgot to set a code + oauthErr.Code = ServerError + } + redirectURI, _ := url.Parse(oauthErr.RedirectURI) + if oauthErr.RedirectURI == "" || redirectURI == nil { + // Can't redirect the user-agent back, render error as JSON or plain text (depending on content-type) + contentType := echoContext.Request().Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + // Return JSON response + return echoContext.JSON(oauthErr.StatusCode(), oauthErr) + } else { + // Return plain text response + parts := []string{string(oauthErr.Code)} + if oauthErr.Description != "" { + parts = append(parts, oauthErr.Description) + } + return echoContext.String(oauthErr.StatusCode(), strings.Join(parts, " - ")) + } + } + // Redirect the user-agent back to the client + query := redirectURI.Query() + query.Set("error", string(oauthErr.Code)) + if oauthErr.Description != "" { + query.Set("error_description", oauthErr.Description) + } + redirectURI.RawQuery = query.Encode() + return echoContext.Redirect(http.StatusFound, redirectURI.String()) +} diff --git a/auth/api/iam/error_test.go b/auth/api/iam/error_test.go new file mode 100644 index 0000000000..9225282c92 --- /dev/null +++ b/auth/api/iam/error_test.go @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package iam + +import ( + "errors" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestError_Error(t *testing.T) { + t.Run("with underlying error", func(t *testing.T) { + assert.EqualError(t, OAuth2Error{InternalError: errors.New("token has expired"), Code: InvalidRequest}, "invalid_request - token has expired") + }) + t.Run("without underlying error", func(t *testing.T) { + assert.EqualError(t, OAuth2Error{Code: InvalidRequest}, "invalid_request") + }) +} + +func Test_oauth2ErrorWriter_Write(t *testing.T) { + t.Run("user-agent is browser with redirect URI", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Code: InvalidRequest, + Description: "failure", + RedirectURI: "https://example.com", + }) + + assert.NoError(t, err) + assert.Equal(t, http.StatusFound, rec.Code) + assert.Equal(t, "https://example.com?error=invalid_request&error_description=failure", rec.Header().Get("Location")) + }) + t.Run("user-agent is browser without redirect URI", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Code: InvalidRequest, + Description: "failure", + }) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get("Content-Type")) + assert.Equal(t, "invalid_request - failure", string(body)) + assert.Empty(t, rec.Header().Get("Location")) + }) + t.Run("user-agent is API client (sent JSON)", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + httpRequest.Header["Content-Type"] = []string{"application/json"} + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Code: InvalidRequest, + Description: "failure", + }) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "application/json; charset=UTF-8", rec.Header().Get("Content-Type")) + assert.Equal(t, `{"error":"invalid_request","error_description":"failure"}`, strings.TrimSpace(string(body))) + assert.Empty(t, rec.Header().Get("Location")) + }) + t.Run("OAuth2 error without code, defaults to server_error", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Description: "failure", + }) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, `server_error - failure`, strings.TrimSpace(string(body))) + }) + t.Run("error is not an OAuth2 error, should be wrapped", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", errors.New("catastrophic")) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, `server_error`, strings.TrimSpace(string(body))) + }) +} diff --git a/auth/api/iam/openid4vp.go b/auth/api/iam/openid4vp.go index 46f7aa5730..76313d480b 100644 --- a/auth/api/iam/openid4vp.go +++ b/auth/api/iam/openid4vp.go @@ -81,14 +81,22 @@ func (r *Wrapper) handlePresentationRequest(params map[string]string, session *S } // Response mode is always direct_post for now if params[responseModeParam] != responseModeDirectPost { - return nil, errors.New("response_mode must be direct_post") + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: "response_mode must be direct_post", + RedirectURI: session.RedirectURI, + } } // TODO: This is the easiest for now, but is this the way? // For compatibility, we probably need to support presentation_definition and/or presentation_definition_uri. presentationDefinition := r.auth.PresentationDefinitions().ByScope(params[scopeParam]) if presentationDefinition == nil { - return nil, fmt.Errorf("unsupported scope for presentation exchange: %s", params[scopeParam]) + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: fmt.Sprintf("unsupported scope for presentation exchange: %s", params[scopeParam]), + RedirectURI: session.RedirectURI, + } } // Render HTML diff --git a/auth/api/iam/openid4vp_test.go b/auth/api/iam/openid4vp_test.go index 32ff266848..05918e7661 100644 --- a/auth/api/iam/openid4vp_test.go +++ b/auth/api/iam/openid4vp_test.go @@ -118,6 +118,40 @@ func TestWrapper_handlePresentationRequest(t *testing.T) { require.Equal(t, http.StatusOK, httpResponse.statusCode) assert.Contains(t, httpResponse.body.String(), "") }) + t.Run("unsupported scope", func(t *testing.T) { + ctrl := gomock.NewController(t) + peStore := &pe.DefinitionResolver{} + _ = peStore.LoadFromFile("test/presentation_definition_mapping.json") + mockAuth := auth.NewMockAuthenticationServices(ctrl) + mockAuth.EXPECT().PresentationDefinitions().Return(peStore) + instance := New(mockAuth, nil, nil) + + params := map[string]string{ + "scope": "unsupported", + "response_type": "code", + "response_mode": "direct_post", + "client_metadata_uri": "https://example.com/client_metadata.xml", + } + + response, err := instance.handlePresentationRequest(params, createSession(params, holderDID)) + + requireOAuthError(t, err, InvalidRequest, "unsupported scope for presentation exchange: unsupported") + assert.Nil(t, response) + }) + t.Run("invalid response_mode", func(t *testing.T) { + instance := New(nil, nil, nil) + params := map[string]string{ + "scope": "eOverdracht-overdrachtsbericht", + "response_type": "code", + "response_mode": "invalid", + "client_metadata_uri": "https://example.com/client_metadata.xml", + } + + response, err := instance.handlePresentationRequest(params, createSession(params, holderDID)) + + requireOAuthError(t, err, InvalidRequest, "response_mode must be direct_post") + assert.Nil(t, response) + }) } type stubResponseWriter struct { diff --git a/cmd/root.go b/cmd/root.go index ec8869fd16..c0cd5db74c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -160,7 +160,7 @@ func startServer(ctx context.Context, system *core.System) error { logrus.Info("Shutting down...") err := system.Shutdown() if err != nil { - logrus.Errorf("Error shutting down system: %v", err) + logrus.Errorf("OAuth2Error shutting down system: %v", err) } else { logrus.Info("Shutdown complete. Goodbye!") } diff --git a/http/requestlogger.go b/http/requestlogger.go index 67385f565a..14156c038a 100644 --- a/http/requestlogger.go +++ b/http/requestlogger.go @@ -21,7 +21,6 @@ package http import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" - "github.com/nuts-foundation/nuts-node/core" "github.com/sirupsen/logrus" "mime" "net/http" @@ -40,12 +39,13 @@ func requestLoggerMiddleware(skipper middleware.Skipper, logger *logrus.Entry) e LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { status := values.Status if values.Error != nil { - switch errWithStatus := values.Error.(type) { - case *echo.HTTPError: - status = errWithStatus.Code - case core.HTTPStatusCodeError: - status = errWithStatus.StatusCode() - default: + // In case the error provides `func StatusCode() int` + // (e.g. core.HTTPStatusCodeError) + if x, ok := values.Error.(interface{ StatusCode() int }); ok { + status = x.StatusCode() + } else if x, ok := values.Error.(*echo.HTTPError); ok { + status = x.Code + } else { status = http.StatusInternalServerError } }