diff --git a/oci/auth/aws/auth.go b/oci/auth/aws/auth.go index 735086d5..4fb43812 100644 --- a/oci/auth/aws/auth.go +++ b/oci/auth/aws/auth.go @@ -21,6 +21,8 @@ import ( "encoding/base64" "errors" "fmt" + "net/http" + "net/url" "regexp" "strings" "sync" @@ -51,8 +53,19 @@ func ParseRegistry(registry string) (accountId, awsEcrRegion string, ok bool) { // Client is a AWS ECR client which can log into the registry and return // authorization information. type Client struct { - config *aws.Config - mu sync.Mutex + config *aws.Config + mu sync.Mutex + proxyURL *url.URL +} + +// Option is a functional option for configuring the client. +type Option func(*Client) + +// WithProxyURL sets the proxy URL for the client. +func WithProxyURL(proxyURL *url.URL) Option { + return func(c *Client) { + c.proxyURL = proxyURL + } } // NewClient creates a new empty ECR client. @@ -60,8 +73,12 @@ type Client struct { // config, return an empty Client. Client.getLoginAuth() loads the default // config if Client.config is nil. This also enables tests to configure the // Client with stub before calling the login method using Client.WithConfig(). -func NewClient() *Client { - return &Client{} +func NewClient(opts ...Option) *Client { + client := &Client{} + for _, opt := range opts { + opt(client) + } + return client } // WithConfig allows setting the client config if it's uninitialized. @@ -87,8 +104,16 @@ func (c *Client) getLoginAuth(ctx context.Context, awsEcrRegion string) (authn.A if c.config != nil { cfg = c.config.Copy() } else { + var confOpts []func(*config.LoadOptions) error + confOpts = append(confOpts, config.WithRegion(awsEcrRegion)) + if c.proxyURL != nil { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = http.ProxyURL(c.proxyURL) + confOpts = append(confOpts, config.WithHTTPClient(&http.Client{Transport: transport})) + } + var err error - cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(awsEcrRegion)) + cfg, err = config.LoadDefaultConfig(ctx, confOpts...) if err != nil { c.mu.Unlock() return authConfig, time.Time{}, fmt.Errorf("failed to load default configuration: %w", err) diff --git a/oci/auth/azure/auth.go b/oci/auth/azure/auth.go index d4ead684..f085ba02 100644 --- a/oci/auth/azure/auth.go +++ b/oci/auth/azure/auth.go @@ -19,6 +19,8 @@ package azure import ( "context" "fmt" + "net/http" + "net/url" "strings" "time" @@ -44,11 +46,26 @@ const defaultCacheExpirationInSeconds = 600 type Client struct { credential azcore.TokenCredential scheme string + proxyURL *url.URL +} + +// Option is a functional option for configuring the client. +type Option func(*Client) + +// WithProxyURL sets the proxy URL for the client. +func WithProxyURL(proxyURL *url.URL) Option { + return func(c *Client) { + c.proxyURL = proxyURL + } } // NewClient creates a new ACR client with default configurations. -func NewClient() *Client { - return &Client{scheme: "https"} +func NewClient(opts ...Option) *Client { + client := &Client{scheme: "https"} + for _, opt := range opts { + opt(client) + } + return client } // WithTokenCredential sets the token credential used by the ACR client. @@ -73,7 +90,14 @@ func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.Au // NOTE: NewDefaultAzureCredential() performs a lot of environment lookup // for creating default token credential. Load it only when it's needed. if c.credential == nil { - cred, err := azidentity.NewDefaultAzureCredential(nil) + opts := &azidentity.DefaultAzureCredentialOptions{} + if c.proxyURL != nil { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = http.ProxyURL(c.proxyURL) + opts.Transport = &http.Client{Transport: transport} + } + + cred, err := azidentity.NewDefaultAzureCredential(opts) if err != nil { return authConfig, time.Time{}, err } @@ -90,7 +114,7 @@ func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.Au } // Obtain ACR access token using exchanger. - ex := newExchanger(registryURL) + ex := newExchanger(registryURL, c.proxyURL) accessToken, err := ex.ExchangeACRAccessToken(string(armToken.Token)) if err != nil { return authConfig, time.Time{}, fmt.Errorf("error exchanging token: %w", err) diff --git a/oci/auth/azure/exchanger.go b/oci/auth/azure/exchanger.go index 9ab07ea4..14c47505 100644 --- a/oci/auth/azure/exchanger.go +++ b/oci/auth/azure/exchanger.go @@ -67,13 +67,15 @@ type acrError struct { type exchanger struct { endpoint string + proxyURL *url.URL } // newExchanger returns an Azure Exchanger for Azure Container Registry with // a given endpoint, for example https://azurecr.io. -func newExchanger(endpoint string) *exchanger { +func newExchanger(endpoint string, proxyURL *url.URL) *exchanger { return &exchanger{ endpoint: endpoint, + proxyURL: proxyURL, } } @@ -87,12 +89,20 @@ func (e *exchanger) ExchangeACRAccessToken(armToken string) (string, error) { } exchangeURL.Path = path.Join(exchangeURL.Path, "oauth2/exchange") + httpClient := &http.Client{} + + if e.proxyURL != nil { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = http.ProxyURL(e.proxyURL) + httpClient.Transport = transport + } + parameters := url.Values{} parameters.Add("grant_type", "access_token") parameters.Add("service", exchangeURL.Hostname()) parameters.Add("access_token", armToken) - resp, err := http.PostForm(exchangeURL.String(), parameters) + resp, err := httpClient.PostForm(exchangeURL.String(), parameters) if err != nil { return "", fmt.Errorf("failed to send token exchange request: %w", err) } diff --git a/oci/auth/azure/exchanger_test.go b/oci/auth/azure/exchanger_test.go index 0c84079f..89209bb3 100644 --- a/oci/auth/azure/exchanger_test.go +++ b/oci/auth/azure/exchanger_test.go @@ -84,7 +84,7 @@ func TestExchanger_ExchangeACRAccessToken(t *testing.T) { srv.Close() }) - ex := newExchanger(srv.URL) + ex := newExchanger(srv.URL, nil /*proxyURL*/) token, err := ex.ExchangeACRAccessToken("some-access-token") g.Expect(err != nil).To(Equal(tt.wantErr)) if tt.statusCode == http.StatusOK { diff --git a/oci/auth/gcp/auth.go b/oci/auth/gcp/auth.go index 52b5f0d6..584f0b25 100644 --- a/oci/auth/gcp/auth.go +++ b/oci/auth/gcp/auth.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "time" @@ -50,11 +51,26 @@ func ValidHost(host string) bool { // authorization information. type Client struct { tokenURL string + proxyURL *url.URL +} + +// Option is a functional option for configuring the client. +type Option func(*Client) + +// WithProxyURL sets the proxy URL for the client. +func WithProxyURL(proxyURL *url.URL) Option { + return func(c *Client) { + c.proxyURL = proxyURL + } } // NewClient creates a new GCR client with default configurations. -func NewClient() *Client { - return &Client{tokenURL: GCP_TOKEN_URL} +func NewClient(opts ...Option) *Client { + client := &Client{tokenURL: GCP_TOKEN_URL} + for _, opt := range opts { + opt(client) + } + return client } // WithTokenURL sets the token URL used by the GCR client. @@ -77,7 +93,14 @@ func (c *Client) getLoginAuth(ctx context.Context) (authn.AuthConfig, time.Time, request.Header.Add("Metadata-Flavor", "Google") - client := &http.Client{} + var transport http.RoundTripper + if c.proxyURL != nil { + t := http.DefaultTransport.(*http.Transport).Clone() + t.Proxy = http.ProxyURL(c.proxyURL) + transport = t + } + + client := &http.Client{Transport: transport} response, err := client.Do(request) if err != nil { return authConfig, time.Time{}, err diff --git a/oci/auth/login/login.go b/oci/auth/login/login.go index e618aa48..9ffaef94 100644 --- a/oci/auth/login/login.go +++ b/oci/auth/login/login.go @@ -81,13 +81,42 @@ type Manager struct { acr *azure.Client } +// Option is a functional option for configuring the manager. +type Option func(*options) + +type options struct { + proxyURL *url.URL +} + +// WithProxyURL sets the proxy URL for the manager. +func WithProxyURL(proxyURL *url.URL) Option { + return func(o *options) { + o.proxyURL = proxyURL + } +} + // NewManager initializes a Manager with default registry clients // configurations. -func NewManager() *Manager { +func NewManager(opts ...Option) *Manager { + var o options + for _, opt := range opts { + opt(&o) + } + + var awsOpts []aws.Option + var gcpOpts []gcp.Option + var azureOpts []azure.Option + + if o.proxyURL != nil { + awsOpts = append(awsOpts, aws.WithProxyURL(o.proxyURL)) + gcpOpts = append(gcpOpts, gcp.WithProxyURL(o.proxyURL)) + azureOpts = append(azureOpts, azure.WithProxyURL(o.proxyURL)) + } + return &Manager{ - ecr: aws.NewClient(), - gcr: gcp.NewClient(), - acr: azure.NewClient(), + ecr: aws.NewClient(awsOpts...), + gcr: gcp.NewClient(gcpOpts...), + acr: azure.NewClient(azureOpts...), } }