Skip to content

Commit

Permalink
fix: update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Apr 15, 2024
1 parent 7ec7fcb commit 5054903
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 5 deletions.
36 changes: 36 additions & 0 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1945,3 +1945,39 @@ func newOAuth2Client(
Scopes: strings.Split(c.Scope, " "),
}
}

func newDeviceClient(
t *testing.T,
reg interface {
config.Provider
client.Registry
},
opts ...func(*client.Client),
) (*client.Client, *oauth2.Config) {
ctx := context.Background()
c := &client.Client{
GrantTypes: []string{
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
TokenEndpointAuthMethod: "none",
}

// apply options
for _, o := range opts {
o(c)
}

require.NoError(t, reg.ClientManager().CreateClient(ctx, c))
return c, &oauth2.Config{
ClientID: c.GetID(),
Endpoint: oauth2.Endpoint{
DeviceAuthURL: reg.Config().OAuth2DeviceAuthorisationURL(ctx).String(),
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: strings.Split(c.Scope, " "),
}
}
300 changes: 295 additions & 5 deletions oauth2/oauth2_device_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package oauth2_test

import (
"context"
"net/http"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -15,15 +17,21 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"golang.org/x/oauth2"

"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/internal/testhelpers"
hydraoauth2 "github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
"github.com/ory/x/pointerx"
"github.com/ory/x/requirex"
)

func TestDeviceAuthRequest(t *testing.T) {
Expand All @@ -33,11 +41,9 @@ func TestDeviceAuthRequest(t *testing.T) {

secret := uuid.New()
c := &client.Client{
ID: "device-client",
Secret: secret,
GrantTypes: []string{
string(fosite.GrantTypeDeviceCode),
},
ID: "device-client",
Secret: secret,
GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
TokenEndpointAuthMethod: "client_secret_post",
Expand Down Expand Up @@ -219,3 +225,287 @@ func TestDeviceTokenRequest(t *testing.T) {
})
}
}

func TestDeviceCodeWithDefaultStrategy(t *testing.T) {
ctx := context.Background()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg)

publicClient := hydra.NewAPIClient(hydra.NewConfiguration())
publicClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: publicTS.URL}}
adminClient := hydra.NewAPIClient(hydra.NewConfiguration())
adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}}

getDeviceCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (*oauth2.DeviceAuthResponse, error) {
if c == nil {
c = testhelpers.NewEmptyJarClient(t)
}

return conf.DeviceAuth(ctx, params...)
}

acceptUserCode := func(t *testing.T, conf *oauth2.Config, c *http.Client, devResp *oauth2.DeviceAuthResponse) *http.Response {
if c == nil {
c = testhelpers.NewEmptyJarClient(t)
}

resp, err := c.Get(devResp.VerificationURIComplete)
require.NoError(t, err)
require.Contains(t, reg.Config().DeviceDoneURL(ctx).String(), resp.Request.URL.Path, "did not end up in post device URL")
require.Equal(t, resp.Request.URL.Query().Get("client_id"), conf.ClientID)

return resp
}

acceptDeviceHandler := func(t *testing.T, c *client.Client) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userCode := r.URL.Query().Get("user_code")
payload := hydra.AcceptDeviceUserCodeRequest{
UserCode: &userCode,
}

v, _, err := adminClient.OAuth2Api.AcceptUserCodeRequest(context.Background()).
DeviceChallenge(r.URL.Query().Get("device_challenge")).
AcceptDeviceUserCodeRequest(payload).
Execute()
require.NoError(t, err)
require.NotEmpty(t, v.RedirectTo)
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
}
}

acceptLoginHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
rr, _, err := adminClient.OAuth2Api.GetOAuth2LoginRequest(context.Background()).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute()
require.NoError(t, err)

assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId))
assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret))
assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes)
assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
assert.EqualValues(t, r.URL.Query().Get("login_challenge"), rr.Challenge)
assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
assert.Contains(t, rr.RequestUrl, hydraoauth2.DeviceVerificationPath)

acceptBody := hydra.AcceptOAuth2LoginRequest{
Subject: subject,
Remember: pointerx.Ptr(!rr.Skip),
Acr: pointerx.Ptr("1"),
Amr: []string{"pwd"},
Context: map[string]interface{}{"context": "bar"},
}
if checkRequestPayload != nil {
if b := checkRequestPayload(rr); b != nil {
acceptBody = *b
}
}

v, _, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(context.Background()).
LoginChallenge(r.URL.Query().Get("login_challenge")).
AcceptOAuth2LoginRequest(acceptBody).
Execute()
require.NoError(t, err)
require.NotEmpty(t, v.RedirectTo)
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
}
}

acceptConsentHandler := func(t *testing.T, c *client.Client, subject string, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
rr, _, err := adminClient.OAuth2Api.GetOAuth2ConsentRequest(context.Background()).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute()
require.NoError(t, err)

assert.EqualValues(t, c.GetID(), pointerx.Deref(rr.Client.ClientId))
assert.Empty(t, pointerx.Deref(rr.Client.ClientSecret))
assert.EqualValues(t, c.GrantTypes, rr.Client.GrantTypes)
assert.EqualValues(t, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
assert.EqualValues(t, subject, pointerx.Deref(rr.Subject))
assert.EqualValues(t, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
assert.EqualValues(t, r.URL.Query().Get("consent_challenge"), rr.Challenge)
assert.Contains(t, *rr.RequestUrl, hydraoauth2.DeviceVerificationPath)
if checkRequestPayload != nil {
checkRequestPayload(rr)
}

assert.Equal(t, map[string]interface{}{"context": "bar"}, rr.Context)
v, _, err := adminClient.OAuth2Api.AcceptOAuth2ConsentRequest(context.Background()).
ConsentChallenge(r.URL.Query().Get("consent_challenge")).
AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0),
GrantAccessTokenAudience: rr.RequestedAccessTokenAudience,
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{"foo": "bar"},
IdToken: map[string]interface{}{"bar": "baz"},
},
}).
Execute()
require.NoError(t, err)
require.NotEmpty(t, v.RedirectTo)
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
}
}

assertRefreshToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedExp time.Time) {
actualExp, err := strconv.ParseInt(testhelpers.IntrospectToken(t, c, token.RefreshToken, adminTS).Get("exp").String(), 10, 64)
require.NoError(t, err)
requirex.EqualTime(t, expectedExp, time.Unix(actualExp, 0), time.Second)
}

assertIDToken := func(t *testing.T, token *oauth2.Token, c *oauth2.Config, expectedSubject, expectedNonce string, expectedExp time.Time) gjson.Result {
idt, ok := token.Extra("id_token").(string)
require.True(t, ok)
assert.NotEmpty(t, idt)

body, err := x.DecodeSegment(strings.Split(idt, ".")[1])
require.NoError(t, err)

claims := gjson.ParseBytes(body)
assert.True(t, time.Now().After(time.Unix(claims.Get("iat").Int(), 0)), "%s", claims)
assert.True(t, time.Now().After(time.Unix(claims.Get("nbf").Int(), 0)), "%s", claims)
assert.True(t, time.Now().Before(time.Unix(claims.Get("exp").Int(), 0)), "%s", claims)
requirex.EqualTime(t, expectedExp, time.Unix(claims.Get("exp").Int(), 0), 2*time.Second)
assert.NotEmpty(t, claims.Get("jti").String(), "%s", claims)
assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), claims.Get("iss").String(), "%s", claims)
assert.NotEmpty(t, claims.Get("sid").String(), "%s", claims)
assert.Equal(t, "1", claims.Get("acr").String(), "%s", claims)
require.Len(t, claims.Get("amr").Array(), 1, "%s", claims)
assert.EqualValues(t, "pwd", claims.Get("amr").Array()[0].String(), "%s", claims)

require.Len(t, claims.Get("aud").Array(), 1, "%s", claims)
assert.EqualValues(t, c.ClientID, claims.Get("aud").Array()[0].String(), "%s", claims)
assert.EqualValues(t, expectedSubject, claims.Get("sub").String(), "%s", claims)
assert.EqualValues(t, `baz`, claims.Get("bar").String(), "%s", claims)

return claims
}

introspectAccessToken := func(t *testing.T, conf *oauth2.Config, token *oauth2.Token, expectedSubject string) gjson.Result {
require.NotEmpty(t, token.AccessToken)
i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.True(t, i.Get("active").Bool(), "%s", i)
assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i)
assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i)
assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i)
return i
}

assertJWTAccessToken := func(t *testing.T, strat string, conf *oauth2.Config, token *oauth2.Token, expectedSubject string, expectedExp time.Time, scopes string) gjson.Result {
require.NotEmpty(t, token.AccessToken)
parts := strings.Split(token.AccessToken, ".")
if strat != "jwt" {
require.Len(t, parts, 2)
return gjson.Parse("null")
}
require.Len(t, parts, 3)

body, err := x.DecodeSegment(parts[1])
require.NoError(t, err)

i := gjson.ParseBytes(body)
assert.NotEmpty(t, i.Get("jti").String())
assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i)
assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i)
assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), i.Get("iss").String(), "%s", i)
assert.True(t, time.Now().After(time.Unix(i.Get("iat").Int(), 0)), "%s", i)
assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i)
assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i)
requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second)
assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i)
assert.EqualValues(t, scopes, i.Get("scp").Raw, "%s", i)
return i
}

waitForRefreshTokenExpiry := func() {
time.Sleep(reg.Config().GetRefreshTokenLifespan(ctx) + time.Second)
}

t.Run("case=checks if request fails when audience does not match", func(t *testing.T) {
testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t))
_, conf := newDeviceClient(t, reg)
resp, err := getDeviceCode(t, conf, nil, oauth2.SetAuthURLParam("audience", "https://not-ory-api/"))
require.Error(t, err)
devErr := err.(*oauth2.RetrieveError)
require.Nil(t, resp)
require.Equal(t, devErr.Response.StatusCode, http.StatusBadRequest)
})

subject := "aeneas-rekkas"
nonce := uuid.New()
t.Run("case=perform device flow with ID token and refresh tokens", func(t *testing.T) {
run := func(t *testing.T, strategy string) {
c, conf := newDeviceClient(t, reg)
testhelpers.NewDeviceLoginConsentUI(t, reg.Config(),
acceptDeviceHandler(t, c),
acceptLoginHandler(t, c, subject, nil),
acceptConsentHandler(t, c, subject, nil),
)

resp, err := getDeviceCode(t, conf, nil, oauth2.SetAuthURLParam("nonce", nonce))
require.NoError(t, err)
require.NotEmpty(t, resp.DeviceCode)
require.NotEmpty(t, resp.UserCode)
loginFlowResp := acceptUserCode(t, conf, nil, resp)
require.NotNil(t, loginFlowResp)
token, err := conf.DeviceAccessToken(context.Background(), resp)
iat := time.Now()
require.NoError(t, err)

assert.Empty(t, token.Extra("c_nonce_draft_00"), "should not be set if not requested")
assert.Empty(t, token.Extra("c_nonce_expires_in_draft_00"), "should not be set if not requested")
introspectAccessToken(t, conf, token, subject)
assertJWTAccessToken(t, strategy, conf, token, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`)
assertIDToken(t, token, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx)))
assertRefreshToken(t, token, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx)))

t.Run("followup=successfully perform refresh token flow", func(t *testing.T) {
require.NotEmpty(t, token.RefreshToken)
token.Expiry = token.Expiry.Add(-time.Hour * 24)
iat = time.Now()
refreshedToken, err := conf.TokenSource(context.Background(), token).Token()
require.NoError(t, err)

require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken)
require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken)
require.NotEqual(t, token.Extra("id_token"), refreshedToken.Extra("id_token"))
introspectAccessToken(t, conf, refreshedToken, subject)

t.Run("followup=refreshed tokens contain valid tokens", func(t *testing.T) {
assertJWTAccessToken(t, strategy, conf, refreshedToken, subject, iat.Add(reg.Config().GetAccessTokenLifespan(ctx)), `["hydra","offline","openid"]`)
assertIDToken(t, refreshedToken, conf, subject, nonce, iat.Add(reg.Config().GetIDTokenLifespan(ctx)))
assertRefreshToken(t, refreshedToken, conf, iat.Add(reg.Config().GetRefreshTokenLifespan(ctx)))
})

t.Run("followup=original access token is no longer valid", func(t *testing.T) {
i := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.False(t, i.Get("active").Bool(), "%s", i)
})

t.Run("followup=original refresh token is no longer valid", func(t *testing.T) {
_, err := conf.TokenSource(context.Background(), token).Token()
assert.Error(t, err)
})

t.Run("followup=but fail subsequent refresh because expiry was reached", func(t *testing.T) {
waitForRefreshTokenExpiry()

// Force golang to refresh token
refreshedToken.Expiry = refreshedToken.Expiry.Add(-time.Hour * 24)
_, err := conf.TokenSource(context.Background(), refreshedToken).Token()
require.Error(t, err)
})
})
}

t.Run("strategy=jwt", func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt")
run(t, "jwt")
})

t.Run("strategy=opaque", func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
run(t, "opaque")
})
})
}

0 comments on commit 5054903

Please sign in to comment.