From 3056b95db1dca893b1193e55cd3f002aa29f8937 Mon Sep 17 00:00:00 2001 From: Houssem Ben Mabrouk Date: Tue, 10 Oct 2023 13:14:00 +0200 Subject: [PATCH] implement pkce challenge to oidc code_exchange Signed-off-by: Houssem Ben Mabrouk --- connector/oidc/oidc.go | 73 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 2be973f4ae..6c1c863875 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -3,6 +3,10 @@ package oidc import ( "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -36,6 +40,11 @@ type Config struct { Scopes []string `json:"scopes"` // defaults to "profile" and "email" + PKCE struct { + // Configurable key which controls if pkce challenge should be created or not + Enabled bool `json:"enabled"` // defaults to "false" + } `json:"pkce"` + // HostedDomains was an optional list of whitelisted domains when using the OIDC connector with Google. // Only users from a whitelisted domain were allowed to log in. // Support for this option was removed from the OIDC connector. @@ -103,6 +112,13 @@ type connectorData struct { RefreshToken []byte } +// PKCE stores information about the pkce challenge if enabled +type PKCE struct { + CodeChallenge string + CodeChallengeMethod string + CodeVerifier string +} + // Detect auth header provider issues for known providers. This lets users // avoid having to explicitly set "basicAuthUnsupported" in their config. // @@ -118,6 +134,16 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool { return false } +func randomBytesInHex(count int) (string, error) { + buf := make([]byte, count) + _, err := io.ReadFull(rand.Reader, buf) + if err != nil { + return "", fmt.Errorf("Could not generate %d random bytes: %v", count, err) + } + + return hex.EncodeToString(buf), nil +} + // Open returns a connector which can be used to login users through an upstream // OpenID Connect provider. func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { @@ -162,6 +188,25 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e c.PromptType = "consent" } + // pkce + pkce := PKCE{} + if c.PKCE.Enabled { + // GENERATE code verifier & code challenge + codeVerifier, err := randomBytesInHex(32) // 64 character string + if err != nil { + cancel() + return nil, fmt.Errorf("failed to generate the code verifier: %v", err) + } + sha2 := sha256.New() + if _, err = io.WriteString(sha2, codeVerifier); err != nil { + cancel() + return nil, fmt.Errorf("UNable to copy code_verifier to the hash: %v", err) + } + pkce.CodeChallenge = base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + pkce.CodeChallengeMethod = "S256" + pkce.CodeVerifier = codeVerifier + } + clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -176,6 +221,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e verifier: provider.Verifier( &oidc.Config{ClientID: clientID}, ), + pkce: pkce, logger: logger, cancel: cancel, httpClient: httpClient, @@ -203,6 +249,7 @@ type oidcConnector struct { redirectURI string oauth2Config *oauth2.Config verifier *oidc.IDTokenVerifier + pkce PKCE cancel context.CancelFunc logger log.Logger httpClient *http.Client @@ -239,7 +286,13 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if s.OfflineAccess { opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + + if c.pkce.CodeChallenge != "" && c.pkce.CodeChallengeMethod != "" && c.pkce.CodeVerifier != "" { + opts = append(opts, oauth2.SetAuthURLParam("code_challenge", c.pkce.CodeChallenge)) + opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", c.pkce.CodeChallengeMethod)) + } + url := c.oauth2Config.AuthCodeURL(state, opts...) + return url, nil } type oauth2Error struct { @@ -262,16 +315,14 @@ const ( exchangeCaller ) -func (c *oidcConnector) getTokenViaClientCredentials(scopes string) (token *oauth2.Token, err error) { - if scopes == "" { - scopes = strings.Join(c.oauth2Config.Scopes, " ") - } +func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err error) { data := url.Values{ "grant_type": {"client_credentials"}, "client_id": {c.oauth2Config.ClientID}, "client_secret": {c.oauth2Config.ClientSecret}, - "scope": {scopes}, + "scope": {strings.Join(c.oauth2Config.Scopes, " ")}, } + resp, err := c.httpClient.PostForm(c.oauth2Config.Endpoint.TokenURL, data) if err != nil { return nil, fmt.Errorf("oidc: failed to get token: %v", err) @@ -318,14 +369,20 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) if q.Has("code") { // exchange code to token - token, err := c.oauth2Config.Exchange(ctx, q.Get("code")) + var opts []oauth2.AuthCodeOption + + if c.pkce.CodeVerifier != "" { + opts = append(opts, oauth2.SetAuthURLParam("code_verifier", c.pkce.CodeVerifier)) + } + + token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) } return c.createIdentity(ctx, identity, token, createCaller) } else { // get token via client_credentials - token, err := c.getTokenViaClientCredentials(r.Form.Get("scope")) + token, err := c.getTokenViaClientCredentials() if err != nil { return identity, err }