diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 28b004c5f..e7fed8695 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -415,6 +415,22 @@ def with_token_uri(self, token_uri): new_cred._metrics_options = self._metrics_options return new_cred + def with_universe_domain(self, universe_domain): + """Create a copy of these credentials with the given universe domain. + + Args: + universe_domain (str): The universe domain value. + + Returns: + google.auth.external_account.Credentials: A new credentials + instance. + """ + kwargs = self._constructor_args() + kwargs.update(universe_domain=universe_domain) + new_cred = self.__class__(**kwargs) + new_cred._metrics_options = self._metrics_options + return new_cred + def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index ae204b45a..7e2173ebe 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -49,6 +49,7 @@ # The Google OAuth 2.0 token endpoint. Used for authorized user credentials. _GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" +_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaProject): @@ -85,6 +86,7 @@ def __init__( enable_reauth_refresh=False, granted_scopes=None, trust_boundary=None, + universe_domain=_DEFAULT_UNIVERSE_DOMAIN, ): """ Args: @@ -126,6 +128,9 @@ def __init__( granted_scopes (Optional[Sequence[str]]): The scopes that were consented/granted by the user. This could be different from the requested scopes and it could be empty if granted and requested scopes were same. + trust_boundary (str): String representation of trust boundary meta. + universe_domain (Optional[str]): The universe domain. The default + universe domain is googleapis.com. """ super(Credentials, self).__init__() self.token = token @@ -143,6 +148,7 @@ def __init__( self.refresh_handler = refresh_handler self._enable_reauth_refresh = enable_reauth_refresh self._trust_boundary = trust_boundary + self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called @@ -273,6 +279,7 @@ def with_quota_project(self, quota_project_id): rapt_token=self.rapt_token, enable_reauth_refresh=self._enable_reauth_refresh, trust_boundary=self._trust_boundary, + universe_domain=self._universe_domain, ) @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) @@ -292,6 +299,34 @@ def with_token_uri(self, token_uri): rapt_token=self.rapt_token, enable_reauth_refresh=self._enable_reauth_refresh, trust_boundary=self._trust_boundary, + universe_domain=self._universe_domain, + ) + + def with_universe_domain(self, universe_domain): + """Create a copy of the credential with the given universe domain. + + Args: + universe_domain (str): The universe domain value. + + Returns: + google.oauth2.credentials.Credentials: A new credentials instance. + """ + + return self.__class__( + self.token, + refresh_token=self.refresh_token, + id_token=self.id_token, + token_uri=self._token_uri, + client_id=self.client_id, + client_secret=self.client_secret, + scopes=self.scopes, + default_scopes=self.default_scopes, + granted_scopes=self.granted_scopes, + quota_project_id=self.quota_project_id, + rapt_token=self.rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, + trust_boundary=self._trust_boundary, + universe_domain=universe_domain, ) def _metric_header_for_usage(self): @@ -299,6 +334,17 @@ def _metric_header_for_usage(self): @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): + if self._universe_domain != _DEFAULT_UNIVERSE_DOMAIN: + raise exceptions.RefreshError( + "User credential refresh is only supported in the default " + "googleapis.com universe domain, but the current universe " + "domain is {}. If you created the credential with an access " + "token, it's likely that the provided token is expired now, " + "please update your code with a valid token.".format( + self._universe_domain + ) + ) + scopes = self._scopes if self._scopes is not None else self._default_scopes # Use refresh handler if available and no refresh token is # available. This is useful in general when tokens are obtained by calling @@ -428,6 +474,7 @@ def from_authorized_user_info(cls, info, scopes=None): expiry=expiry, rapt_token=info.get("rapt_token"), # may not exist trust_boundary=info.get("trust_boundary"), # may not exist + universe_domain=info.get("universe_domain"), # may not exist ) @classmethod @@ -471,6 +518,7 @@ def to_json(self, strip=None): "client_secret": self.client_secret, "scopes": self.scopes, "rapt_token": self.rapt_token, + "universe_domain": self._universe_domain, } if self.expiry: # flatten expiry timestamp prep["expiry"] = self.expiry.isoformat() + "Z" diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 2b5e0d390..68db41af4 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -182,10 +182,7 @@ def __init__( self._quota_project_id = quota_project_id self._token_uri = token_uri self._always_use_jwt_access = always_use_jwt_access - if not universe_domain: - self._universe_domain = _DEFAULT_UNIVERSE_DOMAIN - else: - self._universe_domain = universe_domain + self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN if universe_domain != _DEFAULT_UNIVERSE_DOMAIN: self._always_use_jwt_access = True @@ -328,6 +325,22 @@ def with_always_use_jwt_access(self, always_use_jwt_access): cred._always_use_jwt_access = always_use_jwt_access return cred + def with_universe_domain(self, universe_domain): + """Create a copy of these credentials with the given universe domain. + + Args: + universe_domain (str): The universe domain value. + + Returns: + google.auth.service_account.Credentials: A new credentials + instance. + """ + cred = self._make_copy() + cred._universe_domain = universe_domain + if universe_domain != _DEFAULT_UNIVERSE_DOMAIN: + cred._always_use_jwt_access = True + return cred + def with_subject(self, subject): """Create a copy of these credentials with the specified subject. diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index 2f0c09c0e..e4c71790e 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index d265d22ed..054f79405 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -122,6 +122,17 @@ def test_invalid_refresh_handler(self): assert excinfo.match("The provided refresh_handler is not a callable or None.") + def test_refresh_with_non_default_universe_domain(self): + creds = credentials.Credentials( + token="token", universe_domain="dummy_universe.com" + ) + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(mock.Mock()) + + assert excinfo.match( + "refresh is only supported in the default googleapis.com universe domain" + ) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", @@ -774,6 +785,12 @@ def test_with_quota_project(self): creds.apply(headers) assert "x-goog-user-project" in headers + def test_with_universe_domain(self): + creds = credentials.Credentials(token="token") + assert creds.universe_domain == "googleapis.com" + new_creds = creds.with_universe_domain("dummy_universe.com") + assert new_creds.universe_domain == "dummy_universe.com" + def test_with_token_uri(self): info = AUTH_USER_INFO.copy() @@ -868,6 +885,7 @@ def test_to_json(self): assert json_asdict.get("scopes") == creds.scopes assert json_asdict.get("client_secret") == creds.client_secret assert json_asdict.get("expiry") == info["expiry"] + assert json_asdict.get("universe_domain") == creds.universe_domain # Test with a `strip` arg json_output = creds.to_json(strip=["client_secret"]) diff --git a/tests/oauth2/test_service_account.py b/tests/oauth2/test_service_account.py index f9e0c1186..ebaab05fc 100644 --- a/tests/oauth2/test_service_account.py +++ b/tests/oauth2/test_service_account.py @@ -205,6 +205,17 @@ def test_with_token_uri(self): creds_with_new_token_uri = credentials.with_token_uri(new_token_uri) assert creds_with_new_token_uri._token_uri == new_token_uri + def test_with_universe_domain(self): + credentials = self.make_credentials() + + new_credentials = credentials.with_universe_domain("dummy_universe.com") + assert new_credentials.universe_domain == "dummy_universe.com" + assert new_credentials._always_use_jwt_access + + new_credentials = credentials.with_universe_domain("googleapis.com") + assert new_credentials.universe_domain == "googleapis.com" + assert not new_credentials._always_use_jwt_access + def test__with_always_use_jwt_access(self): credentials = self.make_credentials() assert not credentials._always_use_jwt_access diff --git a/tests/test_external_account.py b/tests/test_external_account.py index 6f6e18b2c..5225dcf34 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -505,6 +505,11 @@ def test_universe_domain(self): credentials = self.make_credentials() assert credentials.universe_domain == external_account._DEFAULT_UNIVERSE_DOMAIN + def test_with_universe_domain(self): + credentials = self.make_credentials() + new_credentials = credentials.with_universe_domain("dummy_universe.com") + assert new_credentials.universe_domain == "dummy_universe.com" + def test_info_workforce_pool(self): credentials = self.make_workforce_pool_credentials( workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT