diff --git a/auth/api/iam/api_test.go b/auth/api/iam/api_test.go index ba408ca952..6a1c3a478e 100644 --- a/auth/api/iam/api_test.go +++ b/auth/api/iam/api_test.go @@ -367,6 +367,7 @@ func statusCodeFrom(err error) int { } type testCtx struct { + ctrl *gomock.Controller client *Wrapper authnServices *auth.MockAuthenticationServices vdr *vdr.MockVDR @@ -390,6 +391,7 @@ func newTestClient(t testing.TB) *testCtx { vdr.EXPECT().Resolver().Return(resolver).AnyTimes() return &testCtx{ + ctrl: ctrl, authnServices: authnServices, relyingParty: relyingPary, resolver: resolver, diff --git a/auth/api/iam/session.go b/auth/api/iam/session.go index 6f7ee5b35f..8b1250be48 100644 --- a/auth/api/iam/session.go +++ b/auth/api/iam/session.go @@ -41,6 +41,11 @@ type UserSession struct { OwnDID did.DID } +type RedirectSession struct { + OwnDID did.DID + AccessTokenRequest RequestAccessTokenRequestObject +} + func (s OAuthSession) CreateRedirectURI(params map[string]string) string { redirectURI, _ := url.Parse(s.RedirectURI) r := http.AddQueryParams(*redirectURI, params) diff --git a/auth/api/iam/user.go b/auth/api/iam/user.go index 8dd08e5a90..c4f34d3bbe 100644 --- a/auth/api/iam/user.go +++ b/auth/api/iam/user.go @@ -20,14 +20,14 @@ package iam import ( "context" - "github.com/labstack/echo/v4" - "github.com/nuts-foundation/nuts-node/auth/log" - http2 "github.com/nuts-foundation/nuts-node/http" "net/http" "time" + "github.com/labstack/echo/v4" "github.com/nuts-foundation/go-did/did" + "github.com/nuts-foundation/nuts-node/auth/log" "github.com/nuts-foundation/nuts-node/crypto" + http2 "github.com/nuts-foundation/nuts-node/http" "github.com/nuts-foundation/nuts-node/vdr/didweb" ) @@ -36,7 +36,10 @@ func (r *Wrapper) requestUserAccessToken(_ context.Context, requestHolder did.DI token := crypto.GenerateNonce() store := r.storageEngine.GetSessionDatabase().GetStore(time.Second*5, "user", "redirect") // put the request in the store - err := store.Put(token, request) + err := store.Put(token, RedirectSession{ + OwnDID: requestHolder, + AccessTokenRequest: request, + }) if err != nil { return nil, err } @@ -74,24 +77,26 @@ func (r *Wrapper) handleUserLanding(echoCtx echo.Context) error { // extract request from store store := r.storageEngine.GetSessionDatabase().GetStore(time.Second, "user", "redirect") - request := RequestAccessTokenRequestObject{} - err := store.Get(token, &request) + redirectSession := RedirectSession{} + err := store.Get(token, &redirectSession) if err != nil { log.Logger().Debug("token not found in store") return echoCtx.String(http.StatusForbidden, "") } - requester, err := did.ParseDID(request.Did) + accessTokenRequest := redirectSession.AccessTokenRequest + // burn token + err = store.Delete(token) if err != nil { - return echoCtx.String(http.StatusInternalServerError, "") + //rare, log just in case + log.Logger().Warn("delete token failed") } - // create UserSession with userID from request // generate new sessionID and clientState with crypto.GenerateNonce() userSession := UserSession{ ClientState: crypto.GenerateNonce(), SessionID: crypto.GenerateNonce(), - UserID: *request.Body.UserID, // should be there... - OwnDID: *requester, + UserID: *accessTokenRequest.Body.UserID, // should be there... + OwnDID: redirectSession.OwnDID, } // store user session in session store under sessionID and clientState @@ -105,10 +110,13 @@ func (r *Wrapper) handleUserLanding(echoCtx echo.Context) error { if err != nil { return err } - verifier, err := did.ParseDID(request.Body.Verifier) + verifier, err := did.ParseDID(accessTokenRequest.Body.Verifier) + if err != nil { + return err + } + redirectURL, err := r.auth.RelyingParty().AuthorizationRequest(echoCtx.Request().Context(), redirectSession.OwnDID, *verifier, accessTokenRequest.Body.Scope, userSession.ClientState) if err != nil { return err } - redirectURL, err := r.auth.RelyingParty().AuthorizationRequest(echoCtx.Request().Context(), *requester, *verifier, request.Body.Scope, userSession.ClientState) return echoCtx.Redirect(http.StatusFound, redirectURL.String()) } diff --git a/auth/api/iam/user_test.go b/auth/api/iam/user_test.go index d6e19d5369..f0d00df87d 100644 --- a/auth/api/iam/user_test.go +++ b/auth/api/iam/user_test.go @@ -19,6 +19,10 @@ package iam import ( + "github.com/nuts-foundation/nuts-node/mock" + "go.uber.org/mock/gomock" + "net/http" + "net/url" "testing" "time" @@ -46,10 +50,10 @@ func TestWrapper_requestUserAccessToken(t *testing.T) { // assert session store := ctx.client.storageEngine.GetSessionDatabase().GetStore(time.Second*5, "user", "redirect") - var target RequestAccessTokenRequestObject + var target RedirectSession err = store.Get(redirectResponse.Headers.Location[37:], &target) require.NoError(t, err) - assert.Equal(t, walletDID.String(), target.Did) + assert.Equal(t, walletDID, target.OwnDID) }) t.Run("error - wrong did type", func(t *testing.T) { @@ -63,3 +67,96 @@ func TestWrapper_requestUserAccessToken(t *testing.T) { assert.EqualError(t, err, "unsupported DID method: test") }) } + +func TestWrapper_handleUserLanding(t *testing.T) { + walletDID := did.MustParseDID("did:web:test.test:iam:123") + verifierDID := did.MustParseDID("did:web:test.test:iam:456") + userID := "user" + redirectSession := RedirectSession{ + OwnDID: walletDID, + AccessTokenRequest: RequestAccessTokenRequestObject{ + Body: &RequestAccessTokenJSONRequestBody{ + Scope: "first second", + UserID: &userID, + Verifier: verifierDID.String(), + }, + Did: walletDID.String(), + }, + } + + t.Run("OK", func(t *testing.T) { + ctx := newTestClient(t) + expectedURL, _ := url.Parse("https://test.test/iam/123/user?token=token") + echoCtx := mock.NewMockContext(ctx.ctrl) + echoCtx.EXPECT().QueryParam("token").Return("token") + echoCtx.EXPECT().Request().Return(&http.Request{Host: "test.test"}) + echoCtx.EXPECT().Redirect(http.StatusFound, expectedURL.String()) + ctx.relyingParty.EXPECT().AuthorizationRequest(gomock.Any(), walletDID, verifierDID, "first second", gomock.Any()).Return(expectedURL, nil) + store := ctx.client.storageEngine.GetSessionDatabase().GetStore(time.Second*5, "user", "redirect") + err := store.Put("token", redirectSession) + require.NoError(t, err) + + err = ctx.client.handleUserLanding(echoCtx) + + require.NoError(t, err) + // check for deleted token + err = store.Get("token", &RedirectSession{}) + assert.Error(t, err) + }) + t.Run("error - no token", func(t *testing.T) { + ctx := newTestClient(t) + echoCtx := mock.NewMockContext(ctx.ctrl) + echoCtx.EXPECT().QueryParam("token").Return("") + echoCtx.EXPECT().String(http.StatusForbidden, "") + + err := ctx.client.handleUserLanding(echoCtx) + + require.NoError(t, err) + }) + t.Run("error - token not found", func(t *testing.T) { + ctx := newTestClient(t) + echoCtx := mock.NewMockContext(ctx.ctrl) + echoCtx.EXPECT().QueryParam("token").Return("token") + echoCtx.EXPECT().String(http.StatusForbidden, "") + + err := ctx.client.handleUserLanding(echoCtx) + + require.NoError(t, err) + }) + t.Run("error - verifier did parse error", func(t *testing.T) { + ctx := newTestClient(t) + echoCtx := mock.NewMockContext(ctx.ctrl) + echoCtx.EXPECT().QueryParam("token").Return("token") + store := ctx.client.storageEngine.GetSessionDatabase().GetStore(time.Second*5, "user", "redirect") + err := store.Put("token", RedirectSession{ + OwnDID: walletDID, + AccessTokenRequest: RequestAccessTokenRequestObject{ + Body: &RequestAccessTokenJSONRequestBody{ + Scope: "first second", + UserID: &userID, + Verifier: "invalid", + }, + Did: walletDID.String(), + }, + }) + require.NoError(t, err) + + err = ctx.client.handleUserLanding(echoCtx) + + require.Error(t, err) + }) + t.Run("error - authorization request error", func(t *testing.T) { + ctx := newTestClient(t) + echoCtx := mock.NewMockContext(ctx.ctrl) + echoCtx.EXPECT().QueryParam("token").Return("token") + echoCtx.EXPECT().Request().Return(&http.Request{Host: "test.test"}) + store := ctx.client.storageEngine.GetSessionDatabase().GetStore(time.Second*5, "user", "redirect") + err := store.Put("token", redirectSession) + require.NoError(t, err) + ctx.relyingParty.EXPECT().AuthorizationRequest(gomock.Any(), walletDID, verifierDID, "first second", gomock.Any()).Return(nil, assert.AnError) + + err = ctx.client.handleUserLanding(echoCtx) + + assert.Error(t, err) + }) +}