Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Only obtain a bearer token once at a time #1968

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 142 additions & 67 deletions docker/docker_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
digest "github.com/opencontainers/go-digest"
imgspecv1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/sirupsen/logrus"
"golang.org/x/sync/semaphore"
)

const (
Expand Down Expand Up @@ -86,8 +87,19 @@ type extensionSignatureList struct {
Signatures []extensionSignature `json:"signatures"`
}

// bearerToken records a cached token we can use to authenticate.
// bearerToken records a cached token we can use to authenticate, or a pending process to obtain one.
//
// The goroutine obtaining the token holds lock to block concurrent token requests, and fills the structure (err and possibly the other fields)
// before releasing the lock.
// Other goroutines obtain lock to block on the token request, if any; and then inspect err to see if the token is usable.
// If it is not, they try to get a new one.
type bearerToken struct {
// lock is held while obtaining the token. Potentially nested inside dockerClient.tokenCacheLock.
// This is a counting semaphore only because we need a cancellable lock operation.
lock *semaphore.Weighted

// The following fields can only be accessed with lock held.
err error // nil if the token was successfully obtained (but may be expired); an error if the next lock holder _must_ obtain a new token.
token string
expirationTime time.Time
}
Expand Down Expand Up @@ -117,7 +129,8 @@ type dockerClient struct {
supportsSignatures bool

// Private state for setupRequestAuth (key: string, value: bearerToken)
tokenCache sync.Map
tokenCacheLock sync.Mutex // Protects tokenCache.
tokenCache map[string]*bearerToken
// Private state for detectProperties:
detectPropertiesOnce sync.Once // detectPropertiesOnce is used to execute detectProperties() at most once.
detectPropertiesError error // detectPropertiesError caches the initial error.
Expand Down Expand Up @@ -269,6 +282,7 @@ func newDockerClient(sys *types.SystemContext, registry, reference string) (*doc
registry: registry,
userAgent: userAgent,
tlsClientConfig: tlsClientConfig,
tokenCache: map[string]*bearerToken{},
reportedWarnings: set.New[string](),
}, nil
}
Expand Down Expand Up @@ -712,50 +726,11 @@ func (c *dockerClient) setupRequestAuth(req *http.Request, extraScope *authScope
req.SetBasicAuth(c.auth.Username, c.auth.Password)
return nil
case "bearer":
registryToken := c.registryToken
if registryToken == "" {
cacheKey := ""
scopes := []authScope{c.scope}
if extraScope != nil {
// Using ':' as a separator here is unambiguous because getBearerToken below
// uses the same separator when formatting a remote request (and because
// repository names that we create can't contain colons, and extraScope values
// coming from a server come from `parseAuthScope`, which also splits on colons).
cacheKey = fmt.Sprintf("%s:%s:%s", extraScope.resourceType, extraScope.remoteName, extraScope.actions)
if colonCount := strings.Count(cacheKey, ":"); colonCount != 2 {
return fmt.Errorf(
"Internal error: there must be exactly 2 colons in the cacheKey ('%s') but got %d",
cacheKey,
colonCount,
)
}
scopes = append(scopes, *extraScope)
}
var token bearerToken
t, inCache := c.tokenCache.Load(cacheKey)
if inCache {
token = t.(bearerToken)
}
if !inCache || time.Now().After(token.expirationTime) {
var (
t *bearerToken
err error
)
if c.auth.IdentityToken != "" {
t, err = c.getBearerTokenOAuth2(req.Context(), challenge, scopes)
} else {
t, err = c.getBearerToken(req.Context(), challenge, scopes)
}
if err != nil {
return err
}

token = *t
c.tokenCache.Store(cacheKey, token)
}
registryToken = token.token
token, err := c.obtainBearerToken(req.Context(), challenge, extraScope)
if err != nil {
return err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", registryToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return nil
default:
logrus.Debugf("no handler for %s authentication", challenge.Scheme)
Expand All @@ -765,16 +740,115 @@ func (c *dockerClient) setupRequestAuth(req *http.Request, extraScope *authScope
return nil
}

func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, challenge challenge,
scopes []authScope) (*bearerToken, error) {
// obtainBearerToken gets an "Authorization: Bearer" token if one is available, or obtains a fresh one.
func (c *dockerClient) obtainBearerToken(ctx context.Context, challenge challenge, extraScope *authScope) (string, error) {
if c.registryToken != "" {
return c.registryToken, nil
}

cacheKey := ""
scopes := []authScope{c.scope}
if extraScope != nil {
// Using ':' as a separator here is unambiguous because getBearerToken below
// uses the same separator when formatting a remote request (and because
// repository names that we create can't contain colons, and extraScope values
// coming from a server come from `parseAuthScope`, which also splits on colons).
cacheKey = fmt.Sprintf("%s:%s:%s", extraScope.resourceType, extraScope.remoteName, extraScope.actions)
if colonCount := strings.Count(cacheKey, ":"); colonCount != 2 {
return "", fmt.Errorf(
"Internal error: there must be exactly 2 colons in the cacheKey ('%s') but got %d",
cacheKey,
colonCount,
)
}
scopes = append(scopes, *extraScope)
}

logrus.Debugf("REMOVE: Checking token cache for key %q", cacheKey)
token, newEntry, err := func() (*bearerToken, bool, error) { // A scope for defer
c.tokenCacheLock.Lock()
defer c.tokenCacheLock.Unlock()
token, ok := c.tokenCache[cacheKey]
if ok {
return token, false, nil
} else {
logrus.Debugf("REMOVE: No token cache for key %q, allocating one…", cacheKey)
token = &bearerToken{
lock: semaphore.NewWeighted(1),
}
// If this is a new *bearerToken, lock the entry before adding it to the cache, so that any other goroutine that finds
// this entry blocks until we obtain the token for the first time, and does not see an empty object
// (and does not try to obtain the token itself when we are going to do so).
if err := token.lock.Acquire(ctx, 1); err != nil {
// We do not block on this Acquire, so we don’t really expect to fail here — but if ctx is canceled,
// there is no point in trying to continue anyway.
return nil, false, err
}
c.tokenCache[cacheKey] = token
return token, true, nil
}
}()
if err != nil {
return "", err
}
if !newEntry {
// If this is an existing *bearerToken, obtain the lock only after releasing c.tokenCacheLock,
// so that users of other cacheKey values are not blocked for the whole duration of our HTTP roundtrip.
logrus.Debugf("REMOVE: Found existing token cache for key %q, getting lock", cacheKey)
if err := token.lock.Acquire(ctx, 1); err != nil {
return "", err
}
logrus.Debugf("REMOVE: Locked existing token cache for key %q", cacheKey)
}

defer token.lock.Release(1)

// Determine if the bearerToken is usable: if it is not, log the cause and fall through, otherwise return early.
switch {
case newEntry:
logrus.Debugf("REMOVE: New token cache entry for key %q, getting first token", cacheKey)
case token.err != nil:
// If obtaining a token fails for any reason, the request that triggered that will fail;
// other requests will see token.err and try obtaining their own token, one goroutine at a time.
// (Consider that a request can fail because a very short timeout was provided to _that one operation_ using a context.Context;
// that clearly shouldn’t prevent other operations from trying with a longer timeout.)
//
// If we got here while holding token.lock, we are the goroutine responsible for trying again; others are blocked
// on token.lock.
logrus.Debugf("REMOVE: Token cache for key %q records failure %v, getting new token", cacheKey, token.err)
case time.Now().After(token.expirationTime):
logrus.Debugf("REMOVE: Token cache for key %q is expired, getting new token", cacheKey)

default:
return token.token, nil
}

if c.auth.IdentityToken != "" {
err = c.getBearerTokenOAuth2(ctx, token, challenge, scopes)
} else {
err = c.getBearerToken(ctx, token, challenge, scopes)
}
logrus.Debugf("REMOVE: Obtaining a token for key %q, error %v", cacheKey, err)
token.err = err
if token.err != nil {
return "", token.err
}
return token.token, nil
}

// getBearerTokenOAuth2 obtains an "Authorization: Bearer" token using a pre-existing identity token per
// https://github.com/distribution/distribution/blob/main/docs/spec/auth/oauth.md for challenge and scopes,
// and writes it into dest.
func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, dest *bearerToken, challenge challenge,
scopes []authScope) error {
realm, ok := challenge.Parameters["realm"]
if !ok {
return nil, errors.New("missing realm in bearer auth challenge")
return errors.New("missing realm in bearer auth challenge")
}

authReq, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, nil)
if err != nil {
return nil, err
return err
}

// Make the form data required against the oauth2 authentication
Expand All @@ -799,26 +873,29 @@ func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, challenge chall
logrus.Debugf("%s %s", authReq.Method, authReq.URL.Redacted())
res, err := c.client.Do(authReq)
if err != nil {
return nil, err
return err
}
defer res.Body.Close()
if err := httpResponseToError(res, "Trying to obtain access token"); err != nil {
return nil, err
return err
}

return newBearerTokenFromHTTPResponseBody(res)
return dest.readFromHTTPResponseBody(res)
}

func (c *dockerClient) getBearerToken(ctx context.Context, challenge challenge,
scopes []authScope) (*bearerToken, error) {
// getBearerToken obtains an "Authorization: Bearer" token using a GET request, per
// https://github.com/distribution/distribution/blob/main/docs/spec/auth/token.md for challenge and scopes,
// and writes it into dest.
func (c *dockerClient) getBearerToken(ctx context.Context, dest *bearerToken, challenge challenge,
scopes []authScope) error {
realm, ok := challenge.Parameters["realm"]
if !ok {
return nil, errors.New("missing realm in bearer auth challenge")
return errors.New("missing realm in bearer auth challenge")
}

authReq, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil)
if err != nil {
return nil, err
return err
}

params := authReq.URL.Query()
Expand Down Expand Up @@ -846,22 +923,22 @@ func (c *dockerClient) getBearerToken(ctx context.Context, challenge challenge,
logrus.Debugf("%s %s", authReq.Method, authReq.URL.Redacted())
res, err := c.client.Do(authReq)
if err != nil {
return nil, err
return err
}
defer res.Body.Close()
if err := httpResponseToError(res, "Requesting bearer token"); err != nil {
return nil, err
return err
}

return newBearerTokenFromHTTPResponseBody(res)
return dest.readFromHTTPResponseBody(res)
}

// newBearerTokenFromHTTPResponseBody parses a http.Response to obtain a bearerToken.
// readFromHTTPResponseBody sets token data by parsing a http.Response.
// The caller is still responsible for ensuring res.Body is closed.
func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error) {
func (bt *bearerToken) readFromHTTPResponseBody(res *http.Response) error {
blob, err := iolimits.ReadAtMost(res.Body, iolimits.MaxAuthTokenBodySize)
if err != nil {
return nil, err
return err
}

var token struct {
Expand All @@ -877,12 +954,10 @@ func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error
if len(bodySample) > bodySampleLength {
bodySample = bodySample[:bodySampleLength]
}
return nil, fmt.Errorf("decoding bearer token (last URL %q, body start %q): %w", res.Request.URL.Redacted(), string(bodySample), err)
return fmt.Errorf("decoding bearer token (last URL %q, body start %q): %w", res.Request.URL.Redacted(), string(bodySample), err)
}

bt := &bearerToken{
token: token.Token,
}
bt.token = token.Token
if bt.token == "" {
bt.token = token.AccessToken
}
Expand All @@ -895,7 +970,7 @@ func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error
token.IssuedAt = time.Now().UTC()
}
bt.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second)
return bt, nil
return nil
}

// detectPropertiesHelper performs the work of detectProperties which executes
Expand Down
10 changes: 6 additions & 4 deletions docker/docker_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func testTokenHTTPResponse(t *testing.T, body string) *http.Response {
}
}

func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) {
func TestBearerTokenReadFromHTTPResponseBody(t *testing.T) {
for _, c := range []struct {
input string
expected *bearerToken // or nil if a failure is expected
Expand All @@ -128,7 +128,8 @@ func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) {
expected: &bearerToken{token: "IAmAToken", expirationTime: time.Unix(1514800802+60, 0)},
},
} {
token, err := newBearerTokenFromHTTPResponseBody(testTokenHTTPResponse(t, c.input))
token := &bearerToken{}
err := token.readFromHTTPResponseBody(testTokenHTTPResponse(t, c.input))
if c.expected == nil {
assert.Error(t, err, c.input)
} else {
Expand All @@ -140,11 +141,12 @@ func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) {
}
}

func TestNewBearerTokenFromHTTPResponseBodyIssuedAtZero(t *testing.T) {
func TestBearerTokenReadFromHTTPResponseBodyIssuedAtZero(t *testing.T) {
zeroTime := time.Time{}.Format(time.RFC3339)
now := time.Now()
tokenBlob := fmt.Sprintf(`{"token":"IAmAToken","expires_in":100,"issued_at":"%s"}`, zeroTime)
token, err := newBearerTokenFromHTTPResponseBody(testTokenHTTPResponse(t, tokenBlob))
token := &bearerToken{}
err := token.readFromHTTPResponseBody(testTokenHTTPResponse(t, tokenBlob))
require.NoError(t, err)
expectedExpiration := now.Add(time.Duration(100) * time.Second)
require.False(t, token.expirationTime.Before(expectedExpiration),
Expand Down