diff --git a/auth/api/auth/v1/client/client.go b/auth/api/auth/v1/client/client.go index 5ea160f107..b727a71454 100644 --- a/auth/api/auth/v1/client/client.go +++ b/auth/api/auth/v1/client/client.go @@ -74,7 +74,7 @@ func (h HTTPClient) CreateAccessToken(ctx context.Context, endpointURL url.URL, return nil, err } - if err := core.TestResponseCode(http.StatusOK, response); err != nil { + if err = core.TestResponseCode(http.StatusOK, response); err != nil { rse := err.(core.HttpError) // Cut off the response body to 100 characters max to prevent logging of large responses responseBodyString := string(rse.ResponseBody) diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 4a2b574b82..7216b061f6 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -255,7 +255,7 @@ func (r Wrapper) PresentationDefinition(_ context.Context, request PresentationD presentationDefinition := r.auth.PresentationDefinitions().ByScope(scopes[0]) if presentationDefinition == nil { return PresentationDefinition400JSONResponse{ - Error: "invalid_scope", + Code: "invalid_scope", }, nil } presentationDefinitions := []PresentationDefinition{*presentationDefinition} diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index e14c5a68e8..128f31ff03 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -197,7 +197,7 @@ func TestWrapper_PresentationDefinition(t *testing.T) { require.NoError(t, err) require.NotNil(t, response) - assert.Equal(t, "invalid_scope", (response.(PresentationDefinition400JSONResponse)).Error) + assert.Equal(t, InvalidScope, (response.(PresentationDefinition400JSONResponse)).Code) }) } diff --git a/auth/api/iam/client.go b/auth/api/iam/client.go index b83caccc7c..0075ce45bc 100644 --- a/auth/api/iam/client.go +++ b/auth/api/iam/client.go @@ -100,8 +100,12 @@ func (hb HTTPClient) PresentationDefinition(ctx context.Context, definitionEndpo if err != nil { return nil, fmt.Errorf("failed to call endpoint: %w", err) } - if err = core.TestResponseCode(http.StatusOK, response); err != nil { - return nil, err + if httpErr := core.TestResponseCode(http.StatusOK, response); httpErr != nil { + rse := httpErr.(core.HttpError) + if TestOAuthErrorCode(rse.ResponseBody, InvalidScope) { + return nil, ErrInvalidScope + } + return nil, httpErr } definitions := make([]PresentationDefinition, 0) diff --git a/auth/api/iam/client_test.go b/auth/api/iam/client_test.go index 8ee6623f51..f4012c545f 100644 --- a/auth/api/iam/client_test.go +++ b/auth/api/iam/client_test.go @@ -128,7 +128,16 @@ func TestHTTPClient_PresentationDefinition(t *testing.T) { require.NotNil(t, handler.Request) assert.Equal(t, url.Values{"scope": []string{"first second"}}, handler.Request.URL.Query()) }) + t.Run("error - invalid_scope", func(t *testing.T) { + handler := http2.Handler{StatusCode: http.StatusBadRequest, ResponseData: OAuth2Error{Code: InvalidScope}} + tlsServer, client := testServerAndClient(t, &handler) + + response, err := client.PresentationDefinition(ctx, tlsServer.URL, []string{"test"}) + require.Error(t, err) + assert.EqualError(t, err, "invalid scope") + assert.Nil(t, response) + }) t.Run("error - not found", func(t *testing.T) { handler := http2.Handler{StatusCode: http.StatusNotFound} tlsServer, client := testServerAndClient(t, &handler) diff --git a/auth/api/iam/error.go b/auth/api/iam/error.go index a70c005052..491685bae1 100644 --- a/auth/api/iam/error.go +++ b/auth/api/iam/error.go @@ -19,6 +19,7 @@ package iam import ( + "encoding/json" "errors" "github.com/labstack/echo/v4" "github.com/nuts-foundation/nuts-node/core" @@ -125,3 +126,16 @@ func (p oauth2ErrorWriter) Write(echoContext echo.Context, _ int, _ string, err redirectURI.RawQuery = query.Encode() return echoContext.Redirect(http.StatusFound, redirectURI.String()) } + +const InvalidScope = ErrorCode("invalid_scope") + +var ErrInvalidScope = errors.New("invalid scope") + +// TestOAuthErrorCode tests if the response is an OAuth2 error with the given code. +func TestOAuthErrorCode(responseBody []byte, code ErrorCode) bool { + var oauthErr OAuth2Error + if err := json.Unmarshal(responseBody, &oauthErr); err != nil { + return false + } + return oauthErr.Code == code +}