diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go
index 109a21089c..ce87ed81d9 100644
--- a/auth/api/iam/api.go
+++ b/auth/api/iam/api.go
@@ -29,11 +29,11 @@ import (
"github.com/nuts-foundation/nuts-node/auth/log"
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/vcr"
- "github.com/nuts-foundation/nuts-node/vcr/openid4vci"
"github.com/nuts-foundation/nuts-node/vdr"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"html/template"
"net/http"
+ "strings"
"sync"
)
@@ -41,6 +41,7 @@ var _ core.Routable = &Wrapper{}
var _ StrictServerInterface = &Wrapper{}
const apiPath = "iam"
+const apiModuleName = auth.ModuleName + "/" + apiPath
const httpRequestContextKey = "http-request"
//go:embed assets
@@ -71,31 +72,12 @@ func New(authInstance auth.AuthenticationServices, vcrInstance vcr.VCR, vdrInsta
}
func (r Wrapper) Routes(router core.EchoRouter) {
- const apiModuleName = auth.ModuleName + "/" + apiPath
RegisterHandlers(router, NewStrictHandler(r, []StrictMiddlewareFunc{
func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx echo.Context, request interface{}) (response interface{}, err error) {
- ctx.Set(core.OperationIDContextKey, operationID)
- ctx.Set(core.ModuleNameContextKey, apiModuleName)
- // Add http.Request to context, to allow reading URL query parameters
- requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request())
- ctx.SetRequest(ctx.Request().WithContext(requestCtx))
- // TODO: Do we need a generic error handler?
- // ctx.Set(core.ErrorWriterContextKey, &protocolErrorWriter{})
- return f(ctx, request)
+ return r.middleware(ctx, request, operationID, f)
}
},
- func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
- return func(ctx echo.Context, args interface{}) (interface{}, error) {
- if !r.auth.V2APIEnabled() {
- return nil, core.Error(http.StatusForbidden, "Access denied")
- }
- return f(ctx, args)
- }
- },
- func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
- return audit.StrictMiddleware(f, apiModuleName, operationID)
- },
}))
auditMiddleware := audit.Middleware(apiModuleName)
// The following handler is of the OpenID4VCI wallet which is called by the holder (wallet owner)
@@ -111,6 +93,24 @@ func (r Wrapper) Routes(router core.EchoRouter) {
router.POST("/iam/:did/openid4vp_demo", r.handleOpenID4VPDemoSendRequest, auditMiddleware)
}
+func (r Wrapper) middleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) {
+ ctx.Set(core.OperationIDContextKey, operationID)
+ ctx.Set(core.ModuleNameContextKey, apiModuleName)
+
+ if !r.auth.V2APIEnabled() {
+ return nil, core.Error(http.StatusForbidden, "Access denied")
+ }
+
+ // Add http.Request to context, to allow reading URL query parameters
+ requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request())
+ ctx.SetRequest(ctx.Request().WithContext(requestCtx))
+ if strings.HasPrefix(ctx.Request().URL.Path, "/iam/") {
+ ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{})
+ }
+ audit.StrictMiddleware(f, apiModuleName, operationID)
+ return f(ctx, request)
+}
+
// HandleTokenRequest handles calls to the token endpoint for exchanging a grant (e.g authorization code or pre-authorized code) for an access token.
func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequestRequestObject) (HandleTokenRequestResponseObject, error) {
switch request.Body.GrantType {
@@ -118,30 +118,29 @@ func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequ
// Options:
// - OpenID4VCI
// - OpenID4VP, vp_token is sent in Token Response
+ return nil, OAuth2Error{
+ Code: UnsupportedGrantType,
+ Description: "not implemented yet",
+ }
case "vp_token":
// Options:
// - service-to-service vp_token flow
+ return nil, OAuth2Error{
+ Code: UnsupportedGrantType,
+ Description: "not implemented yet",
+ }
case "urn:ietf:params:oauth:grant-type:pre-authorized_code":
// Options:
// - OpenID4VCI
+ return nil, OAuth2Error{
+ Code: UnsupportedGrantType,
+ Description: "not implemented yet",
+ }
default:
- // TODO: Don't use openid4vci package for errors
- return nil, openid4vci.Error{
- Code: openid4vci.InvalidRequest,
- StatusCode: http.StatusBadRequest,
- //Description: "invalid grant type",
+ return nil, OAuth2Error{
+ Code: UnsupportedGrantType,
}
}
-
- // TODO: Handle?
- //scope, err := handler(request.Body.AdditionalProperties)
- //if err != nil {
- // return nil, err
- //}
- // TODO: Generate access token with scope
- return HandleTokenRequest200JSONResponse(TokenResponse{
- AccessToken: "",
- }), nil
}
// HandleAuthorizeRequest handles calls to the authorization endpoint for starting an authorization code flow.
@@ -161,7 +160,10 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho
// TODO: Spec says that the redirect URI is optional, but it's not clear what to do if it's not provided.
// Threat models say it's unsafe to omit redirect_uri.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
- return nil, errors.New("missing redirect URI")
+ return nil, OAuth2Error{
+ Code: InvalidRequest,
+ Description: "redirect_uri is required",
+ }
}
switch session.ResponseType {
@@ -171,33 +173,24 @@ func (r Wrapper) HandleAuthorizeRequest(ctx context.Context, request HandleAutho
// - OpenID4VCI; authorization code flow for credential issuance to (end-user) wallet
// - OpenID4VP, vp_token is sent in Token Response; authorization code flow for presentation exchange (not required a.t.m.)
// TODO: Switch on parameters to right flow
+ panic("not implemented")
case responseTypeVPToken:
// Options:
// - OpenID4VP flow, vp_token is sent in Authorization Response
// TODO: Check parameters for right flow
// TODO: Do we actually need this? (probably not)
+ panic("not implemented")
case responseTypeVPIDToken:
// Options:
// - OpenID4VP+SIOP flow, vp_token is sent in Authorization Response
return r.handlePresentationRequest(params, session)
default:
// TODO: This should be a redirect?
- // TODO: Don't use openid4vci package for errors
- return nil, openid4vci.Error{
- Code: openid4vci.InvalidRequest,
- StatusCode: http.StatusBadRequest,
- //Description: "invalid/unsupported response_type",
+ return nil, OAuth2Error{
+ Code: UnsupportedResponseType,
+ RedirectURI: session.RedirectURI,
}
}
-
- // No handler could handle the request
- // TODO: This should be a redirect?
- // TODO: Don't use openid4vci package for errors
- return nil, openid4vci.Error{
- Code: openid4vci.InvalidRequest,
- StatusCode: http.StatusBadRequest,
- //Description: "missing or invalid parameters",
- }
}
// OAuthAuthorizationServerMetadata returns the Authorization Server's metadata
diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go
index 3b2d25b524..bff9a8c8c2 100644
--- a/auth/api/iam/api_test.go
+++ b/auth/api/iam/api_test.go
@@ -19,7 +19,9 @@
package iam
import (
+ "context"
"errors"
+ "github.com/labstack/echo/v4"
ssi "github.com/nuts-foundation/go-did"
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/nuts-node/audit"
@@ -30,18 +32,21 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
+ "net/http"
+ "net/http/httptest"
"net/url"
"testing"
)
+var nutsDID = did.MustParseDID("did:nuts:123")
+
func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) {
- testDID := did.MustParseDID("did:nuts:123")
t.Run("ok", func(t *testing.T) {
// 200
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(true, nil)
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(true, nil)
- res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID})
+ res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID})
require.NoError(t, err)
assert.IsType(t, OAuthAuthorizationServerMetadata200JSONResponse{}, res)
@@ -50,9 +55,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) {
t.Run("error - did not managed by this node", func(t *testing.T) {
//404
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, testDID)
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID)
- res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID})
+ res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID})
assert.Equal(t, 404, statusCodeFrom(err))
assert.EqualError(t, err, "authz server metadata: did not owned")
@@ -61,9 +66,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) {
t.Run("error - did does not exist", func(t *testing.T) {
//404
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(false, resolver.ErrNotFound)
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, resolver.ErrNotFound)
- res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID})
+ res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID})
assert.Equal(t, 404, statusCodeFrom(err))
assert.EqualError(t, err, "authz server metadata: unable to find the DID document")
@@ -72,9 +77,9 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) {
t.Run("error - internal error 500", func(t *testing.T) {
//500
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, testDID).Return(false, errors.New("unknown error"))
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, errors.New("unknown error"))
- res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: testDID.ID})
+ res, err := ctx.client.OAuthAuthorizationServerMetadata(nil, OAuthAuthorizationServerMetadataRequestObject{Id: nutsDID.ID})
assert.Equal(t, 500, statusCodeFrom(err))
assert.EqualError(t, err, "authz server metadata: unknown error")
@@ -83,7 +88,6 @@ func TestWrapper_OAuthAuthorizationServerMetadata(t *testing.T) {
}
func TestWrapper_GetWebDID(t *testing.T) {
- nutsDID := did.MustParseDID("did:nuts:123")
webDID := did.MustParseDID("did:web:example.com:iam:123")
publicURL := ssi.MustParseURI("https://example.com").URL
webDIDBaseURL := publicURL.JoinPath("/iam")
@@ -125,30 +129,29 @@ func TestWrapper_GetWebDID(t *testing.T) {
}
func TestWrapper_GetOAuthClientMetadata(t *testing.T) {
- did := did.MustParseDID("did:nuts:123")
t.Run("ok", func(t *testing.T) {
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, did).Return(true, nil)
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(true, nil)
- res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID})
+ res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID})
require.NoError(t, err)
assert.IsType(t, OAuthClientMetadata200JSONResponse{}, res)
})
t.Run("error - did not managed by this node", func(t *testing.T) {
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, did)
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID)
- res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID})
+ res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID})
assert.Equal(t, 404, statusCodeFrom(err))
assert.Nil(t, res)
})
t.Run("error - internal error 500", func(t *testing.T) {
ctx := newTestClient(t)
- ctx.vdr.EXPECT().IsOwner(nil, did).Return(false, errors.New("unknown error"))
+ ctx.vdr.EXPECT().IsOwner(nil, nutsDID).Return(false, errors.New("unknown error"))
- res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: did.ID})
+ res, err := ctx.client.OAuthClientMetadata(nil, OAuthClientMetadataRequestObject{Id: nutsDID.ID})
assert.Equal(t, 500, statusCodeFrom(err))
assert.EqualError(t, err, "unknown error")
@@ -156,6 +159,68 @@ func TestWrapper_GetOAuthClientMetadata(t *testing.T) {
})
}
+func TestWrapper_HandleAuthorizeRequest(t *testing.T) {
+ t.Run("missing redirect_uri", func(t *testing.T) {
+ ctx := newTestClient(t)
+
+ res, err := ctx.client.HandleAuthorizeRequest(requestContext(map[string]string{}), HandleAuthorizeRequestRequestObject{
+ Id: nutsDID.String(),
+ })
+
+ requireOAuthError(t, err, InvalidRequest, "redirect_uri is required")
+ assert.Nil(t, res)
+ })
+ t.Run("unsupported response type", func(t *testing.T) {
+ ctx := newTestClient(t)
+
+ res, err := ctx.client.HandleAuthorizeRequest(requestContext(map[string]string{
+ "redirect_uri": "https://example.com",
+ "response_type": "unsupported",
+ }), HandleAuthorizeRequestRequestObject{
+ Id: nutsDID.String(),
+ })
+
+ requireOAuthError(t, err, UnsupportedResponseType, "")
+ assert.Nil(t, res)
+ })
+}
+
+func TestWrapper_HandleTokenRequest(t *testing.T) {
+ t.Run("unsupported grant type", func(t *testing.T) {
+ ctx := newTestClient(t)
+
+ res, err := ctx.client.HandleTokenRequest(nil, HandleTokenRequestRequestObject{
+ Id: nutsDID.String(),
+ Body: &HandleTokenRequestFormdataRequestBody{
+ GrantType: "unsupported",
+ },
+ })
+
+ requireOAuthError(t, err, UnsupportedGrantType, "")
+ assert.Nil(t, res)
+ })
+}
+
+func requireOAuthError(t *testing.T, err error, errorCode ErrorCode, errorDescription string) {
+ var oauthErr OAuth2Error
+ require.ErrorAs(t, err, &oauthErr)
+ assert.Equal(t, errorCode, oauthErr.Code)
+ assert.Equal(t, errorDescription, oauthErr.Description)
+}
+
+func requestContext(queryParams map[string]string) context.Context {
+ vals := url.Values{}
+ for key, value := range queryParams {
+ vals.Add(key, value)
+ }
+ httpRequest := &http.Request{
+ URL: &url.URL{
+ RawQuery: vals.Encode(),
+ },
+ }
+ return context.WithValue(audit.TestContext(), httpRequestContextKey, httpRequest)
+}
+
// statusCodeFrom returns the statuscode if err is core.HTTPStatusCodeError, or 0 if it isn't
func statusCodeFrom(err error) int {
var SE core.HTTPStatusCodeError
@@ -191,3 +256,66 @@ func newTestClient(t testing.TB) *testCtx {
},
}
}
+
+func TestWrapper_Routes(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ router := core.NewMockEchoRouter(ctrl)
+
+ router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+ router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+
+ Wrapper{}.Routes(router)
+}
+
+func TestWrapper_middleware(t *testing.T) {
+ server := echo.New()
+ ctrl := gomock.NewController(t)
+ authService := auth.NewMockAuthenticationServices(ctrl)
+ authService.EXPECT().V2APIEnabled().Return(true).AnyTimes()
+
+ t.Run("API enabling", func(t *testing.T) {
+ t.Run("enabled", func(t *testing.T) {
+ var called strictServerCallCapturer
+
+ ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
+ _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle)
+
+ assert.True(t, bool(called))
+ })
+ t.Run("disabled", func(t *testing.T) {
+ var called strictServerCallCapturer
+
+ authService := auth.NewMockAuthenticationServices(ctrl)
+ authService.EXPECT().V2APIEnabled().Return(false).AnyTimes()
+
+ ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
+ _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle)
+
+ assert.False(t, bool(called))
+ })
+ })
+
+ t.Run("OAuth2 error handling", func(t *testing.T) {
+ var handler strictServerCallCapturer
+ t.Run("OAuth2 path", func(t *testing.T) {
+ ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
+ _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle)
+
+ assert.IsType(t, &oauth2ErrorWriter{}, ctx.Get(core.ErrorWriterContextKey))
+ })
+ t.Run("other path", func(t *testing.T) {
+ ctx := server.NewContext(httptest.NewRequest("GET", "/internal/foo", nil), httptest.NewRecorder())
+ _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle)
+
+ assert.Nil(t, ctx.Get(core.ErrorWriterContextKey))
+ })
+ })
+
+}
+
+type strictServerCallCapturer bool
+
+func (s *strictServerCallCapturer) handle(ctx echo.Context, request interface{}) (response interface{}, err error) {
+ *s = true
+ return nil, nil
+}
diff --git a/auth/api/iam/authorized_code.go b/auth/api/iam/authorized_code.go
index 59faa1dbde..2e4cc2521b 100644
--- a/auth/api/iam/authorized_code.go
+++ b/auth/api/iam/authorized_code.go
@@ -26,7 +26,6 @@ import (
"fmt"
"github.com/labstack/echo/v4"
"github.com/nuts-foundation/nuts-node/core"
- "github.com/nuts-foundation/nuts-node/vcr/openid4vci"
"html/template"
"net/http"
"net/url"
@@ -92,20 +91,16 @@ func (a authorizedCodeFlow) handleAuthConsent(c echo.Context) error {
func (a authorizedCodeFlow) validateCode(params map[string]string) (string, error) {
code, ok := params["code"]
+ invalidCodeError := OAuth2Error{
+ Code: InvalidRequest,
+ Description: "missing or invalid code parameter",
+ }
if !ok {
- return "", openid4vci.Error{
- Code: openid4vci.InvalidRequest,
- StatusCode: http.StatusBadRequest,
- //Description: "missing or invalid code parameter",
- }
+ return "", invalidCodeError
}
session := a.sessions.Get(code)
if session == nil {
- return "", openid4vci.Error{
- Code: openid4vci.InvalidRequest,
- StatusCode: http.StatusBadRequest,
- //Description: "invalid code",
- }
+ return "", invalidCodeError
}
return session.Scope, nil
}
diff --git a/auth/api/iam/error.go b/auth/api/iam/error.go
new file mode 100644
index 0000000000..a70c005052
--- /dev/null
+++ b/auth/api/iam/error.go
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2023 Nuts community
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package iam
+
+import (
+ "errors"
+ "github.com/labstack/echo/v4"
+ "github.com/nuts-foundation/nuts-node/core"
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+// ErrorCode specifies error codes as defined by the OAuth2 specifications.
+// Codes and descriptions are taken from https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
+type ErrorCode string
+
+const (
+ // InvalidRequest is returned when the request is missing a required parameter, includes an invalid parameter value,
+ // includes a parameter more than once, or is otherwise malformed.
+ InvalidRequest ErrorCode = "invalid_request"
+ // UnsupportedGrantType is returned when the authorization grant type is not supported by the authorization server.
+ UnsupportedGrantType ErrorCode = "unsupported_grant_type"
+ // UnsupportedResponseType is returned when the authorization server does not support obtaining an authorization code using this method.
+ UnsupportedResponseType ErrorCode = "unsupported_response_type"
+ // ServerError is returned when the Authorization Server encounters an unexpected condition that prevents it from fulfilling the request.
+ ServerError ErrorCode = "server_error"
+)
+
+// Make sure the error implements core.HTTPStatusCodeError, so the HTTP request logger can log the correct status code.
+var _ core.HTTPStatusCodeError = OAuth2Error{}
+
+// OAuth2Error is an OAuth2 error that signals the error was (probably) caused by the client (e.g. bad request),
+// or that the client can recover from the error (e.g. retry).
+type OAuth2Error struct {
+ // Code is the error code as defined by the OAuth2 specification.
+ Code ErrorCode `json:"error"`
+ // Description is a human-readable ASCII [USASCII] text providing additional information, used to assist the client developer in understanding the error that occurred.
+ Description string `json:"error_description,omitempty"`
+ // InternalError is the underlying error, may be omitted. It is not intended to be returned to the client, only to be logged.
+ InternalError error `json:"-"`
+ // RedirectURI is the redirect URI that should be used to redirect the client to, in case the user-agent is a browser.
+ // It should not be set if the user-agent is not a browser, or there is no redirect_uri (because the request was malformed), this field is empty.
+ // When the field is set, the user-agent is redirected to the specified URI with the error code and description as query parameters.
+ // If it's not set, the error code and description are returned in the response body (plain text or JSON).
+ RedirectURI string `json:"-"`
+}
+
+// StatusCode returns the HTTP status code to be returned to the client, in case the user-agent can't be redirected with HTTP 302 - Found.
+func (e OAuth2Error) StatusCode() int {
+ switch e.Code {
+ case ServerError:
+ return http.StatusInternalServerError
+ default:
+ return http.StatusBadRequest
+ }
+}
+
+// OAuth2Error returns the error message, which is either the underlying error or the code if there is no underlying error
+func (e OAuth2Error) Error() string {
+ var parts []string
+ parts = append(parts, string(e.Code))
+ if e.InternalError != nil {
+ parts = append(parts, e.InternalError.Error())
+ }
+ if e.Description != "" {
+ parts = append(parts, e.Description)
+ }
+ return strings.Join(parts, " - ")
+}
+
+type oauth2ErrorWriter struct{}
+
+func (p oauth2ErrorWriter) Write(echoContext echo.Context, _ int, _ string, err error) error {
+ var oauthErr OAuth2Error
+ if !errors.As(err, &oauthErr) {
+ // Internal error, wrap it in an OAuth2 error
+ oauthErr = OAuth2Error{
+ Code: ServerError,
+ InternalError: err,
+ }
+ }
+ if oauthErr.Code == "" {
+ // Somebody forgot to set a code
+ oauthErr.Code = ServerError
+ }
+ redirectURI, _ := url.Parse(oauthErr.RedirectURI)
+ if oauthErr.RedirectURI == "" || redirectURI == nil {
+ // Can't redirect the user-agent back, render error as JSON or plain text (depending on content-type)
+ contentType := echoContext.Request().Header.Get("Content-Type")
+ if strings.Contains(contentType, "application/json") {
+ // Return JSON response
+ return echoContext.JSON(oauthErr.StatusCode(), oauthErr)
+ } else {
+ // Return plain text response
+ parts := []string{string(oauthErr.Code)}
+ if oauthErr.Description != "" {
+ parts = append(parts, oauthErr.Description)
+ }
+ return echoContext.String(oauthErr.StatusCode(), strings.Join(parts, " - "))
+ }
+ }
+ // Redirect the user-agent back to the client
+ query := redirectURI.Query()
+ query.Set("error", string(oauthErr.Code))
+ if oauthErr.Description != "" {
+ query.Set("error_description", oauthErr.Description)
+ }
+ redirectURI.RawQuery = query.Encode()
+ return echoContext.Redirect(http.StatusFound, redirectURI.String())
+}
diff --git a/auth/api/iam/error_test.go b/auth/api/iam/error_test.go
new file mode 100644
index 0000000000..9225282c92
--- /dev/null
+++ b/auth/api/iam/error_test.go
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2023 Nuts community
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package iam
+
+import (
+ "errors"
+ "github.com/labstack/echo/v4"
+ "github.com/stretchr/testify/assert"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestError_Error(t *testing.T) {
+ t.Run("with underlying error", func(t *testing.T) {
+ assert.EqualError(t, OAuth2Error{InternalError: errors.New("token has expired"), Code: InvalidRequest}, "invalid_request - token has expired")
+ })
+ t.Run("without underlying error", func(t *testing.T) {
+ assert.EqualError(t, OAuth2Error{Code: InvalidRequest}, "invalid_request")
+ })
+}
+
+func Test_oauth2ErrorWriter_Write(t *testing.T) {
+ t.Run("user-agent is browser with redirect URI", func(t *testing.T) {
+ server := echo.New()
+ httpRequest := httptest.NewRequest("GET", "/", nil)
+ rec := httptest.NewRecorder()
+ ctx := server.NewContext(httpRequest, rec)
+
+ err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{
+ Code: InvalidRequest,
+ Description: "failure",
+ RedirectURI: "https://example.com",
+ })
+
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusFound, rec.Code)
+ assert.Equal(t, "https://example.com?error=invalid_request&error_description=failure", rec.Header().Get("Location"))
+ })
+ t.Run("user-agent is browser without redirect URI", func(t *testing.T) {
+ server := echo.New()
+ httpRequest := httptest.NewRequest("GET", "/", nil)
+ rec := httptest.NewRecorder()
+ ctx := server.NewContext(httpRequest, rec)
+
+ err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{
+ Code: InvalidRequest,
+ Description: "failure",
+ })
+
+ assert.NoError(t, err)
+ body, _ := io.ReadAll(rec.Body)
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+ assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get("Content-Type"))
+ assert.Equal(t, "invalid_request - failure", string(body))
+ assert.Empty(t, rec.Header().Get("Location"))
+ })
+ t.Run("user-agent is API client (sent JSON)", func(t *testing.T) {
+ server := echo.New()
+ httpRequest := httptest.NewRequest("GET", "/", nil)
+ httpRequest.Header["Content-Type"] = []string{"application/json"}
+ rec := httptest.NewRecorder()
+ ctx := server.NewContext(httpRequest, rec)
+
+ err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{
+ Code: InvalidRequest,
+ Description: "failure",
+ })
+
+ assert.NoError(t, err)
+ body, _ := io.ReadAll(rec.Body)
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
+ assert.Equal(t, "application/json; charset=UTF-8", rec.Header().Get("Content-Type"))
+ assert.Equal(t, `{"error":"invalid_request","error_description":"failure"}`, strings.TrimSpace(string(body)))
+ assert.Empty(t, rec.Header().Get("Location"))
+ })
+ t.Run("OAuth2 error without code, defaults to server_error", func(t *testing.T) {
+ server := echo.New()
+ httpRequest := httptest.NewRequest("GET", "/", nil)
+ rec := httptest.NewRecorder()
+ ctx := server.NewContext(httpRequest, rec)
+
+ err := oauth2ErrorWriter{}.Write(ctx, 0, "", OAuth2Error{
+ Description: "failure",
+ })
+
+ assert.NoError(t, err)
+ body, _ := io.ReadAll(rec.Body)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.Equal(t, `server_error - failure`, strings.TrimSpace(string(body)))
+ })
+ t.Run("error is not an OAuth2 error, should be wrapped", func(t *testing.T) {
+ server := echo.New()
+ httpRequest := httptest.NewRequest("GET", "/", nil)
+ rec := httptest.NewRecorder()
+ ctx := server.NewContext(httpRequest, rec)
+
+ err := oauth2ErrorWriter{}.Write(ctx, 0, "", errors.New("catastrophic"))
+
+ assert.NoError(t, err)
+ body, _ := io.ReadAll(rec.Body)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.Equal(t, `server_error`, strings.TrimSpace(string(body)))
+ })
+}
diff --git a/auth/api/iam/generated.go b/auth/api/iam/generated.go
index 1aff918daa..4acaf152bb 100644
--- a/auth/api/iam/generated.go
+++ b/auth/api/iam/generated.go
@@ -15,12 +15,6 @@ import (
strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo"
)
-// ErrorResponse defines model for ErrorResponse.
-type ErrorResponse struct {
- // Error Code identifying the error that occurred.
- Error string `json:"error"`
-}
-
// TokenResponse Token Responses are made as defined in (RFC6749)[https://datatracker.ietf.org/doc/html/rfc6749#section-5.1]
type TokenResponse struct {
// AccessToken The access token issued by the authorization server.
diff --git a/auth/api/iam/openid4vp.go b/auth/api/iam/openid4vp.go
index 46f7aa5730..30f467a85a 100644
--- a/auth/api/iam/openid4vp.go
+++ b/auth/api/iam/openid4vp.go
@@ -81,14 +81,22 @@ func (r *Wrapper) handlePresentationRequest(params map[string]string, session *S
}
// Response mode is always direct_post for now
if params[responseModeParam] != responseModeDirectPost {
- return nil, errors.New("response_mode must be direct_post")
+ return nil, OAuth2Error{
+ Code: InvalidRequest,
+ Description: "response_mode must be direct_post",
+ RedirectURI: session.RedirectURI,
+ }
}
// TODO: This is the easiest for now, but is this the way?
// For compatibility, we probably need to support presentation_definition and/or presentation_definition_uri.
presentationDefinition := r.auth.PresentationDefinitions().ByScope(params[scopeParam])
if presentationDefinition == nil {
- return nil, fmt.Errorf("unsupported scope for presentation exchange: %s", params[scopeParam])
+ return nil, OAuth2Error{
+ Code: InvalidRequest,
+ Description: fmt.Sprintf("unsupported scope for presentation exchange: %s", params[scopeParam]),
+ RedirectURI: session.RedirectURI,
+ }
}
// Render HTML
@@ -204,7 +212,6 @@ func (r *Wrapper) handlePresentationRequestAccept(c echo.Context) error {
}
func (r *Wrapper) handlePresentationRequestCompleted(ctx echo.Context) error {
- // TODO: support error response
// TODO: response direct_post mode
vpToken := ctx.QueryParams()[vpTokenParam]
if len(vpToken) == 0 {
diff --git a/auth/api/iam/openid4vp_test.go b/auth/api/iam/openid4vp_test.go
index effc889e6d..91fec61bfd 100644
--- a/auth/api/iam/openid4vp_test.go
+++ b/auth/api/iam/openid4vp_test.go
@@ -118,6 +118,40 @@ func TestWrapper_handlePresentationRequest(t *testing.T) {
require.Equal(t, http.StatusOK, httpResponse.statusCode)
assert.Contains(t, httpResponse.body.String(), "