Skip to content

Commit

Permalink
implement pkce challenge to oidc code_exchange
Browse files Browse the repository at this point in the history
Signed-off-by: Houssem Ben Mabrouk <[email protected]>
  • Loading branch information
orange-hbenmabrouk committed Oct 16, 2023
1 parent e738058 commit 3056b95
Showing 1 changed file with 65 additions and 8 deletions.
73 changes: 65 additions & 8 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package oidc

import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -36,6 +40,11 @@ type Config struct {

Scopes []string `json:"scopes"` // defaults to "profile" and "email"

PKCE struct {
// Configurable key which controls if pkce challenge should be created or not
Enabled bool `json:"enabled"` // defaults to "false"
} `json:"pkce"`

// HostedDomains was an optional list of whitelisted domains when using the OIDC connector with Google.
// Only users from a whitelisted domain were allowed to log in.
// Support for this option was removed from the OIDC connector.
Expand Down Expand Up @@ -103,6 +112,13 @@ type connectorData struct {
RefreshToken []byte
}

// PKCE stores information about the pkce challenge if enabled
type PKCE struct {
CodeChallenge string
CodeChallengeMethod string
CodeVerifier string
}

// Detect auth header provider issues for known providers. This lets users
// avoid having to explicitly set "basicAuthUnsupported" in their config.
//
Expand All @@ -118,6 +134,16 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool {
return false
}

func randomBytesInHex(count int) (string, error) {
buf := make([]byte, count)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("Could not generate %d random bytes: %v", count, err)

Check failure on line 141 in connector/oidc/oidc.go

View workflow job for this annotation

GitHub Actions / Lint

ST1005: error strings should not be capitalized (stylecheck)
}

return hex.EncodeToString(buf), nil
}

// Open returns a connector which can be used to login users through an upstream
// OpenID Connect provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
Expand Down Expand Up @@ -162,6 +188,25 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
c.PromptType = "consent"
}

// pkce
pkce := PKCE{}
if c.PKCE.Enabled {
// GENERATE code verifier & code challenge
codeVerifier, err := randomBytesInHex(32) // 64 character string
if err != nil {
cancel()
return nil, fmt.Errorf("failed to generate the code verifier: %v", err)
}
sha2 := sha256.New()
if _, err = io.WriteString(sha2, codeVerifier); err != nil {
cancel()
return nil, fmt.Errorf("UNable to copy code_verifier to the hash: %v", err)
}
pkce.CodeChallenge = base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
pkce.CodeChallengeMethod = "S256"
pkce.CodeVerifier = codeVerifier
}

clientID := c.ClientID
return &oidcConnector{
provider: provider,
Expand All @@ -176,6 +221,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
verifier: provider.Verifier(
&oidc.Config{ClientID: clientID},
),
pkce: pkce,
logger: logger,
cancel: cancel,
httpClient: httpClient,
Expand Down Expand Up @@ -203,6 +249,7 @@ type oidcConnector struct {
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
pkce PKCE
cancel context.CancelFunc
logger log.Logger
httpClient *http.Client
Expand Down Expand Up @@ -239,7 +286,13 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil

if c.pkce.CodeChallenge != "" && c.pkce.CodeChallengeMethod != "" && c.pkce.CodeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_challenge", c.pkce.CodeChallenge))
opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", c.pkce.CodeChallengeMethod))
}
url := c.oauth2Config.AuthCodeURL(state, opts...)
return url, nil
}

type oauth2Error struct {
Expand All @@ -262,16 +315,14 @@ const (
exchangeCaller
)

func (c *oidcConnector) getTokenViaClientCredentials(scopes string) (token *oauth2.Token, err error) {
if scopes == "" {
scopes = strings.Join(c.oauth2Config.Scopes, " ")
}
func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err error) {
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {c.oauth2Config.ClientID},
"client_secret": {c.oauth2Config.ClientSecret},
"scope": {scopes},
"scope": {strings.Join(c.oauth2Config.Scopes, " ")},
}

resp, err := c.httpClient.PostForm(c.oauth2Config.Endpoint.TokenURL, data)
if err != nil {
return nil, fmt.Errorf("oidc: failed to get token: %v", err)
Expand Down Expand Up @@ -318,14 +369,20 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)
if q.Has("code") {
// exchange code to token
token, err := c.oauth2Config.Exchange(ctx, q.Get("code"))
var opts []oauth2.AuthCodeOption

if c.pkce.CodeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", c.pkce.CodeVerifier))
}

token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...)
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}
return c.createIdentity(ctx, identity, token, createCaller)
} else {
// get token via client_credentials
token, err := c.getTokenViaClientCredentials(r.Form.Get("scope"))
token, err := c.getTokenViaClientCredentials()
if err != nil {
return identity, err
}
Expand Down

0 comments on commit 3056b95

Please sign in to comment.