From f8d619573cb18cee885b65a535647b54b3afee03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miloslav=20Trma=C4=8D?= Date: Mon, 29 May 2023 22:41:25 +0200 Subject: [PATCH] Move creation of bearerToken to obtainBearerToken MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of having getBearerToken* construct a new bearerToken object, have the caller provide one. This will allow us to record that a token is being obtained, so that others can wait for it. Should not change behavior. Signed-off-by: Miloslav Trmač --- docker/docker_client.go | 60 +++++++++++++++++++----------------- docker/docker_client_test.go | 10 +++--- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/docker/docker_client.go b/docker/docker_client.go index 4da5c0139..8d372770b 100644 --- a/docker/docker_client.go +++ b/docker/docker_client.go @@ -760,20 +760,18 @@ func (c *dockerClient) obtainBearerToken(ctx context.Context, challenge challeng token, inCache = c.tokenCache[cacheKey] }() if !inCache || time.Now().After(token.expirationTime) { - var ( - t *bearerToken - err error - ) + token = &bearerToken{} + + var err error if c.auth.IdentityToken != "" { - t, err = c.getBearerTokenOAuth2(ctx, challenge, scopes) + err = c.getBearerTokenOAuth2(ctx, token, challenge, scopes) } else { - t, err = c.getBearerToken(ctx, challenge, scopes) + err = c.getBearerToken(ctx, token, challenge, scopes) } if err != nil { return "", err } - token = t func() { // A scope for defer c.tokenCacheLock.Lock() defer c.tokenCacheLock.Unlock() @@ -783,16 +781,19 @@ func (c *dockerClient) obtainBearerToken(ctx context.Context, challenge challeng return token.token, nil } -func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, challenge challenge, - scopes []authScope) (*bearerToken, error) { +// getBearerTokenOAuth2 obtains an "Authorization: Bearer" token using a pre-existing identity token per +// https://github.com/distribution/distribution/blob/main/docs/spec/auth/oauth.md for challenge and scopes, +// and writes it into dest. +func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, dest *bearerToken, challenge challenge, + scopes []authScope) error { realm, ok := challenge.Parameters["realm"] if !ok { - return nil, errors.New("missing realm in bearer auth challenge") + return errors.New("missing realm in bearer auth challenge") } authReq, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, nil) if err != nil { - return nil, err + return err } // Make the form data required against the oauth2 authentication @@ -817,26 +818,29 @@ func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, challenge chall logrus.Debugf("%s %s", authReq.Method, authReq.URL.Redacted()) res, err := c.client.Do(authReq) if err != nil { - return nil, err + return err } defer res.Body.Close() if err := httpResponseToError(res, "Trying to obtain access token"); err != nil { - return nil, err + return err } - return newBearerTokenFromHTTPResponseBody(res) + return dest.readFromHTTPResponseBody(res) } -func (c *dockerClient) getBearerToken(ctx context.Context, challenge challenge, - scopes []authScope) (*bearerToken, error) { +// getBearerToken obtains an "Authorization: Bearer" token using a GET request, per +// https://github.com/distribution/distribution/blob/main/docs/spec/auth/token.md for challenge and scopes, +// and writes it into dest. +func (c *dockerClient) getBearerToken(ctx context.Context, dest *bearerToken, challenge challenge, + scopes []authScope) error { realm, ok := challenge.Parameters["realm"] if !ok { - return nil, errors.New("missing realm in bearer auth challenge") + return errors.New("missing realm in bearer auth challenge") } authReq, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil) if err != nil { - return nil, err + return err } params := authReq.URL.Query() @@ -864,22 +868,22 @@ func (c *dockerClient) getBearerToken(ctx context.Context, challenge challenge, logrus.Debugf("%s %s", authReq.Method, authReq.URL.Redacted()) res, err := c.client.Do(authReq) if err != nil { - return nil, err + return err } defer res.Body.Close() if err := httpResponseToError(res, "Requesting bearer token"); err != nil { - return nil, err + return err } - return newBearerTokenFromHTTPResponseBody(res) + return dest.readFromHTTPResponseBody(res) } -// newBearerTokenFromHTTPResponseBody parses a http.Response to obtain a bearerToken. +// readFromHTTPResponseBody sets token data by parsing a http.Response. // The caller is still responsible for ensuring res.Body is closed. -func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error) { +func (bt *bearerToken) readFromHTTPResponseBody(res *http.Response) error { blob, err := iolimits.ReadAtMost(res.Body, iolimits.MaxAuthTokenBodySize) if err != nil { - return nil, err + return err } var token struct { @@ -895,12 +899,10 @@ func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error if len(bodySample) > bodySampleLength { bodySample = bodySample[:bodySampleLength] } - return nil, fmt.Errorf("decoding bearer token (last URL %q, body start %q): %w", res.Request.URL.Redacted(), string(bodySample), err) + return fmt.Errorf("decoding bearer token (last URL %q, body start %q): %w", res.Request.URL.Redacted(), string(bodySample), err) } - bt := &bearerToken{ - token: token.Token, - } + bt.token = token.Token if bt.token == "" { bt.token = token.AccessToken } @@ -913,7 +915,7 @@ func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error token.IssuedAt = time.Now().UTC() } bt.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second) - return bt, nil + return nil } // detectPropertiesHelper performs the work of detectProperties which executes diff --git a/docker/docker_client_test.go b/docker/docker_client_test.go index 16e2c286b..3ee318070 100644 --- a/docker/docker_client_test.go +++ b/docker/docker_client_test.go @@ -106,7 +106,7 @@ func testTokenHTTPResponse(t *testing.T, body string) *http.Response { } } -func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) { +func TestBearerTokenReadFromHTTPResponseBody(t *testing.T) { for _, c := range []struct { input string expected *bearerToken // or nil if a failure is expected @@ -128,7 +128,8 @@ func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) { expected: &bearerToken{token: "IAmAToken", expirationTime: time.Unix(1514800802+60, 0)}, }, } { - token, err := newBearerTokenFromHTTPResponseBody(testTokenHTTPResponse(t, c.input)) + token := &bearerToken{} + err := token.readFromHTTPResponseBody(testTokenHTTPResponse(t, c.input)) if c.expected == nil { assert.Error(t, err, c.input) } else { @@ -140,11 +141,12 @@ func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) { } } -func TestNewBearerTokenFromHTTPResponseBodyIssuedAtZero(t *testing.T) { +func TestBearerTokenReadFromHTTPResponseBodyIssuedAtZero(t *testing.T) { zeroTime := time.Time{}.Format(time.RFC3339) now := time.Now() tokenBlob := fmt.Sprintf(`{"token":"IAmAToken","expires_in":100,"issued_at":"%s"}`, zeroTime) - token, err := newBearerTokenFromHTTPResponseBody(testTokenHTTPResponse(t, tokenBlob)) + token := &bearerToken{} + err := token.readFromHTTPResponseBody(testTokenHTTPResponse(t, tokenBlob)) require.NoError(t, err) expectedExpiration := now.Add(time.Duration(100) * time.Second) require.False(t, token.expirationTime.Before(expectedExpiration),