diff --git a/google/auth/_credentials_base.py b/google/auth/_credentials_base.py index 64d5ce34b..41a557bf1 100644 --- a/google/auth/_credentials_base.py +++ b/google/auth/_credentials_base.py @@ -16,11 +16,13 @@ """Interface for base credentials.""" import abc +from typing import Optional from google.auth import _helpers +from google.auth.transport.requests import Request -class _BaseCredentials(metaclass=abc.ABCMeta): +class BaseCredentials(metaclass=abc.ABCMeta): """Base class for all credentials. All credentials have a :attr:`token` that is used for authentication and @@ -44,10 +46,10 @@ class _BaseCredentials(metaclass=abc.ABCMeta): """ def __init__(self): - self.token = None + self.token: Optional[str] = None @abc.abstractmethod - def refresh(self, request): + def refresh(self, request: Request) -> None: """Refreshes the access token. Args: @@ -62,14 +64,18 @@ def refresh(self, request): # (pylint doesn't recognize that this is abstract) raise NotImplementedError("Refresh must be implemented") - def _apply(self, headers, token=None): + def _apply(self, headers: dict[str, str], token: Optional[str] = None): """Apply the token to the authentication header. Args: - headers (Mapping): The HTTP request headers. + headers (dict[str, str]): The HTTP request headers. token (Optional[str]): If specified, overrides the current access token. """ - headers["authorization"] = "Bearer {}".format( - _helpers.from_bytes(token or self.token) - ) + if token is not None: + value = token + elif self.token is not None: + value = self.token + else: + assert False, "token must be set" + headers["authorization"] = "Bearer {}".format(_helpers.from_bytes(value)) diff --git a/google/auth/_default.py b/google/auth/_default.py index 7bbcf8591..50736b278 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -21,10 +21,12 @@ import json import logging import os +from typing import Optional, Sequence import warnings from google.auth import environment_vars from google.auth import exceptions +from google.auth.credentials import Credentials import google.auth.transport._http_client _LOGGER = logging.getLogger(__name__) @@ -77,8 +79,12 @@ def _warn_about_problematic_credentials(credentials): def load_credentials_from_file( - filename, scopes=None, default_scopes=None, quota_project_id=None, request=None -): + filename: str, + scopes: Optional[Sequence[str]] = None, + default_scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + request: Optional[google.auth.transport.Request] = None, +) -> tuple[Credentials, Optional[str]]: """Loads Google credentials from a file. The credentials file must be a service account key, stored authorized @@ -173,7 +179,12 @@ def load_credentials_from_dict( def _load_credentials_from_info( - filename, info, scopes, default_scopes, quota_project_id, request + filename: str, + info, + scopes: Optional[Sequence[str]], + default_scopes: Optional[Sequence[str]], + quota_project_id: Optional[str], + request: Optional[google.auth.transport.Request], ): from google.auth.credentials import CredentialsWithQuotaProject @@ -508,8 +519,8 @@ def _get_gdch_service_account_credentials(filename, info): from google.oauth2 import gdch_credentials try: - credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info( - info + credentials = ( + gdch_credentials.ServiceAccountCredentials.from_service_account_info(info) ) except ValueError as caught_exc: msg = "Failed to load GDCH service account credentials from {}".format(filename) @@ -540,7 +551,12 @@ def _apply_quota_project_id(credentials, quota_project_id): return credentials -def default(scopes=None, request=None, quota_project_id=None, default_scopes=None): +def default( + scopes: Optional[Sequence[str]] = None, + request: Optional[google.auth.transport.Request] = None, + quota_project_id: Optional[str] = None, + default_scopes: Optional[Sequence[str]] = None, +) -> tuple[Credentials, Optional[str]]: """Gets the default credentials for the current environment. `Application Default Credentials`_ provides an easy way to obtain diff --git a/google/auth/_refresh_worker.py b/google/auth/_refresh_worker.py index 674032d84..96532905f 100644 --- a/google/auth/_refresh_worker.py +++ b/google/auth/_refresh_worker.py @@ -16,7 +16,9 @@ import logging import threading +from google.auth.credentials import Credentials import google.auth.exceptions as e +from google.auth.transport import Request _LOGGER = logging.getLogger(__name__) @@ -32,7 +34,7 @@ def __init__(self): self._worker = None self._lock = threading.Lock() # protects access to worker threads. - def start_refresh(self, cred, request): + def start_refresh(self, cred: Credentials, request: Request) -> bool: """Starts a refresh thread for the given credentials. The credentials are refreshed using the request parameter. request and cred MUST not be None @@ -59,10 +61,10 @@ def start_refresh(self, cred, request): self._worker.start() return True - def clear_error(self): + def clear_error(self) -> None: + """ + Removes any errors that were stored from previous background refreshes. """ - Removes any errors that were stored from previous background refreshes. - """ with self._lock: if self._worker: self._worker._error_info = None diff --git a/google/auth/aio/credentials.py b/google/auth/aio/credentials.py index 3bc6a5a67..712c3009b 100644 --- a/google/auth/aio/credentials.py +++ b/google/auth/aio/credentials.py @@ -15,13 +15,12 @@ """Interfaces for asynchronous credentials.""" - from google.auth import _helpers from google.auth import exceptions -from google.auth._credentials_base import _BaseCredentials +from google.auth._credentials_base import BaseCredentials -class Credentials(_BaseCredentials): +class Credentials(BaseCredentials): """Base class for all asynchronous credentials. All credentials have a :attr:`token` that is used for authentication and diff --git a/google/auth/credentials.py b/google/auth/credentials.py index 2c67e0443..0a4676187 100644 --- a/google/auth/credentials.py +++ b/google/auth/credentials.py @@ -16,19 +16,36 @@ """Interfaces for credentials.""" import abc +import datetime from enum import Enum import os +from typing import Mapping, Optional, Self, Sequence from google.auth import _helpers, environment_vars from google.auth import exceptions from google.auth import metrics -from google.auth._credentials_base import _BaseCredentials +from google.auth._credentials_base import BaseCredentials from google.auth._refresh_worker import RefreshThreadManager +from google.auth.crypt import Signer +from google.auth.transport import Request DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" -class Credentials(_BaseCredentials): +class TokenState(Enum): + """ + Tracks the state of a token. + FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry. + STALE: The token is close to expired, and should be refreshed. The token can be used normally. + INVALID: The token is expired or invalid. The token cannot be used for a normal operation. + """ + + FRESH = 1 + STALE = 2 + INVALID = 3 + + +class Credentials(BaseCredentials): """Base class for all credentials. All credentials have a :attr:`token` that is used for authentication and @@ -50,24 +67,24 @@ class Credentials(_BaseCredentials): def __init__(self): super(Credentials, self).__init__() - self.expiry = None + self.expiry: Optional[datetime.datetime] = None """Optional[datetime]: When the token expires and is no longer valid. If this is None, the token is assumed to never expire.""" - self._quota_project_id = None + self._quota_project_id: Optional[str] = None """Optional[str]: Project to use for quota and billing purposes.""" - self._trust_boundary = None + self._trust_boundary: Optional[dict[str, object]] = None """Optional[dict]: Cache of a trust boundary response which has a list of allowed regions and an encoded string representation of credentials trust boundary.""" - self._universe_domain = DEFAULT_UNIVERSE_DOMAIN + self._universe_domain: Optional[str] = DEFAULT_UNIVERSE_DOMAIN """Optional[str]: The universe domain value, default is googleapis.com """ - self._use_non_blocking_refresh = False - self._refresh_worker = RefreshThreadManager() + self._use_non_blocking_refresh: bool = False + self._refresh_worker: RefreshThreadManager = RefreshThreadManager() @property - def expired(self): + def expired(self) -> bool: """Checks if the credentials are expired. Note that credentials can be invalid but not expired because @@ -85,7 +102,7 @@ def expired(self): return _helpers.utcnow() >= skewed_expiry @property - def valid(self): + def valid(self) -> bool: """Checks the validity of the credentials. This is True if the credentials have a :attr:`token` and the token @@ -97,7 +114,7 @@ def valid(self): return self.token is not None and not self.expired @property - def token_state(self): + def token_state(self) -> TokenState: """ See `:obj:`TokenState` """ @@ -119,16 +136,16 @@ def token_state(self): return TokenState.FRESH @property - def quota_project_id(self): + def quota_project_id(self) -> Optional[str]: """Project to use for quota and billing purposes.""" return self._quota_project_id @property - def universe_domain(self): + def universe_domain(self) -> Optional[str]: """The universe domain value.""" return self._universe_domain - def get_cred_info(self): + def get_cred_info(self) -> Optional[Mapping[str, str]]: """The credential information JSON. The credential information will be added to auth related error messages @@ -140,7 +157,7 @@ def get_cred_info(self): return None @abc.abstractmethod - def refresh(self, request): + def refresh(self, request: Request) -> None: """Refreshes the access token. Args: @@ -155,7 +172,7 @@ def refresh(self, request): # (pylint doesn't recognize that this is abstract) raise NotImplementedError("Refresh must be implemented") - def _metric_header_for_usage(self): + def _metric_header_for_usage(self) -> Optional[str]: """The x-goog-api-client header for token usage metric. This header will be added to the API service requests in before_request @@ -170,7 +187,7 @@ def _metric_header_for_usage(self): """ return None - def apply(self, headers, token=None): + def apply(self, headers: dict[str, str], token: Optional[str] = None) -> None: """Apply the token to the authentication header. Args: @@ -197,11 +214,11 @@ def apply(self, headers, token=None): if self.quota_project_id: headers["x-goog-user-project"] = self.quota_project_id - def _blocking_refresh(self, request): + def _blocking_refresh(self, request: Request) -> None: if not self.valid: self.refresh(request) - def _non_blocking_refresh(self, request): + def _non_blocking_refresh(self, request: Request) -> None: use_blocking_refresh_fallback = False if self.token_state == TokenState.STALE: @@ -216,7 +233,9 @@ def _non_blocking_refresh(self, request): # background thread. self._refresh_worker.clear_error() - def before_request(self, request, method, url, headers): + def before_request( + self, request: Request, method: str, url: str, headers: dict[str, str] + ) -> None: """Performs credential-specific before request logic. Refreshes the credentials if necessary, then calls :meth:`apply` to @@ -241,14 +260,14 @@ def before_request(self, request, method, url, headers): metrics.add_metric_header(headers, self._metric_header_for_usage()) self.apply(headers) - def with_non_blocking_refresh(self): + def with_non_blocking_refresh(self) -> None: self._use_non_blocking_refresh = True class CredentialsWithQuotaProject(Credentials): """Abstract base for credentials supporting ``with_quota_project`` factory""" - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str) -> Credentials: """Returns a copy of these credentials with a modified quota project. Args: @@ -260,7 +279,7 @@ def with_quota_project(self, quota_project_id): """ raise NotImplementedError("This credential does not support quota project.") - def with_quota_project_from_environment(self): + def with_quota_project_from_environment(self) -> Credentials: quota_from_env = os.environ.get(environment_vars.GOOGLE_CLOUD_QUOTA_PROJECT) if quota_from_env: return self.with_quota_project(quota_from_env) @@ -270,7 +289,7 @@ def with_quota_project_from_environment(self): class CredentialsWithTokenUri(Credentials): """Abstract base for credentials supporting ``with_token_uri`` factory""" - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> Credentials: """Returns a copy of these credentials with a modified token uri. Args: @@ -285,7 +304,7 @@ def with_token_uri(self, token_uri): class CredentialsWithUniverseDomain(Credentials): """Abstract base for credentials supporting ``with_universe_domain`` factory""" - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str) -> Credentials: """Returns a copy of these credentials with a modified universe domain. Args: @@ -307,21 +326,21 @@ class AnonymousCredentials(Credentials): """ @property - def expired(self): + def expired(self) -> bool: """Returns `False`, anonymous credentials never expire.""" return False @property - def valid(self): + def valid(self) -> bool: """Returns `True`, anonymous credentials are always valid.""" return True - def refresh(self, request): + def refresh(self, request: Request) -> None: """Raises :class:``InvalidOperation``, anonymous credentials cannot be refreshed.""" raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.") - def apply(self, headers, token=None): + def apply(self, headers: dict[str, str], token: Optional[str] = None) -> None: """Anonymous credentials do nothing to the request. The optional ``token`` argument is not supported. @@ -332,7 +351,9 @@ def apply(self, headers, token=None): if token is not None: raise exceptions.InvalidValue("Anonymous credentials don't support tokens.") - def before_request(self, request, method, url, headers): + def before_request( + self, request: Request, method: str, url: str, headers: dict[str, str] + ) -> None: """Anonymous credentials do nothing to the request.""" @@ -367,26 +388,26 @@ class ReadOnlyScoped(metaclass=abc.ABCMeta): def __init__(self): super(ReadOnlyScoped, self).__init__() - self._scopes = None - self._default_scopes = None + self._scopes: Optional[Sequence[str]] = None + self._default_scopes: Optional[Sequence[str]] = None @property - def scopes(self): + def scopes(self) -> Optional[Sequence[str]]: """Sequence[str]: the credentials' current set of scopes.""" return self._scopes @property - def default_scopes(self): + def default_scopes(self) -> Optional[Sequence[str]]: """Sequence[str]: the credentials' current set of default scopes.""" return self._default_scopes - @abc.abstractproperty - def requires_scopes(self): - """True if these credentials require scopes to obtain an access token. - """ + @property + @abc.abstractmethod + def requires_scopes(self) -> bool: + """True if these credentials require scopes to obtain an access token.""" return False - def has_scopes(self, scopes): + def has_scopes(self, scopes: Sequence[str]) -> bool: """Checks if the credentials have the given scopes. .. warning: This method is not guaranteed to be accurate if the @@ -434,7 +455,9 @@ class Scoped(ReadOnlyScoped): """ @abc.abstractmethod - def with_scopes(self, scopes, default_scopes=None): + def with_scopes( + self, scopes: Sequence[str], default_scopes: Optional[Sequence[str]] = None + ) -> Self: """Create a copy of these credentials with the specified scopes. Args: @@ -449,7 +472,11 @@ def with_scopes(self, scopes, default_scopes=None): raise NotImplementedError("This class does not require scoping.") -def with_scopes_if_required(credentials, scopes, default_scopes=None): +def with_scopes_if_required( + credentials: Credentials, + scopes: Sequence[str], + default_scopes: Optional[Sequence[str]] = None, +) -> Credentials: """Creates a copy of the credentials with scopes if scoping is required. This helper function is useful when you do not know (or care to know) the @@ -481,7 +508,7 @@ class Signing(metaclass=abc.ABCMeta): """Interface for credentials that can cryptographically sign messages.""" @abc.abstractmethod - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: """Signs the given message. Args: @@ -494,29 +521,18 @@ def sign_bytes(self, message): # (pylint doesn't recognize that this is abstract) raise NotImplementedError("Sign bytes must be implemented.") - @abc.abstractproperty - def signer_email(self): + @property + @abc.abstractmethod + def signer_email(self) -> Optional[str]: """Optional[str]: An email address that identifies the signer.""" # pylint: disable=missing-raises-doc # (pylint doesn't recognize that this is abstract) raise NotImplementedError("Signer email must be implemented.") - @abc.abstractproperty - def signer(self): + @property + @abc.abstractmethod + def signer(self) -> Signer: """google.auth.crypt.Signer: The signer used to sign bytes.""" # pylint: disable=missing-raises-doc # (pylint doesn't recognize that this is abstract) raise NotImplementedError("Signer must be implemented.") - - -class TokenState(Enum): - """ - Tracks the state of a token. - FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry. - STALE: The token is close to expired, and should be refreshed. The token can be used normally. - INVALID: The token is expired or invalid. The token cannot be used for a normal operation. - """ - - FRESH = 1 - STALE = 2 - INVALID = 3 diff --git a/google/auth/metrics.py b/google/auth/metrics.py index 11e4b0773..87b83fd1b 100644 --- a/google/auth/metrics.py +++ b/google/auth/metrics.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" We use x-goog-api-client header to report metrics. This module provides +"""We use x-goog-api-client header to report metrics. This module provides the constants and helper methods to construct x-goog-api-client header. """ import platform +from typing import Mapping, Optional from google.auth import version @@ -48,6 +49,7 @@ def python_and_auth_lib_version(): # Token request metric header values + # x-goog-api-client header value for access token request via metadata server. # Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds" def token_request_access_token_mds(): @@ -108,6 +110,7 @@ def token_request_user(): # Miscellenous metrics + # x-goog-api-client header value for metadata server ping. # Example: "gl-python/3.7 auth/1.1 auth-request-type/mds" def mds_ping(): @@ -135,11 +138,11 @@ def byoid_metrics_header(metrics_options): return header -def add_metric_header(headers, metric_header_value): +def add_metric_header(headers: dict[str, str], metric_header_value: Optional[str]): """Add x-goog-api-client header with the given value. Args: - headers (Mapping[str, str]): The headers to which we will add the + headers (dict[str, str]): The headers to which we will add the metric header. metric_header_value (Optional[str]): If value is None, do nothing; if headers already has a x-goog-api-client header, append the value diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 98dafa3e3..070a0bae7 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -72,6 +72,7 @@ import copy import datetime +from typing import Mapping, Optional, Self, Sequence from google.auth import _helpers from google.auth import _service_account_info @@ -80,6 +81,7 @@ from google.auth import iam from google.auth import jwt from google.auth import metrics +from google.auth.crypt import Signer from google.oauth2 import _client _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds @@ -129,18 +131,18 @@ class Credentials( def __init__( self, - signer, - service_account_email, - token_uri, - scopes=None, - default_scopes=None, - subject=None, - project_id=None, - quota_project_id=None, - additional_claims=None, - always_use_jwt_access=False, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - trust_boundary=None, + signer: Signer, + service_account_email: str, + token_uri: str, + scopes: Optional[Sequence[str]] = None, + default_scopes: Optional[Sequence[str]] = None, + subject: Optional[str] = None, + project_id: Optional[str] = None, + quota_project_id: Optional[str] = None, + additional_claims: Optional[dict[str, str]] = None, + always_use_jwt_access: bool = False, + universe_domain: Optional[str] = credentials.DEFAULT_UNIVERSE_DOMAIN, + trust_boundary: Optional[str] = None, ): """ Args: @@ -225,7 +227,7 @@ def _from_signer_and_info(cls, signer, info, **kwargs): ) @classmethod - def from_service_account_info(cls, info, **kwargs): + def from_service_account_info(cls, info: Mapping[str, str], **kwargs) -> Self: """Creates a Credentials instance from parsed service account info. Args: @@ -281,7 +283,7 @@ def requires_scopes(self): """ return True if not self._scopes else False - def _make_copy(self): + def _make_copy(self) -> Self: cred = self.__class__( self._signer, service_account_email=self._service_account_email, @@ -440,8 +442,10 @@ def refresh(self, request): ) if self._use_self_signed_jwt(): + assert self._jwt_credentials is not None self._jwt_credentials.refresh(request) - self.token = self._jwt_credentials.token.decode() + assert self._jwt_credentials.token is not None + self.token = self._jwt_credentials.token self.expiry = self._jwt_credentials.expiry else: assertion = self._make_authorization_grant_assertion() @@ -473,7 +477,6 @@ def _create_self_signed_jwt(self, audience): self._jwt_credentials is None or self._jwt_credentials._audience != audience ): - self._jwt_credentials = jwt.Credentials.from_signing_credentials( self, audience ) @@ -567,13 +570,13 @@ class IDTokenCredentials( def __init__( self, - signer, - service_account_email, - token_uri, - target_audience, - additional_claims=None, - quota_project_id=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, + signer: Signer, + service_account_email: str, + token_uri: str, + target_audience: str, + additional_claims: Optional[dict[str, str]] = None, + quota_project_id: Optional[str] = None, + universe_domain: Optional[str] = credentials.DEFAULT_UNIVERSE_DOMAIN, ): """ Args: @@ -583,7 +586,7 @@ def __init__( target_audience (str): The intended audience for these credentials, used when requesting the ID Token. The ID Token's ``aud`` claim will be set to this string. - additional_claims (Mapping[str, str]): Any additional claims for + additional_claims (dict[str, str]): Any additional claims for the JWT assertion used in the authorization grant. quota_project_id (Optional[str]): The project ID used for quota and billing. universe_domain (str): The universe domain. The default @@ -811,7 +814,7 @@ def _refresh_with_iam_endpoint(self, request): self._iam_id_token_endpoint, self.signer_email, self._target_audience, - jwt_credentials.token.decode(), + jwt_credentials.token, ) @_helpers.copy_docstring(credentials.Credentials)