Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
reinkrul committed Sep 27, 2023
1 parent 599990c commit f5c6a30
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 20 deletions.
40 changes: 21 additions & 19 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ import (
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"html/template"
"net/http"
"strings"
"sync"
)

var _ core.Routable = &Wrapper{}
var _ StrictServerInterface = &Wrapper{}

const apiPath = "iam"
const apiModuleName = auth.ModuleName + "/" + apiPath
const httpRequestContextKey = "http-request"

//go:embed assets
Expand Down Expand Up @@ -70,30 +72,12 @@ func New(authInstance auth.AuthenticationServices, vcrInstance vcr.VCR, vdrInsta
}

func (r Wrapper) Routes(router core.EchoRouter) {
const apiModuleName = auth.ModuleName + "/" + apiPath
RegisterHandlers(router, NewStrictHandler(r, []StrictMiddlewareFunc{
func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx echo.Context, request interface{}) (response interface{}, err error) {
ctx.Set(core.OperationIDContextKey, operationID)
ctx.Set(core.ModuleNameContextKey, apiModuleName)
// Add http.Request to context, to allow reading URL query parameters
requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request())
ctx.SetRequest(ctx.Request().WithContext(requestCtx))
ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{})
return f(ctx, request)
return r.middleware(ctx, request, operationID, f)
}
},
func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx echo.Context, args interface{}) (interface{}, error) {
if !r.auth.V2APIEnabled() {
return nil, core.Error(http.StatusForbidden, "Access denied")
}
return f(ctx, args)
}
},
func(f StrictHandlerFunc, operationID string) StrictHandlerFunc {
return audit.StrictMiddleware(f, apiModuleName, operationID)
},
}))
auditMiddleware := audit.Middleware(apiModuleName)
// The following handler is of the OpenID4VCI wallet which is called by the holder (wallet owner)
Expand All @@ -109,6 +93,24 @@ func (r Wrapper) Routes(router core.EchoRouter) {
router.POST("/iam/:did/openid4vp_demo", r.handleOpenID4VPDemoSendRequest, auditMiddleware)
}

func (r Wrapper) middleware(ctx echo.Context, request interface{}, operationID string, f StrictHandlerFunc) (interface{}, error) {
ctx.Set(core.OperationIDContextKey, operationID)
ctx.Set(core.ModuleNameContextKey, apiModuleName)

if !r.auth.V2APIEnabled() {
return nil, core.Error(http.StatusForbidden, "Access denied")
}

// Add http.Request to context, to allow reading URL query parameters
requestCtx := context.WithValue(ctx.Request().Context(), httpRequestContextKey, ctx.Request())
ctx.SetRequest(ctx.Request().WithContext(requestCtx))
if strings.HasPrefix(ctx.Request().URL.Path, "/iam/") {
ctx.Set(core.ErrorWriterContextKey, &oauth2ErrorWriter{})
}
audit.StrictMiddleware(f, apiModuleName, operationID)
return f(ctx, request)
}

// HandleTokenRequest handles calls to the token endpoint for exchanging a grant (e.g authorization code or pre-authorized code) for an access token.
func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequestRequestObject) (HandleTokenRequestResponseObject, error) {
switch request.Body.GrantType {
Expand Down
67 changes: 66 additions & 1 deletion auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package iam
import (
"context"
"errors"
"github.com/labstack/echo/v4"
ssi "github.com/nuts-foundation/go-did"
"github.com/nuts-foundation/go-did/did"
"github.com/nuts-foundation/nuts-node/audit"
Expand All @@ -32,6 +33,7 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
Expand Down Expand Up @@ -194,7 +196,7 @@ func TestWrapper_HandleTokenRequest(t *testing.T) {
},
})

requireOAuthError(t, err, InvalidRequest, "unsupported grant_type")
requireOAuthError(t, err, UnsupportedGrantType, "")
assert.Nil(t, res)
})
}
Expand Down Expand Up @@ -254,3 +256,66 @@ func newTestClient(t testing.TB) *testCtx {
},
}
}

func TestWrapper_Routes(t *testing.T) {
ctrl := gomock.NewController(t)
router := core.NewMockEchoRouter(ctrl)

router.EXPECT().GET(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
router.EXPECT().POST(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()

Wrapper{}.Routes(router)
}

func TestWrapper_middleware(t *testing.T) {
server := echo.New()
ctrl := gomock.NewController(t)
authService := auth.NewMockAuthenticationServices(ctrl)
authService.EXPECT().V2APIEnabled().Return(true).AnyTimes()

t.Run("API enabling", func(t *testing.T) {
t.Run("enabled", func(t *testing.T) {
var called strictServerCallCapturer

ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle)

assert.True(t, bool(called))
})
t.Run("disabled", func(t *testing.T) {
var called strictServerCallCapturer

authService := auth.NewMockAuthenticationServices(ctrl)
authService.EXPECT().V2APIEnabled().Return(false).AnyTimes()

ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", called.handle)

assert.False(t, bool(called))
})
})

t.Run("OAuth2 error handling", func(t *testing.T) {
var handler strictServerCallCapturer
t.Run("OAuth2 path", func(t *testing.T) {
ctx := server.NewContext(httptest.NewRequest("GET", "/iam/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle)

assert.IsType(t, &oauth2ErrorWriter{}, ctx.Get(core.ErrorWriterContextKey))
})
t.Run("other path", func(t *testing.T) {
ctx := server.NewContext(httptest.NewRequest("GET", "/internal/foo", nil), httptest.NewRecorder())
_, _ = Wrapper{auth: authService}.middleware(ctx, nil, "Test", handler.handle)

assert.Nil(t, ctx.Get(core.ErrorWriterContextKey))
})
})

}

type strictServerCallCapturer bool

func (s *strictServerCallCapturer) handle(ctx echo.Context, request interface{}) (response interface{}, err error) {
*s = true
return nil, nil
}

0 comments on commit f5c6a30

Please sign in to comment.