diff --git a/auth/api/iam/api.go b/auth/api/iam/api.go index 9a0b4e3c1c..ce87ed81d9 100644 --- a/auth/api/iam/api.go +++ b/auth/api/iam/api.go @@ -33,6 +33,7 @@ import ( "github.com/nuts-foundation/nuts-node/vdr/resolver" "html/template" "net/http" + "strings" "sync" ) @@ -40,6 +41,7 @@ var _ core.Routable = &Wrapper{} var _ StrictServerInterface = &Wrapper{} const apiPath = "iam" +const apiModuleName = auth.ModuleName + "/" + apiPath const httpRequestContextKey = "http-request" //go:embed assets @@ -70,30 +72,12 @@ func New(authInstance auth.AuthenticationServices, vcrInstance vcr.VCR, vdrInsta } func (r Wrapper) Routes(router core.EchoRouter) { - const apiModuleName = auth.ModuleName + "/" + apiPath RegisterHandlers(router, NewStrictHandler(r, []StrictMiddlewareFunc{ func(f StrictHandlerFunc, operationID string) StrictHandlerFunc { return func(ctx echo.Context, request interface{}) (response interface{}, err error) { - ctx.Set(core.OperationIDContextKey, operationID) - ctx.Set(core.ModuleNameContextKey, apiModuleName) - // Add http.Request to context, to allow reading URL query parameters - requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request()) - ctx.SetRequest(ctx.Request().WithContext(requestCtx)) - ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{}) - return f(ctx, request) + return r.middleware(ctx, request, operationID, f) } }, - func(f StrictHandlerFunc, operationID string) StrictHandlerFunc { - return func(ctx echo.Context, args interface{}) (interface{}, error) { - if !r.auth.V2APIEnabled() { - return nil, core.Error(http.StatusForbidden, "Access denied") - } - return f(ctx, args) - } - }, - func(f StrictHandlerFunc, operationID string) StrictHandlerFunc { - return audit.StrictMiddleware(f, apiModuleName, operationID) - }, })) auditMiddleware := audit.Middleware(apiModuleName) // The following handler is of the OpenID4VCI wallet which is called by the holder (wallet owner) @@ -109,6 +93,24 @@ func (r Wrapper) Routes(router core.EchoRouter) { router.POST("/iam/:did/openid4vp_demo", r.handleOpenID4VPDemoSendRequest, auditMiddleware) } +func (r Wrapper) middleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) { + ctx.Set(core.OperationIDContextKey, operationID) + ctx.Set(core.ModuleNameContextKey, apiModuleName) + + if !r.auth.V2APIEnabled() { + return nil, core.Error(http.StatusForbidden, "Access denied") + } + + // Add http.Request to context, to allow reading URL query parameters + requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request()) + ctx.SetRequest(ctx.Request().WithContext(requestCtx)) + if strings.HasPrefix(ctx.Request().URL.Path, "/iam/") { + ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{}) + } + audit.StrictMiddleware(f, apiModuleName, operationID) + return f(ctx, request) +} + // HandleTokenRequest handles calls to the token endpoint for exchanging a grant (e.g authorization code or pre-authorized code) for an access token. func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequestRequestObject) (HandleTokenRequestResponseObject, error) { switch request.Body.GrantType { diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index 4a650153c5..bff9a8c8c2 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -21,6 +21,7 @@ package iam import ( "context" "errors" + "github.com/labstack/echo/v4" ssi "github.com/nuts-foundation/go-did" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/audit" @@ -32,6 +33,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "net/http" + "net/http/httptest" "net/url" "testing" ) @@ -194,7 +196,7 @@ func TestWrapper_HandleTokenRequest(t *testing.T) { }, }) - requireOAuthError(t, err, InvalidRequest, "unsupported grant_type") + requireOAuthError(t, err, UnsupportedGrantType, "") assert.Nil(t, res) }) } @@ -254,3 +256,66 @@ func newTestClient(t testing.TB) *testCtx { }, } } + +func TestWrapper_Routes(t *testing.T) { + ctrl := gomock.NewController(t) + router := core.NewMockEchoRouter(ctrl) + + router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + Wrapper{}.Routes(router) +} + +func TestWrapper_middleware(t *testing.T) { + server := echo.New() + ctrl := gomock.NewController(t) + authService := auth.NewMockAuthenticationServices(ctrl) + authService.EXPECT().V2APIEnabled().Return(true).AnyTimes() + + t.Run("API enabling", func(t *testing.T) { + t.Run("enabled", func(t *testing.T) { + var called strictServerCallCapturer + + ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle) + + assert.True(t, bool(called)) + }) + t.Run("disabled", func(t *testing.T) { + var called strictServerCallCapturer + + authService := auth.NewMockAuthenticationServices(ctrl) + authService.EXPECT().V2APIEnabled().Return(false).AnyTimes() + + ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle) + + assert.False(t, bool(called)) + }) + }) + + t.Run("OAuth2 error handling", func(t *testing.T) { + var handler strictServerCallCapturer + t.Run("OAuth2 path", func(t *testing.T) { + ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle) + + assert.IsType(t, &oauth2ErrorWriter{}, ctx.Get(core.ErrorWriterContextKey)) + }) + t.Run("other path", func(t *testing.T) { + ctx := server.NewContext(httptest.NewRequest("GET", "/internal/foo", nil), httptest.NewRecorder()) + _, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle) + + assert.Nil(t, ctx.Get(core.ErrorWriterContextKey)) + }) + }) + +} + +type strictServerCallCapturer bool + +func (s *strictServerCallCapturer) handle(ctx echo.Context, request interface{}) (response interface{}, err error) { + *s = true + return nil, nil +}