Skip to content

Commit

Permalink
fix: add with_universe_domain (#1408)
Browse files Browse the repository at this point in the history
* fix: add with_universe_domain to service account and external cred

* update

* update

* chore: refresh sys test cred
  • Loading branch information
arithmetic1728 authored Dec 2, 2023
1 parent 39eb287 commit 505910c
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 4 deletions.
16 changes: 16 additions & 0 deletions google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,22 @@ def with_token_uri(self, token_uri):
new_cred._metrics_options = self._metrics_options
return new_cred

def with_universe_domain(self, universe_domain):
"""Create a copy of these credentials with the given universe domain.
Args:
universe_domain (str): The universe domain value.
Returns:
google.auth.external_account.Credentials: A new credentials
instance.
"""
kwargs = self._constructor_args()
kwargs.update(universe_domain=universe_domain)
new_cred = self.__class__(**kwargs)
new_cred._metrics_options = self._metrics_options
return new_cred

def _initialize_impersonated_credentials(self):
"""Generates an impersonated credentials.
Expand Down
48 changes: 48 additions & 0 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

# The Google OAuth 2.0 token endpoint. Used for authorized user credentials.
_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"


class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaProject):
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
enable_reauth_refresh=False,
granted_scopes=None,
trust_boundary=None,
universe_domain=_DEFAULT_UNIVERSE_DOMAIN,
):
"""
Args:
Expand Down Expand Up @@ -126,6 +128,9 @@ def __init__(
granted_scopes (Optional[Sequence[str]]): The scopes that were consented/granted by the user.
This could be different from the requested scopes and it could be empty if granted
and requested scopes were same.
trust_boundary (str): String representation of trust boundary meta.
universe_domain (Optional[str]): The universe domain. The default
universe domain is googleapis.com.
"""
super(Credentials, self).__init__()
self.token = token
Expand All @@ -143,6 +148,7 @@ def __init__(
self.refresh_handler = refresh_handler
self._enable_reauth_refresh = enable_reauth_refresh
self._trust_boundary = trust_boundary
self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN

def __getstate__(self):
"""A __getstate__ method must exist for the __setstate__ to be called
Expand Down Expand Up @@ -273,6 +279,7 @@ def with_quota_project(self, quota_project_id):
rapt_token=self.rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=self._universe_domain,
)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
Expand All @@ -292,13 +299,52 @@ def with_token_uri(self, token_uri):
rapt_token=self.rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=self._universe_domain,
)

def with_universe_domain(self, universe_domain):
"""Create a copy of the credential with the given universe domain.
Args:
universe_domain (str): The universe domain value.
Returns:
google.oauth2.credentials.Credentials: A new credentials instance.
"""

return self.__class__(
self.token,
refresh_token=self.refresh_token,
id_token=self.id_token,
token_uri=self._token_uri,
client_id=self.client_id,
client_secret=self.client_secret,
scopes=self.scopes,
default_scopes=self.default_scopes,
granted_scopes=self.granted_scopes,
quota_project_id=self.quota_project_id,
rapt_token=self.rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=universe_domain,
)

def _metric_header_for_usage(self):
return metrics.CRED_TYPE_USER

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
if self._universe_domain != _DEFAULT_UNIVERSE_DOMAIN:
raise exceptions.RefreshError(
"User credential refresh is only supported in the default "
"googleapis.com universe domain, but the current universe "
"domain is {}. If you created the credential with an access "
"token, it's likely that the provided token is expired now, "
"please update your code with a valid token.".format(
self._universe_domain
)
)

scopes = self._scopes if self._scopes is not None else self._default_scopes
# Use refresh handler if available and no refresh token is
# available. This is useful in general when tokens are obtained by calling
Expand Down Expand Up @@ -428,6 +474,7 @@ def from_authorized_user_info(cls, info, scopes=None):
expiry=expiry,
rapt_token=info.get("rapt_token"), # may not exist
trust_boundary=info.get("trust_boundary"), # may not exist
universe_domain=info.get("universe_domain"), # may not exist
)

@classmethod
Expand Down Expand Up @@ -471,6 +518,7 @@ def to_json(self, strip=None):
"client_secret": self.client_secret,
"scopes": self.scopes,
"rapt_token": self.rapt_token,
"universe_domain": self._universe_domain,
}
if self.expiry: # flatten expiry timestamp
prep["expiry"] = self.expiry.isoformat() + "Z"
Expand Down
21 changes: 17 additions & 4 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,7 @@ def __init__(
self._quota_project_id = quota_project_id
self._token_uri = token_uri
self._always_use_jwt_access = always_use_jwt_access
if not universe_domain:
self._universe_domain = _DEFAULT_UNIVERSE_DOMAIN
else:
self._universe_domain = universe_domain
self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN

if universe_domain != _DEFAULT_UNIVERSE_DOMAIN:
self._always_use_jwt_access = True
Expand Down Expand Up @@ -328,6 +325,22 @@ def with_always_use_jwt_access(self, always_use_jwt_access):
cred._always_use_jwt_access = always_use_jwt_access
return cred

def with_universe_domain(self, universe_domain):
"""Create a copy of these credentials with the given universe domain.
Args:
universe_domain (str): The universe domain value.
Returns:
google.auth.service_account.Credentials: A new credentials
instance.
"""
cred = self._make_copy()
cred._universe_domain = universe_domain
if universe_domain != _DEFAULT_UNIVERSE_DOMAIN:
cred._always_use_jwt_access = True
return cred

def with_subject(self, subject):
"""Create a copy of these credentials with the specified subject.
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
18 changes: 18 additions & 0 deletions tests/oauth2/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def test_invalid_refresh_handler(self):

assert excinfo.match("The provided refresh_handler is not a callable or None.")

def test_refresh_with_non_default_universe_domain(self):
creds = credentials.Credentials(
token="token", universe_domain="dummy_universe.com"
)
with pytest.raises(exceptions.RefreshError) as excinfo:
creds.refresh(mock.Mock())

assert excinfo.match(
"refresh is only supported in the default googleapis.com universe domain"
)

@mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
"google.auth._helpers.utcnow",
Expand Down Expand Up @@ -774,6 +785,12 @@ def test_with_quota_project(self):
creds.apply(headers)
assert "x-goog-user-project" in headers

def test_with_universe_domain(self):
creds = credentials.Credentials(token="token")
assert creds.universe_domain == "googleapis.com"
new_creds = creds.with_universe_domain("dummy_universe.com")
assert new_creds.universe_domain == "dummy_universe.com"

def test_with_token_uri(self):
info = AUTH_USER_INFO.copy()

Expand Down Expand Up @@ -868,6 +885,7 @@ def test_to_json(self):
assert json_asdict.get("scopes") == creds.scopes
assert json_asdict.get("client_secret") == creds.client_secret
assert json_asdict.get("expiry") == info["expiry"]
assert json_asdict.get("universe_domain") == creds.universe_domain

# Test with a `strip` arg
json_output = creds.to_json(strip=["client_secret"])
Expand Down
11 changes: 11 additions & 0 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ def test_with_token_uri(self):
creds_with_new_token_uri = credentials.with_token_uri(new_token_uri)
assert creds_with_new_token_uri._token_uri == new_token_uri

def test_with_universe_domain(self):
credentials = self.make_credentials()

new_credentials = credentials.with_universe_domain("dummy_universe.com")
assert new_credentials.universe_domain == "dummy_universe.com"
assert new_credentials._always_use_jwt_access

new_credentials = credentials.with_universe_domain("googleapis.com")
assert new_credentials.universe_domain == "googleapis.com"
assert not new_credentials._always_use_jwt_access

def test__with_always_use_jwt_access(self):
credentials = self.make_credentials()
assert not credentials._always_use_jwt_access
Expand Down
5 changes: 5 additions & 0 deletions tests/test_external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,11 @@ def test_universe_domain(self):
credentials = self.make_credentials()
assert credentials.universe_domain == external_account._DEFAULT_UNIVERSE_DOMAIN

def test_with_universe_domain(self):
credentials = self.make_credentials()
new_credentials = credentials.with_universe_domain("dummy_universe.com")
assert new_credentials.universe_domain == "dummy_universe.com"

def test_info_workforce_pool(self):
credentials = self.make_workforce_pool_credentials(
workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT
Expand Down

0 comments on commit 505910c

Please sign in to comment.