From 7ab8356d410c374b83f731b80523d9993b0c1390 Mon Sep 17 00:00:00 2001 From: reinkrul Date: Mon, 2 Oct 2023 12:24:08 +0200 Subject: [PATCH] IAM: Implement correct OAuth2 error handling (#2515) --- auth/api/iam/api.go | 95 +++++++++---------- auth/api/iam/api_test.go | 162 ++++++++++++++++++++++++++++---- auth/api/iam/authorized_code.go | 17 ++-- auth/api/iam/error.go | 127 +++++++++++++++++++++++++ auth/api/iam/error_test.go | 123 ++++++++++++++++++++++++ auth/api/iam/generated.go | 6 -- auth/api/iam/openid4vp.go | 13 ++- auth/api/iam/openid4vp_test.go | 34 +++++++ auth/api/iam/types.go | 3 + codegen/configs/auth_iam.yaml | 1 + http/requestlogger.go | 14 +-- 11 files changed, 500 insertions(+), 95 deletions(-) create mode 100644 auth/api/iam/error.go create mode 100644 auth/api/iam/error_test.go diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 109a21089c..ce87ed81d9 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -29,11 +29,11 @@ import ( "github.com/nuts-foundation/nuts-node/auth/log" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/vcr" - "github.com/nuts-foundation/nuts-node/vcr/openid4vci" "github.com/nuts-foundation/nuts-node/vdr" "github.com/nuts-foundation/nuts-node/vdr/resolver" "html/template" "net/http" + "strings" "sync" ) @@ -41,6 +41,7 @@ var _ core.Routable = &Wrapper{} var _ StrictServerInterface = &Wrapper{} const apiPath = "iam" +const apiModuleName = auth.ModuleName + "/" + apiPath const httpRequestContextKey = "http-request" //go:embed assets @@ -71,31 +72,12 @@ func New(authInstance auth.AuthenticationServices, vcrInstance vcr.VCR, vdrInsta } func (r Wrapper) Routes(router core.EchoRouter) { - const apiModuleName = auth.ModuleName + "/" + apiPath RegisterHandlers(router, NewStrictHandler(r, []StrictMiddlewareFunc{ func(f StrictHandlerFunc, operationID string) StrictHandlerFunc { return func(ctx echo.Context, request interface{}) (response interface{}, err error) { - ctx.Set(core.OperationIDContextKey, operationID) - ctx.Set(core.ModuleNameContextKey, apiModuleName) - // Add http.Request to context, to allow reading URL query parameters - requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request()) - ctx.SetRequest(ctx.Request().WithContext(requestCtx)) - // TODO: Do we need a generic error handler? - // ctx.Set(core.ErrorWriterContextKey, &protocolErrorWriter{}) - return f(ctx, request) + return r.middleware(ctx, request, operationID, f) } }, - func(f StrictHandlerFunc, operationID string) StrictHandlerFunc { - return func(ctx echo.Context, args interface{}) (interface{}, error) { - if !r.auth.V2APIEnabled() { - return nil, core.Error(http.StatusForbidden, "Access denied") - } - return f(ctx, args) - } - }, - func(f StrictHandlerFunc, operationID string) StrictHandlerFunc { - return audit.StrictMiddleware(f, apiModuleName, operationID) - }, })) auditMiddleware := audit.Middleware(apiModuleName) // The following handler is of the OpenID4VCI wallet which is called by the holder (wallet owner) @@ -111,6 +93,24 @@ func (r Wrapper) Routes(router core.EchoRouter) { router.POST("/iam/:did/openid4vp_demo", r.handleOpenID4VPDemoSendRequest, auditMiddleware) } +func (r Wrapper) middleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) { + ctx.Set(core.OperationIDContextKey, operationID) + ctx.Set(core.ModuleNameContextKey, apiModuleName) + + if !r.auth.V2APIEnabled() { + return nil, core.Error(http.StatusForbidden, "Access denied") + } + + // Add http.Request to context, to allow reading URL query parameters + requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request()) + ctx.SetRequest(ctx.Request().WithContext(requestCtx)) + if strings.HasPrefix(ctx.Request().URL.Path, "/iam/") { + ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{}) + } + audit.StrictMiddleware(f, apiModuleName, operationID) + return f(ctx, request) +} + // HandleTokenRequest handles calls to the token endpoint for exchanging a grant (e.g authorization code or pre-authorized code) for an access token. func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequestRequestObject) (HandleTokenRequestResponseObject, error) { switch request.Body.GrantType { @@ -118,30 +118,29 @@ func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequ // Options: // - OpenID4VCI // - OpenID4VP, vp_token is sent in Token Response + return nil, OAuth2Error{ + Code: UnsupportedGrantType, + Description: "not implemented yet", + } case "vp_token": // Options: // - service-to-service vp_token flow + return nil, OAuth2Error{ + Code: UnsupportedGrantType, + Description: "not implemented yet", + } case "urn:ietf:params:oauth:grant-type:pre-authorized_code": // Options: // - OpenID4VCI + return nil, OAuth2Error{ + Code: UnsupportedGrantType, + Description: "not implemented yet", + } default: - // TODO: Don't use openid4vci package for errors - return nil, openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "invalid grant type", + return nil, OAuth2Error{ + Code: UnsupportedGrantType, } } - - // TODO: Handle? - //scope, err := handler(request.Body.AdditionalProperties) - //if err != nil { - // return nil, err - //} - // TODO: Generate access token with scope - return HandleTokenRequest200JSONResponse(TokenResponse{ - AccessToken: "", - }), nil } // HandleAuthorizeRequest handles calls to the authorization endpoint for starting an authorization code flow. @@ -161,7 +160,10 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho // TODO: Spec says that the redirect URI is optional, but it's not clear what to do if it's not provided. // Threat models say it's unsafe to omit redirect_uri. // See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 - return nil, errors.New("missing redirect URI") + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: "redirect_uri is required", + } } switch session.ResponseType { @@ -171,33 +173,24 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho // - OpenID4VCI; authorization code flow for credential issuance to (end-user) wallet // - OpenID4VP, vp_token is sent in Token Response; authorization code flow for presentation exchange (not required a.t.m.) // TODO: Switch on parameters to right flow + panic("not implemented") case responseTypeVPToken: // Options: // - OpenID4VP flow, vp_token is sent in Authorization Response // TODO: Check parameters for right flow // TODO: Do we actually need this? (probably not) + panic("not implemented") case responseTypeVPIDToken: // Options: // - OpenID4VP+SIOP flow, vp_token is sent in Authorization Response return r.handlePresentationRequest(params, session) default: // TODO: This should be a redirect? - // TODO: Don't use openid4vci package for errors - return nil, openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "invalid/unsupported response_type", + return nil, OAuth2Error{ + Code: UnsupportedResponseType, + RedirectURI: session.RedirectURI, } } - - // No handler could handle the request - // TODO: This should be a redirect? - // TODO: Don't use openid4vci package for errors - return nil, openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "missing or invalid parameters", - } } // OAuthAuthorizationServerMetadata returns the Authorization Server's metadata diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index 3b2d25b524..bff9a8c8c2 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -19,7 +19,9 @@ package iam import ( + "context" "errors" + "github.com/labstack/echo/v4" ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/audit" @@ -30,18 +32,21 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "net/http" + "net/http/httptest" "net/url" "testing" ) +var nutsDID = did.MustParseDID("did:nuts:123") + func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { - testDID := did.MustParseDID("did:nuts:123") t.Run("ok", func(t *testing.T) { // 200 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(true, nil) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(true, nil) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) require.NoError(t, err) assert.IsType(t, OAuthAuthorizationServerMetadata200JSONResponse{}, res) @@ -50,9 +55,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { t.Run("error - did not managed by this node", func(t *testing.T) { //404 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 404, statusCodeFrom(err)) assert.EqualError(t, err, "authz server metadata: did not owned") @@ -61,9 +66,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { t.Run("error - did does not exist", func(t *testing.T) { //404 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(false, resolver.ErrNotFound) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, resolver.ErrNotFound) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 404, statusCodeFrom(err)) assert.EqualError(t, err, "authz server metadata: unable to find the DID document") @@ -72,9 +77,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { t.Run("error - internal error 500", func(t *testing.T) { //500 ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(false, errors.New("unknown error")) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, errors.New("unknown error")) - res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID}) + res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 500, statusCodeFrom(err)) assert.EqualError(t, err, "authz server metadata: unknown error") @@ -83,7 +88,6 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) { } func TestWrapper_GetWebDID(t *testing.T) { - nutsDID := did.MustParseDID("did:nuts:123") webDID := did.MustParseDID("did:web:example.com:iam:123") publicURL := ssi.MustParseURI("https://example.com").URL webDIDBaseURL := publicURL.JoinPath("/iam") @@ -125,30 +129,29 @@ func TestWrapper_GetWebDID(t *testing.T) { } func TestWrapper_GetOAuthClientMetadata(t *testing.T) { - did := did.MustParseDID("did:nuts:123") t.Run("ok", func(t *testing.T) { ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, did).Return(true, nil) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(true, nil) - res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID}) + res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID}) require.NoError(t, err) assert.IsType(t, OAuthClientMetadata200JSONResponse{}, res) }) t.Run("error - did not managed by this node", func(t *testing.T) { ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, did) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID) - res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID}) + res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 404, statusCodeFrom(err)) assert.Nil(t, res) }) t.Run("error - internal error 500", func(t *testing.T) { ctx := newTestClient(t) - ctx.vdr.EXPECT().IsOwner(nil, did).Return(false, errors.New("unknown error")) + ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, errors.New("unknown error")) - res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID}) + res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID}) assert.Equal(t, 500, statusCodeFrom(err)) assert.EqualError(t, err, "unknown error") @@ -156,6 +159,68 @@ func TestWrapper_GetOAuthClientMetadata(t *testing.T) { }) } +func TestWrapper_HandleAuthorizeRequest(t *testing.T) { + t.Run("missing redirect_uri", func(t *testing.T) { + ctx := newTestClient(t) + + res, err := ctx.client.HandleAuthorizeRequest(requestContext(map[string]string{}), HandleAuthorizeRequestRequestObject{ + Id: nutsDID.String(), + }) + + requireOAuthError(t, err, InvalidRequest, "redirect_uri is required") + assert.Nil(t, res) + }) + t.Run("unsupported response type", func(t *testing.T) { + ctx := newTestClient(t) + + res, err := ctx.client.HandleAuthorizeRequest(requestContext(map[string]string{ + "redirect_uri": "https://example.com", + "response_type": "unsupported", + }), HandleAuthorizeRequestRequestObject{ + Id: nutsDID.String(), + }) + + requireOAuthError(t, err, UnsupportedResponseType, "") + assert.Nil(t, res) + }) +} + +func TestWrapper_HandleTokenRequest(t *testing.T) { + t.Run("unsupported grant type", func(t *testing.T) { + ctx := newTestClient(t) + + res, err := ctx.client.HandleTokenRequest(nil, HandleTokenRequestRequestObject{ + Id: nutsDID.String(), + Body: &HandleTokenRequestFormdataRequestBody{ + GrantType: "unsupported", + }, + }) + + requireOAuthError(t, err, UnsupportedGrantType, "") + assert.Nil(t, res) + }) +} + +func requireOAuthError(t *testing.T, err error, errorCode ErrorCode, errorDescription string) { + var oauthErr OAuth2Error + require.ErrorAs(t, err, &oauthErr) + assert.Equal(t, errorCode, oauthErr.Code) + assert.Equal(t, errorDescription, oauthErr.Description) +} + +func requestContext(queryParams map[string]string) context.Context { + vals := url.Values{} + for key, value := range queryParams { + vals.Add(key, value) + } + httpRequest := &http.Request{ + URL: &url.URL{ + RawQuery: vals.Encode(), + }, + } + return context.WithValue(audit.TestContext(), httpRequestContextKey, httpRequest) +} + // statusCodeFrom returns the statuscode if err is core.HTTPStatusCodeError, or 0 if it isn't func statusCodeFrom(err error) int { var SE core.HTTPStatusCodeError @@ -191,3 +256,66 @@ func newTestClient(t testing.TB) *testCtx { }, } } + +func TestWrapper_Routes(t *testing.T) { + ctrl := gomock.NewController(t) + router := core.NewMockEchoRouter(ctrl) + + router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + Wrapper{}.Routes(router) +} + +func TestWrapper_middleware(t *testing.T) { + server := echo.New() + ctrl := gomock.NewController(t) + authService := auth.NewMockAuthenticationServices(ctrl) + authService.EXPECT().V2APIEnabled().Return(true).AnyTimes() + + t.Run("API enabling", func(t *testing.T) { + t.Run("enabled", func(t *testing.T) { + var called strictServerCallCapturer + + ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle) + + assert.True(t, bool(called)) + }) + t.Run("disabled", func(t *testing.T) { + var called strictServerCallCapturer + + authService := auth.NewMockAuthenticationServices(ctrl) + authService.EXPECT().V2APIEnabled().Return(false).AnyTimes() + + ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle) + + assert.False(t, bool(called)) + }) + }) + + t.Run("OAuth2 error handling", func(t *testing.T) { + var handler strictServerCallCapturer + t.Run("OAuth2 path", func(t *testing.T) { + ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle) + + assert.IsType(t, &oauth2ErrorWriter{}, ctx.Get(core.ErrorWriterContextKey)) + }) + t.Run("other path", func(t *testing.T) { + ctx := server.NewContext(httptest.NewRequest("GET", "/internal/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle) + + assert.Nil(t, ctx.Get(core.ErrorWriterContextKey)) + }) + }) + +} + +type strictServerCallCapturer bool + +func (s *strictServerCallCapturer) handle(ctx echo.Context, request interface{}) (response interface{}, err error) { + *s = true + return nil, nil +} diff --git a/auth/api/iam/authorized_code.go b/auth/api/iam/authorized_code.go index 59faa1dbde..2e4cc2521b 100644 --- a/auth/api/iam/authorized_code.go +++ b/auth/api/iam/authorized_code.go @@ -26,7 +26,6 @@ import ( "fmt" "github.com/labstack/echo/v4" "github.com/nuts-foundation/nuts-node/core" - "github.com/nuts-foundation/nuts-node/vcr/openid4vci" "html/template" "net/http" "net/url" @@ -92,20 +91,16 @@ func (a authorizedCodeFlow) handleAuthConsent(c echo.Context) error { func (a authorizedCodeFlow) validateCode(params map[string]string) (string, error) { code, ok := params["code"] + invalidCodeError := OAuth2Error{ + Code: InvalidRequest, + Description: "missing or invalid code parameter", + } if !ok { - return "", openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "missing or invalid code parameter", - } + return "", invalidCodeError } session := a.sessions.Get(code) if session == nil { - return "", openid4vci.Error{ - Code: openid4vci.InvalidRequest, - StatusCode: http.StatusBadRequest, - //Description: "invalid code", - } + return "", invalidCodeError } return session.Scope, nil } diff --git a/auth/api/iam/error.go b/auth/api/iam/error.go new file mode 100644 index 0000000000..a70c005052 --- /dev/null +++ b/auth/api/iam/error.go @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package iam + +import ( + "errors" + "github.com/labstack/echo/v4" + "github.com/nuts-foundation/nuts-node/core" + "net/http" + "net/url" + "strings" +) + +// ErrorCode specifies error codes as defined by the OAuth2 specifications. +// Codes and descriptions are taken from https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 +type ErrorCode string + +const ( + // InvalidRequest is returned when the request is missing a required parameter, includes an invalid parameter value, + // includes a parameter more than once, or is otherwise malformed. + InvalidRequest ErrorCode = "invalid_request" + // UnsupportedGrantType is returned when the authorization grant type is not supported by the authorization server. + UnsupportedGrantType ErrorCode = "unsupported_grant_type" + // UnsupportedResponseType is returned when the authorization server does not support obtaining an authorization code using this method. + UnsupportedResponseType ErrorCode = "unsupported_response_type" + // ServerError is returned when the Authorization Server encounters an unexpected condition that prevents it from fulfilling the request. + ServerError ErrorCode = "server_error" +) + +// Make sure the error implements core.HTTPStatusCodeError, so the HTTP request logger can log the correct status code. +var _ core.HTTPStatusCodeError = OAuth2Error{} + +// OAuth2Error is an OAuth2 error that signals the error was (probably) caused by the client (e.g. bad request), +// or that the client can recover from the error (e.g. retry). +type OAuth2Error struct { + // Code is the error code as defined by the OAuth2 specification. + Code ErrorCode `json:"error"` + // Description is a human-readable ASCII [USASCII] text providing additional information, used to assist the client developer in understanding the error that occurred. + Description string `json:"error_description,omitempty"` + // InternalError is the underlying error, may be omitted. It is not intended to be returned to the client, only to be logged. + InternalError error `json:"-"` + // RedirectURI is the redirect URI that should be used to redirect the client to, in case the user-agent is a browser. + // It should not be set if the user-agent is not a browser, or there is no redirect_uri (because the request was malformed), this field is empty. + // When the field is set, the user-agent is redirected to the specified URI with the error code and description as query parameters. + // If it's not set, the error code and description are returned in the response body (plain text or JSON). + RedirectURI string `json:"-"` +} + +// StatusCode returns the HTTP status code to be returned to the client, in case the user-agent can't be redirected with HTTP 302 - Found. +func (e OAuth2Error) StatusCode() int { + switch e.Code { + case ServerError: + return http.StatusInternalServerError + default: + return http.StatusBadRequest + } +} + +// OAuth2Error returns the error message, which is either the underlying error or the code if there is no underlying error +func (e OAuth2Error) Error() string { + var parts []string + parts = append(parts, string(e.Code)) + if e.InternalError != nil { + parts = append(parts, e.InternalError.Error()) + } + if e.Description != "" { + parts = append(parts, e.Description) + } + return strings.Join(parts, " - ") +} + +type oauth2ErrorWriter struct{} + +func (p oauth2ErrorWriter) Write(echoContext echo.Context, _ int, _ string, err error) error { + var oauthErr OAuth2Error + if !errors.As(err, &oauthErr) { + // Internal error, wrap it in an OAuth2 error + oauthErr = OAuth2Error{ + Code: ServerError, + InternalError: err, + } + } + if oauthErr.Code == "" { + // Somebody forgot to set a code + oauthErr.Code = ServerError + } + redirectURI, _ := url.Parse(oauthErr.RedirectURI) + if oauthErr.RedirectURI == "" || redirectURI == nil { + // Can't redirect the user-agent back, render error as JSON or plain text (depending on content-type) + contentType := echoContext.Request().Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + // Return JSON response + return echoContext.JSON(oauthErr.StatusCode(), oauthErr) + } else { + // Return plain text response + parts := []string{string(oauthErr.Code)} + if oauthErr.Description != "" { + parts = append(parts, oauthErr.Description) + } + return echoContext.String(oauthErr.StatusCode(), strings.Join(parts, " - ")) + } + } + // Redirect the user-agent back to the client + query := redirectURI.Query() + query.Set("error", string(oauthErr.Code)) + if oauthErr.Description != "" { + query.Set("error_description", oauthErr.Description) + } + redirectURI.RawQuery = query.Encode() + return echoContext.Redirect(http.StatusFound, redirectURI.String()) +} diff --git a/auth/api/iam/error_test.go b/auth/api/iam/error_test.go new file mode 100644 index 0000000000..9225282c92 --- /dev/null +++ b/auth/api/iam/error_test.go @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package iam + +import ( + "errors" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestError_Error(t *testing.T) { + t.Run("with underlying error", func(t *testing.T) { + assert.EqualError(t, OAuth2Error{InternalError: errors.New("token has expired"), Code: InvalidRequest}, "invalid_request - token has expired") + }) + t.Run("without underlying error", func(t *testing.T) { + assert.EqualError(t, OAuth2Error{Code: InvalidRequest}, "invalid_request") + }) +} + +func Test_oauth2ErrorWriter_Write(t *testing.T) { + t.Run("user-agent is browser with redirect URI", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Code: InvalidRequest, + Description: "failure", + RedirectURI: "https://example.com", + }) + + assert.NoError(t, err) + assert.Equal(t, http.StatusFound, rec.Code) + assert.Equal(t, "https://example.com?error=invalid_request&error_description=failure", rec.Header().Get("Location")) + }) + t.Run("user-agent is browser without redirect URI", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Code: InvalidRequest, + Description: "failure", + }) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get("Content-Type")) + assert.Equal(t, "invalid_request - failure", string(body)) + assert.Empty(t, rec.Header().Get("Location")) + }) + t.Run("user-agent is API client (sent JSON)", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + httpRequest.Header["Content-Type"] = []string{"application/json"} + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Code: InvalidRequest, + Description: "failure", + }) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "application/json; charset=UTF-8", rec.Header().Get("Content-Type")) + assert.Equal(t, `{"error":"invalid_request","error_description":"failure"}`, strings.TrimSpace(string(body))) + assert.Empty(t, rec.Header().Get("Location")) + }) + t.Run("OAuth2 error without code, defaults to server_error", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{ + Description: "failure", + }) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, `server_error - failure`, strings.TrimSpace(string(body))) + }) + t.Run("error is not an OAuth2 error, should be wrapped", func(t *testing.T) { + server := echo.New() + httpRequest := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + ctx := server.NewContext(httpRequest, rec) + + err := oauth2ErrorWriter{}.Write(ctx, 0, "", errors.New("catastrophic")) + + assert.NoError(t, err) + body, _ := io.ReadAll(rec.Body) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, `server_error`, strings.TrimSpace(string(body))) + }) +} diff --git a/auth/api/iam/generated.go b/auth/api/iam/generated.go index 1aff918daa..4acaf152bb 100644 --- a/auth/api/iam/generated.go +++ b/auth/api/iam/generated.go @@ -15,12 +15,6 @@ import ( strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" ) -// ErrorResponse defines model for ErrorResponse. -type ErrorResponse struct { - // Error Code identifying the error that occurred. - Error string `json:"error"` -} - // TokenResponse Token Responses are made as defined in (RFC6749)[https://datatracker.ietf.org/doc/html/rfc6749#section-5.1] type TokenResponse struct { // AccessToken The access token issued by the authorization server. diff --git a/auth/api/iam/openid4vp.go b/auth/api/iam/openid4vp.go index 46f7aa5730..30f467a85a 100644 --- a/auth/api/iam/openid4vp.go +++ b/auth/api/iam/openid4vp.go @@ -81,14 +81,22 @@ func (r *Wrapper) handlePresentationRequest(params map[string]string, session *S } // Response mode is always direct_post for now if params[responseModeParam] != responseModeDirectPost { - return nil, errors.New("response_mode must be direct_post") + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: "response_mode must be direct_post", + RedirectURI: session.RedirectURI, + } } // TODO: This is the easiest for now, but is this the way? // For compatibility, we probably need to support presentation_definition and/or presentation_definition_uri. presentationDefinition := r.auth.PresentationDefinitions().ByScope(params[scopeParam]) if presentationDefinition == nil { - return nil, fmt.Errorf("unsupported scope for presentation exchange: %s", params[scopeParam]) + return nil, OAuth2Error{ + Code: InvalidRequest, + Description: fmt.Sprintf("unsupported scope for presentation exchange: %s", params[scopeParam]), + RedirectURI: session.RedirectURI, + } } // Render HTML @@ -204,7 +212,6 @@ func (r *Wrapper) handlePresentationRequestAccept(c echo.Context) error { } func (r *Wrapper) handlePresentationRequestCompleted(ctx echo.Context) error { - // TODO: support error response // TODO: response direct_post mode vpToken := ctx.QueryParams()[vpTokenParam] if len(vpToken) == 0 { diff --git a/auth/api/iam/openid4vp_test.go b/auth/api/iam/openid4vp_test.go index effc889e6d..91fec61bfd 100644 --- a/auth/api/iam/openid4vp_test.go +++ b/auth/api/iam/openid4vp_test.go @@ -118,6 +118,40 @@ func TestWrapper_handlePresentationRequest(t *testing.T) { require.Equal(t, http.StatusOK, httpResponse.statusCode) assert.Contains(t, httpResponse.body.String(), "") }) + t.Run("unsupported scope", func(t *testing.T) { + ctrl := gomock.NewController(t) + peStore := &pe.DefinitionResolver{} + _ = peStore.LoadFromFile("test/presentation_definition_mapping.json") + mockAuth := auth.NewMockAuthenticationServices(ctrl) + mockAuth.EXPECT().PresentationDefinitions().Return(peStore) + instance := New(mockAuth, nil, nil) + + params := map[string]string{ + "scope": "unsupported", + "response_type": "code", + "response_mode": "direct_post", + "client_metadata_uri": "https://example.com/client_metadata.xml", + } + + response, err := instance.handlePresentationRequest(params, createSession(params, holderDID)) + + requireOAuthError(t, err, InvalidRequest, "unsupported scope for presentation exchange: unsupported") + assert.Nil(t, response) + }) + t.Run("invalid response_mode", func(t *testing.T) { + instance := New(nil, nil, nil) + params := map[string]string{ + "scope": "eOverdracht-overdrachtsbericht", + "response_type": "code", + "response_mode": "invalid", + "client_metadata_uri": "https://example.com/client_metadata.xml", + } + + response, err := instance.handlePresentationRequest(params, createSession(params, holderDID)) + + requireOAuthError(t, err, InvalidRequest, "response_mode must be direct_post") + assert.Nil(t, response) + }) } type stubResponseWriter struct { diff --git a/auth/api/iam/types.go b/auth/api/iam/types.go index 901000ab96..76c128ba28 100644 --- a/auth/api/iam/types.go +++ b/auth/api/iam/types.go @@ -29,6 +29,9 @@ type DIDDocument = did.Document // DIDDocumentMetadata is an alias type DIDDocumentMetadata = resolver.DocumentMetadata +// ErrorResponse is an alias +type ErrorResponse = OAuth2Error + const ( // responseTypeParam is the name of the response_type parameter. // Specified by https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.1 diff --git a/codegen/configs/auth_iam.yaml b/codegen/configs/auth_iam.yaml index c2d29e3163..98b1f763ee 100644 --- a/codegen/configs/auth_iam.yaml +++ b/codegen/configs/auth_iam.yaml @@ -9,3 +9,4 @@ output-options: - DIDDocument - OAuthAuthorizationServerMetadata - OAuthClientMetadata + - ErrorResponse diff --git a/http/requestlogger.go b/http/requestlogger.go index 67385f565a..14156c038a 100644 --- a/http/requestlogger.go +++ b/http/requestlogger.go @@ -21,7 +21,6 @@ package http import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" - "github.com/nuts-foundation/nuts-node/core" "github.com/sirupsen/logrus" "mime" "net/http" @@ -40,12 +39,13 @@ func requestLoggerMiddleware(skipper middleware.Skipper, logger *logrus.Entry) e LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { status := values.Status if values.Error != nil { - switch errWithStatus := values.Error.(type) { - case *echo.HTTPError: - status = errWithStatus.Code - case core.HTTPStatusCodeError: - status = errWithStatus.StatusCode() - default: + // In case the error provides `func StatusCode() int` + // (e.g. core.HTTPStatusCodeError) + if x, ok := values.Error.(interface{ StatusCode() int }); ok { + status = x.StatusCode() + } else if x, ok := values.Error.(*echo.HTTPError); ok { + status = x.Code + } else { status = http.StatusInternalServerError } }