Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add type hints to credentials #1605

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
22 changes: 14 additions & 8 deletions google/auth/_credentials_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""Interface for base credentials."""

import abc
from typing import Optional

from google.auth import _helpers
from google.auth.transport.requests import Request


class _BaseCredentials(metaclass=abc.ABCMeta):
class BaseCredentials(metaclass=abc.ABCMeta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to change this to a public class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_BaseCredentials is imported from outside the _credentials_base.py file, so it's not private in that scope. You can see that it's captured as an error by running:

pyright google/auth/credentials.py

"""Base class for all credentials.

All credentials have a :attr:`token` that is used for authentication and
Expand All @@ -44,10 +46,10 @@ class _BaseCredentials(metaclass=abc.ABCMeta):
"""

def __init__(self):
self.token = None
self.token: Optional[str] = None

@abc.abstractmethod
def refresh(self, request):
def refresh(self, request: Request) -> None:
"""Refreshes the access token.

Args:
Expand All @@ -62,14 +64,18 @@ def refresh(self, request):
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Refresh must be implemented")

def _apply(self, headers, token=None):
def _apply(self, headers: dict[str, str], token: Optional[str] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use Mapping from typing (here and in other places)? Also update the docstring to Mapping[str, str].

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation requires the type of headers has __setitem__ method that's why being a Mapping is not enough:

headers["authorization"] = "Bearer {}".format(
    _helpers.from_bytes(token or self.token)
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want me to modify the doctring instead?

"""Apply the token to the authentication header.

Args:
headers (Mapping): The HTTP request headers.
headers (dict[str, str]): The HTTP request headers.
token (Optional[str]): If specified, overrides the current access
token.
"""
headers["authorization"] = "Bearer {}".format(
_helpers.from_bytes(token or self.token)
)
if token is not None:
value = token
elif self.token is not None:
value = self.token
else:
assert False, "token must be set"
headers["authorization"] = "Bearer {}".format(_helpers.from_bytes(value))
28 changes: 22 additions & 6 deletions google/auth/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import json
import logging
import os
from typing import Optional, Sequence
import warnings

from google.auth import environment_vars
from google.auth import exceptions
from google.auth.credentials import Credentials
import google.auth.transport._http_client

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,8 +79,12 @@ def _warn_about_problematic_credentials(credentials):


def load_credentials_from_file(
filename, scopes=None, default_scopes=None, quota_project_id=None, request=None
):
filename: str,
scopes: Optional[Sequence[str]] = None,
default_scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
request: Optional[google.auth.transport.Request] = None,
) -> tuple[Credentials, Optional[str]]:
"""Loads Google credentials from a file.

The credentials file must be a service account key, stored authorized
Expand Down Expand Up @@ -173,7 +179,12 @@ def load_credentials_from_dict(


def _load_credentials_from_info(
filename, info, scopes, default_scopes, quota_project_id, request
filename: str,
info,
scopes: Optional[Sequence[str]],
default_scopes: Optional[Sequence[str]],
quota_project_id: Optional[str],
request: Optional[google.auth.transport.Request],
):
from google.auth.credentials import CredentialsWithQuotaProject

Expand Down Expand Up @@ -508,8 +519,8 @@ def _get_gdch_service_account_credentials(filename, info):
from google.oauth2 import gdch_credentials

try:
credentials = gdch_credentials.ServiceAccountCredentials.from_service_account_info(
info
credentials = (
gdch_credentials.ServiceAccountCredentials.from_service_account_info(info)
)
except ValueError as caught_exc:
msg = "Failed to load GDCH service account credentials from {}".format(filename)
Expand Down Expand Up @@ -540,7 +551,12 @@ def _apply_quota_project_id(credentials, quota_project_id):
return credentials


def default(scopes=None, request=None, quota_project_id=None, default_scopes=None):
def default(
scopes: Optional[Sequence[str]] = None,
request: Optional[google.auth.transport.Request] = None,
quota_project_id: Optional[str] = None,
default_scopes: Optional[Sequence[str]] = None,
) -> tuple[Credentials, Optional[str]]:
"""Gets the default credentials for the current environment.

`Application Default Credentials`_ provides an easy way to obtain
Expand Down
10 changes: 6 additions & 4 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import logging
import threading

from google.auth.credentials import Credentials
import google.auth.exceptions as e
from google.auth.transport import Request

_LOGGER = logging.getLogger(__name__)

Expand All @@ -32,7 +34,7 @@ def __init__(self):
self._worker = None
self._lock = threading.Lock() # protects access to worker threads.

def start_refresh(self, cred, request):
def start_refresh(self, cred: Credentials, request: Request) -> bool:
"""Starts a refresh thread for the given credentials.
The credentials are refreshed using the request parameter.
request and cred MUST not be None
Expand All @@ -59,10 +61,10 @@ def start_refresh(self, cred, request):
self._worker.start()
return True

def clear_error(self):
def clear_error(self) -> None:
"""
Removes any errors that were stored from previous background refreshes.
"""
Removes any errors that were stored from previous background refreshes.
"""
with self._lock:
if self._worker:
self._worker._error_info = None
Expand Down
5 changes: 2 additions & 3 deletions google/auth/aio/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@

"""Interfaces for asynchronous credentials."""


from google.auth import _helpers
from google.auth import exceptions
from google.auth._credentials_base import _BaseCredentials
from google.auth._credentials_base import BaseCredentials


class Credentials(_BaseCredentials):
class Credentials(BaseCredentials):
"""Base class for all asynchronous credentials.

All credentials have a :attr:`token` that is used for authentication and
Expand Down
Loading