diff --git a/oci/auth/login/login.go b/oci/auth/login/login.go index 72d57e79..6b7b92c4 100644 --- a/oci/auth/login/login.go +++ b/oci/auth/login/login.go @@ -113,24 +113,34 @@ func (m *Manager) WithACRClient(c *azure.Client) *Manager { // For generic registry provider, it is no-op. func (m *Manager) Login(ctx context.Context, url string, ref name.Reference, opts ProviderOptions) (authn.Authenticator, error) { log := log.FromContext(ctx) + provider := ImageRegistryProvider(url, ref) + var ( + key string + err error + ) if opts.Cache != nil { - auth, exists, err := getObjectFromCache(opts.Cache, url) + key, err = m.keyFromURL(url, provider) if err != nil { - log.Error(err, "failed to get auth object from cache") - } - if exists { - return auth, nil + log.Error(err, "failed to get cache key") + } else { + auth, exists, err := getObjectFromCache(opts.Cache, key) + if err != nil { + log.Error(err, "failed to get auth object from cache") + } + if exists { + return auth, nil + } } } - switch ImageRegistryProvider(url, ref) { + switch provider { case oci.ProviderAWS: auth, expiresAt, err := m.ecr.LoginWithExpiry(ctx, opts.AwsAutoLogin, url) if err != nil { return nil, err } if opts.Cache != nil { - err := cacheObject(opts.Cache, auth, url, expiresAt) + err := cacheObject(opts.Cache, auth, key, expiresAt) if err != nil { log.Error(err, "failed to cache auth object") } @@ -142,7 +152,7 @@ func (m *Manager) Login(ctx context.Context, url string, ref name.Reference, opt return nil, err } if opts.Cache != nil { - err := cacheObject(opts.Cache, auth, url, expiresAt) + err := cacheObject(opts.Cache, auth, key, expiresAt) if err != nil { log.Error(err, "failed to cache auth object") } @@ -154,7 +164,7 @@ func (m *Manager) Login(ctx context.Context, url string, ref name.Reference, opt return nil, err } if opts.Cache != nil { - err := cacheObject(opts.Cache, auth, url, expiresAt) + err := cacheObject(opts.Cache, auth, key, expiresAt) if err != nil { log.Error(err, "failed to cache auth object") } @@ -198,3 +208,28 @@ func (m *Manager) OIDCLogin(ctx context.Context, registryURL string, opts Provid } return nil, nil } + +// keyFromURL returns a key for the cache based on the URL and provider. +// Use this when you don't want to cache the full URL, +// but instead want to cache based on the provider secific way of identifying +// the authentication principal, i.e. the Domain for AWS and Azure, Project for GCP. +func (m *Manager) keyFromURL(ref string, provider oci.Provider) (string, error) { + if !strings.Contains(ref, "://") { + ref = fmt.Sprintf("//%s", ref) + } + u, err := url.Parse(ref) + if err != nil { + return "", err + } + switch provider { + case oci.ProviderAWS, oci.ProviderAzure: + return u.Host, nil + case oci.ProviderGCP: + paths := strings.Split(u.Path, "/") + if len(paths) > 1 { + return fmt.Sprintf("%s/%s", u.Host, paths[1]), nil + } + return u.Host, nil + } + return "", nil +} diff --git a/oci/auth/login/login_test.go b/oci/auth/login/login_test.go index e7c64bcb..2aebdfab 100644 --- a/oci/auth/login/login_test.go +++ b/oci/auth/login/login_test.go @@ -261,11 +261,13 @@ func TestLogin_WithCache(t *testing.T) { if tt.wantErr { g.Expect(err).To(HaveOccurred()) } else { - auth, exists, err := getObjectFromCache(cache, image) + key, err := mgr.keyFromURL(image, ImageRegistryProvider(image, ref)) + g.Expect(err).ToNot(HaveOccurred()) + auth, exists, err := getObjectFromCache(cache, key) g.Expect(err).ToNot(HaveOccurred()) g.Expect(exists).To(BeTrue()) g.Expect(auth).ToNot(BeNil()) - obj, _, err := cache.GetByKey(image) + obj, _, err := cache.GetByKey(key) g.Expect(err).ToNot(HaveOccurred()) expiration, err := cache.GetExpiration(obj) g.Expect(err).ToNot(HaveOccurred()) @@ -275,3 +277,34 @@ func TestLogin_WithCache(t *testing.T) { }) } } + +func Test_keyFromURL(t *testing.T) { + tests := []struct { + name string + image string + want string + }{ + {"gcr", "gcr.io/foo/bar:v1", "gcr.io/foo"}, + {"ecr", "012345678901.dkr.ecr.us-east-1.amazonaws.com/foo:v1", "012345678901.dkr.ecr.us-east-1.amazonaws.com"}, + {"ecr-root", "012345678901.dkr.ecr.us-east-1.amazonaws.com", "012345678901.dkr.ecr.us-east-1.amazonaws.com"}, + {"ecr-root with slash", "012345678901.dkr.ecr.us-east-1.amazonaws.com/", "012345678901.dkr.ecr.us-east-1.amazonaws.com"}, + {"gcr", "gcr.io/foo/bar:v1", "gcr.io/foo"}, + {"gcr-root", "gcr.io", "gcr.io"}, + {"acr", "foo.azurecr.io/bar:v1", "foo.azurecr.io"}, + {"acr-root", "foo.azurecr.io", "foo.azurecr.io"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + // Trim suffix to allow parsing it as reference without modifying + // the given image address. + ref, err := name.ParseReference(strings.TrimSuffix(tt.image, "/")) + g.Expect(err).ToNot(HaveOccurred()) + key, err := NewManager().keyFromURL(tt.image, ImageRegistryProvider(tt.image, ref)) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(key).To(Equal(tt.want)) + }) + } +}