From 5054903ec27ee238eb3727a1c6c65431969c60b1 Mon Sep 17 00:00:00 2001 From: Nikos Date: Mon, 15 Apr 2024 14:06:49 +0300 Subject: [PATCH] fix: update tests --- oauth2/oauth2_auth_code_test.go | 36 ++++ oauth2/oauth2_device_code_test.go | 300 +++++++++++++++++++++++++++++- 2 files changed, 331 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 8f33b16f835..cc5fccdffbb 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -1945,3 +1945,39 @@ func newOAuth2Client( Scopes: strings.Split(c.Scope, " "), } } + +func newDeviceClient( + t *testing.T, + reg interface { + config.Provider + client.Registry + }, + opts ...func(*client.Client), +) (*client.Client, *oauth2.Config) { + ctx := context.Background() + c := &client.Client{ + GrantTypes: []string{ + "refresh_token", + "urn:ietf:params:oauth:grant-type:device_code", + }, + Scope: "hydra offline openid", + Audience: []string{"https://api.ory.sh/"}, + TokenEndpointAuthMethod: "none", + } + + // apply options + for _, o := range opts { + o(c) + } + + require.NoError(t, reg.ClientManager().CreateClient(ctx, c)) + return c, &oauth2.Config{ + ClientID: c.GetID(), + Endpoint: oauth2.Endpoint{ + DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(), + TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), + AuthStyle: oauth2.AuthStyleInHeader, + }, + Scopes: strings.Split(c.Scope, " "), + } +} diff --git a/oauth2/oauth2_device_code_test.go b/oauth2/oauth2_device_code_test.go index 758b916be8a..aacdddf1d8d 100644 --- a/oauth2/oauth2_device_code_test.go +++ b/oauth2/oauth2_device_code_test.go @@ -5,6 +5,8 @@ package oauth2_test import ( "context" + "net/http" + "strconv" "strings" "testing" "time" @@ -15,15 +17,21 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "golang.org/x/oauth2" "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" + hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/internal/testhelpers" hydraoauth2 "github.com/ory/hydra/v2/oauth2" + "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" + "github.com/ory/x/pointerx" + "github.com/ory/x/requirex" ) func TestDeviceAuthRequest(t *testing.T) { @@ -33,11 +41,9 @@ func TestDeviceAuthRequest(t *testing.T) { secret := uuid.New() c := &client.Client{ - ID: "device-client", - Secret: secret, - GrantTypes: []string{ - string(fosite.GrantTypeDeviceCode), - }, + ID: "device-client", + Secret: secret, + GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, Scope: "hydra offline openid", Audience: []string{"https://api.ory.sh/"}, TokenEndpointAuthMethod: "client_secret_post", @@ -219,3 +225,287 @@ func TestDeviceTokenRequest(t *testing.T) { }) } } + +func TestDeviceCodeWithDefaultStrategy(t *testing.T) { + ctx := context.Background() + reg := internal.NewMockedRegistry(t, &contextx.Default{}) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "") + publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg) + + publicClient := hydra.NewAPIClient(hydra.NewConfiguration()) + publicClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: publicTS.URL}} + adminClient := hydra.NewAPIClient(hydra.NewConfiguration()) + adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} + + getDeviceCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (*oauth2.DeviceAuthResponse, error) { + if c == nil { + c = testhelpers.NewEmptyJarClient(t) + } + + return conf.DeviceAuth(ctx, params...) + } + + acceptUserCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, devResp *oauth2.DeviceAuthResponse) *http.Response { + if c == nil { + c = testhelpers.NewEmptyJarClient(t) + } + + resp, err := c.Get(devResp.VerificationURIComplete) + require.NoError(t, err) + require.Contains(t, reg.Config().DeviceDoneURL(ctx).String(), resp.Request.URL.Path, "did not end up in post device URL") + require.Equal(t, resp.Request.URL.Query().Get("client_id"), conf.ClientID) + + return resp + } + + acceptDeviceHandler := func(t *testing.T, c *client.Client) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userCode := r.URL.Query().Get("user_code") + payload := hydra.AcceptDeviceUserCodeRequest{ + UserCode: &userCode, + } + + v, _, err := adminClient.OAuth2Api.AcceptUserCodeRequest(context.Background()). + DeviceChallenge(r.URL.Query().Get("device_challenge")). + AcceptDeviceUserCodeRequest(payload). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } + } + + acceptLoginHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2Api.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() + require.NoError(t, err) + + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.Contains(t, rr.RequestUrl, hydraoauth2.DeviceVerificationPath) + + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: subject, + Remember: pointerx.Ptr(!rr.Skip), + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, + } + if checkRequestPayload != nil { + if b := checkRequestPayload(rr); b != nil { + acceptBody = *b + } + } + + v, _, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(context.Background()). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } + } + + acceptConsentHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + rr, _, err := adminClient.OAuth2Api.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(t, err) + + assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(t, subject, pointerx.Deref(rr.Subject)) + assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge) + assert.Contains(t, *rr.RequestUrl, hydraoauth2.DeviceVerificationPath) + if checkRequestPayload != nil { + checkRequestPayload(rr) + } + + assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2Api.AcceptOAuth2ConsentRequest(context.Background()). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"bar": "baz"}, + }, + }). + Execute() + require.NoError(t, err) + require.NotEmpty(t, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + } + } + + assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) { + actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64) + require.NoError(t, err) + requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second) + } + + assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result { + idt, ok := token.Extra("id_token").(string) + require.True(t, ok) + assert.NotEmpty(t, idt) + + body, err := x.DecodeSegment(strings.Split(idt, ".")[1]) + require.NoError(t, err) + + claims := gjson.ParseBytes(body) + assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims) + assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims) + assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims) + requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 2*time.Second) + assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims) + assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims) + assert.Equal(t, "1", claims.Get("acr").String(), "%s", claims) + require.Len(t, claims.Get("amr").Array(), 1, "%s", claims) + assert.EqualValues(t, "pwd", claims.Get("amr").Array()[0].String(), "%s", claims) + + require.Len(t, claims.Get("aud").Array(), 1, "%s", claims) + assert.EqualValues(t, c.ClientID, claims.Get("aud").Array()[0].String(), "%s", claims) + assert.EqualValues(t, expectedSubject, claims.Get("sub").String(), "%s", claims) + assert.EqualValues(t, `baz`, claims.Get("bar").String(), "%s", claims) + + return claims + } + + introspectAccessToken := func(t *testing.T, conf *oauth2.Config, token *oauth2.Token, expectedSubject string) gjson.Result { + require.NotEmpty(t, token.AccessToken) + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.True(t, i.Get("active").Bool(), "%s", i) + assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) + assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + return i + } + + assertJWTAccessToken := func(t *testing.T, strat string, conf *oauth2.Config, token *oauth2.Token, expectedSubject string, expectedExp time.Time, scopes string) gjson.Result { + require.NotEmpty(t, token.AccessToken) + parts := strings.Split(token.AccessToken, ".") + if strat != "jwt" { + require.Len(t, parts, 2) + return gjson.Parse("null") + } + require.Len(t, parts, 3) + + body, err := x.DecodeSegment(parts[1]) + require.NoError(t, err) + + i := gjson.ParseBytes(body) + assert.NotEmpty(t, i.Get("jti").String()) + assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) + assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), i.Get("iss").String(), "%s", i) + assert.True(t, time.Now().After(time.Unix(i.Get("iat").Int(), 0)), "%s", i) + assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) + assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) + requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + assert.EqualValues(t, scopes, i.Get("scp").Raw, "%s", i) + return i + } + + waitForRefreshTokenExpiry := func() { + time.Sleep(reg.Config().GetRefreshTokenLifespan(ctx) + time.Second) + } + + t.Run("case=checks if request fails when audience does not match", func(t *testing.T) { + testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) + _, conf := newDeviceClient(t, reg) + resp, err := getDeviceCode(t, conf, nil, oauth2.SetAuthURLParam("audience", "https://not-ory-api/")) + require.Error(t, err) + devErr := err.(*oauth2.RetrieveError) + require.Nil(t, resp) + require.Equal(t, devErr.Response.StatusCode, http.StatusBadRequest) + }) + + subject := "aeneas-rekkas" + nonce := uuid.New() + t.Run("case=perform device flow with ID token and refresh tokens", func(t *testing.T) { + run := func(t *testing.T, strategy string) { + c, conf := newDeviceClient(t, reg) + testhelpers.NewDeviceLoginConsentUI(t, reg.Config(), + acceptDeviceHandler(t, c), + acceptLoginHandler(t, c, subject, nil), + acceptConsentHandler(t, c, subject, nil), + ) + + resp, err := getDeviceCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NoError(t, err) + require.NotEmpty(t, resp.DeviceCode) + require.NotEmpty(t, resp.UserCode) + loginFlowResp := acceptUserCode(t, conf, nil, resp) + require.NotNil(t, loginFlowResp) + token, err := conf.DeviceAccessToken(context.Background(), resp) + iat := time.Now() + require.NoError(t, err) + + assert.Empty(t, token.Extra("c_nonce_draft_00"), "should not be set if not requested") + assert.Empty(t, token.Extra("c_nonce_expires_in_draft_00"), "should not be set if not requested") + introspectAccessToken(t, conf, token, subject) + assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + + t.Run("followup=successfully perform refresh token flow", func(t *testing.T) { + require.NotEmpty(t, token.RefreshToken) + token.Expiry = token.Expiry.Add(-time.Hour * 24) + iat = time.Now() + refreshedToken, err := conf.TokenSource(context.Background(), token).Token() + require.NoError(t, err) + + require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken) + require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken) + require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token")) + introspectAccessToken(t, conf, refreshedToken, subject) + + t.Run("followup=refreshed tokens contain valid tokens", func(t *testing.T) { + assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`) + assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx))) + assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx))) + }) + + t.Run("followup=original access token is no longer valid", func(t *testing.T) { + i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS) + assert.False(t, i.Get("active").Bool(), "%s", i) + }) + + t.Run("followup=original refresh token is no longer valid", func(t *testing.T) { + _, err := conf.TokenSource(context.Background(), token).Token() + assert.Error(t, err) + }) + + t.Run("followup=but fail subsequent refresh because expiry was reached", func(t *testing.T) { + waitForRefreshTokenExpiry() + + // Force golang to refresh token + refreshedToken.Expiry = refreshedToken.Expiry.Add(-time.Hour * 24) + _, err := conf.TokenSource(context.Background(), refreshedToken).Token() + require.Error(t, err) + }) + }) + } + + t.Run("strategy=jwt", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt") + run(t, "jwt") + }) + + t.Run("strategy=opaque", func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + run(t, "opaque") + }) + }) +}