diff --git a/pkg/credentials/sts_web_identity.go b/pkg/credentials/sts_web_identity.go index 596d95152..f1c76c78e 100644 --- a/pkg/credentials/sts_web_identity.go +++ b/pkg/credentials/sts_web_identity.go @@ -25,6 +25,7 @@ import ( "io" "net/http" "net/url" + "os" "strconv" "strings" "time" @@ -85,29 +86,59 @@ type STSWebIdentity struct { // assuming. RoleARN string + // Policy is the policy where the credentials should be limited too. + Policy string + // roleSessionName is the identifier for the assumed role session. roleSessionName string } // NewSTSWebIdentity returns a pointer to a new // Credentials object wrapping the STSWebIdentity. -func NewSTSWebIdentity(stsEndpoint string, getWebIDTokenExpiry func() (*WebIdentityToken, error)) (*Credentials, error) { +func NewSTSWebIdentity(stsEndpoint string, getWebIDTokenExpiry func() (*WebIdentityToken, error), opts ...func(*STSWebIdentity)) (*Credentials, error) { if stsEndpoint == "" { return nil, errors.New("STS endpoint cannot be empty") } if getWebIDTokenExpiry == nil { return nil, errors.New("Web ID token and expiry retrieval function should be defined") } - return New(&STSWebIdentity{ + i := &STSWebIdentity{ Client: &http.Client{ Transport: http.DefaultTransport, }, STSEndpoint: stsEndpoint, GetWebIDTokenExpiry: getWebIDTokenExpiry, - }), nil + } + for _, o := range opts { + o(i) + } + return New(i), nil +} + +// NewKubernetesIdentity returns a pointer to a new +// Credentials object using the Kubernetes service account +func NewKubernetesIdentity(stsEndpoint string, opts ...func(*STSWebIdentity)) (*Credentials, error) { + return NewSTSWebIdentity(stsEndpoint, func() (*WebIdentityToken, error) { + token, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/token") + if err != nil { + return nil, err + } + + return &WebIdentityToken{ + Token: string(token), + }, nil + }, opts...) +} + +// WithPolicy option will enforce that the returned credentials +// will be scoped down to the specified policy +func WithPolicy(policy string) func(*STSWebIdentity) { + return func(i *STSWebIdentity) { + i.Policy = policy + } } -func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSessionName string, +func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSessionName string, policy string, getWebIDTokenExpiry func() (*WebIdentityToken, error), ) (AssumeRoleWithWebIdentityResponse, error) { idToken, err := getWebIDTokenExpiry() @@ -133,6 +164,9 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession if idToken.Expiry > 0 { v.Set("DurationSeconds", fmt.Sprintf("%d", idToken.Expiry)) } + if policy != "" { + v.Set("Policy", policy) + } v.Set("Version", STSVersion) u, err := url.Parse(endpoint) @@ -183,7 +217,7 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSWebIdentity) Retrieve() (Value, error) { - a, err := getWebIdentityCredentials(m.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.GetWebIDTokenExpiry) + a, err := getWebIdentityCredentials(m.Client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) if err != nil { return Value{}, err }