Skip to content

Commit

Permalink
feat: Add optional non blocking refresh for sync auth code (#1368)
Browse files Browse the repository at this point in the history
feat: Add optional non blocking refresh for sync auth code
  • Loading branch information
clundin25 authored Dec 18, 2023
1 parent cc960e6 commit a6dc2c3
Show file tree
Hide file tree
Showing 12 changed files with 485 additions and 8 deletions.
98 changes: 98 additions & 0 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import threading

import google.auth.exceptions as e

_LOGGER = logging.getLogger(__name__)


class RefreshThreadManager:
"""
Organizes exactly one background job that refresh a token.
"""

def __init__(self):
"""Initializes the manager."""

self._worker = None
self._lock = threading.Lock() # protects access to worker threads.

def start_refresh(self, cred, request):
"""Starts a refresh thread for the given credentials.
The credentials are refreshed using the request parameter.
request and cred MUST not be None
Returns True if a background refresh was kicked off. False otherwise.
Args:
cred: A credentials object.
request: A request object.
Returns:
bool
"""
if cred is None or request is None:
raise e.InvalidValue(
"Unable to start refresh. cred and request must be valid and instantiated objects."
)

with self._lock:
if self._worker is not None and self._worker._error_info is not None:
return False

if self._worker is None or not self._worker.is_alive(): # pragma: NO COVER
self._worker = RefreshThread(cred=cred, request=copy.deepcopy(request))
self._worker.start()
return True

def clear_error(self):
"""
Removes any errors that were stored from previous background refreshes.
"""
with self._lock:
if self._worker:
self._worker._error_info = None


class RefreshThread(threading.Thread):
"""
Thread that refreshes credentials.
"""

def __init__(self, cred, request, **kwargs):
"""Initializes the thread.
Args:
cred: A Credential object to refresh.
request: A Request object used to perform a credential refresh.
**kwargs: Additional keyword arguments.
"""

super().__init__(**kwargs)
self._cred = cred
self._request = request
self._error_info = None

def run(self):
"""
Perform the credential refresh.
"""
try:
self._cred.refresh(self._request)
except Exception as err: # pragma: NO COVER
_LOGGER.error(f"Background refresh failed due to: {err}")
self._error_info = err
76 changes: 73 additions & 3 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""Interfaces for credentials."""

import abc
from enum import Enum
import os

from google.auth import _helpers, environment_vars
from google.auth import exceptions
from google.auth import metrics
from google.auth._refresh_worker import RefreshThreadManager


class Credentials(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -59,17 +61,22 @@ def __init__(self):
"""Optional[str]: The universe domain value, default is googleapis.com
"""

self._use_non_blocking_refresh = False
self._refresh_worker = RefreshThreadManager()

@property
def expired(self):
"""Checks if the credentials are expired.
Note that credentials can be invalid but not expired because
Credentials with :attr:`expiry` set to None is considered to never
expire.
.. deprecated:: v2.24.0
Prefer checking :attr:`token_state` instead.
"""
if not self.expiry:
return False

# Remove some threshold from expiry to err on the side of reporting
# expiration early so that we avoid the 401-refresh-retry loop.
skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD
Expand All @@ -81,9 +88,34 @@ def valid(self):
This is True if the credentials have a :attr:`token` and the token
is not :attr:`expired`.
.. deprecated:: v2.24.0
Prefer checking :attr:`token_state` instead.
"""
return self.token is not None and not self.expired

@property
def token_state(self):
"""
See `:obj:`TokenState`
"""
if self.token is None:
return TokenState.INVALID

# Credentials that can't expire are always treated as fresh.
if self.expiry is None:
return TokenState.FRESH

expired = _helpers.utcnow() >= self.expiry
if expired:
return TokenState.INVALID

is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD)
if is_stale:
return TokenState.STALE

return TokenState.FRESH

@property
def quota_project_id(self):
"""Project to use for quota and billing purposes."""
Expand Down Expand Up @@ -154,6 +186,25 @@ def apply(self, headers, token=None):
if self.quota_project_id:
headers["x-goog-user-project"] = self.quota_project_id

def _blocking_refresh(self, request):
if not self.valid:
self.refresh(request)

def _non_blocking_refresh(self, request):
use_blocking_refresh_fallback = False

if self.token_state == TokenState.STALE:
use_blocking_refresh_fallback = not self._refresh_worker.start_refresh(
self, request
)

if self.token_state == TokenState.INVALID or use_blocking_refresh_fallback:
self.refresh(request)
# If the blocking refresh succeeds then we can clear the error info
# on the background refresh worker, and perform refreshes in a
# background thread.
self._refresh_worker.clear_error()

def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic.
Expand All @@ -171,11 +222,17 @@ def before_request(self, request, method, url, headers):
# pylint: disable=unused-argument
# (Subclasses may use these arguments to ascertain information about
# the http request.)
if not self.valid:
self.refresh(request)
if self._use_non_blocking_refresh:
self._non_blocking_refresh(request)
else:
self._blocking_refresh(request)

metrics.add_metric_header(headers, self._metric_header_for_usage())
self.apply(headers)

def with_non_blocking_refresh(self):
self._use_non_blocking_refresh = True


class CredentialsWithQuotaProject(Credentials):
"""Abstract base for credentials supporting ``with_quota_project`` factory"""
Expand Down Expand Up @@ -439,3 +496,16 @@ def signer(self):
# pylint: disable=missing-raises-doc
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Signer must be implemented.")


class TokenState(Enum):
"""
Tracks the state of a token.
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry.
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
"""

FRESH = 1
STALE = 2
INVALID = 3
5 changes: 4 additions & 1 deletion google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def _update_token(self, request):
"""

# Refresh our source credentials if it is not valid.
if not self._source_credentials.valid:
if (
self._source_credentials.token_state == credentials.TokenState.STALE
or self._source_credentials.token_state == credentials.TokenState.INVALID
):
self._source_credentials.refresh(request)

body = {
Expand Down
4 changes: 4 additions & 0 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def __getstate__(self):
# because they need to be importable.
# Instead, the refresh_handler setter should be used to repopulate this.
del state_dict["_refresh_handler"]
# Remove worker as it contains multiproccessing queue objects.
del state_dict["_refresh_worker"]
return state_dict

def __setstate__(self, d):
Expand All @@ -183,6 +185,8 @@ def __setstate__(self, d):
self._universe_domain = d.get("_universe_domain") or _DEFAULT_UNIVERSE_DOMAIN
# The refresh_handler setter should be used to repopulate this.
self._refresh_handler = None
self._refresh_worker = None
self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh")

@property
def refresh_token(self):
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
10 changes: 8 additions & 2 deletions tests/oauth2/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth import transport
from google.auth.credentials import TokenState
from google.oauth2 import credentials


Expand Down Expand Up @@ -61,6 +62,7 @@ def test_default_state(self):
assert not credentials.expired
# Scopes aren't required for these credentials
assert not credentials.requires_scopes
assert credentials.token_state == TokenState.INVALID
# Test properties
assert credentials.refresh_token == self.REFRESH_TOKEN
assert credentials.token_uri == self.TOKEN_URI
Expand Down Expand Up @@ -911,7 +913,11 @@ def test_pickle_and_unpickle(self):
assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()

for attr in list(creds.__dict__):
assert getattr(creds, attr) == getattr(unpickled, attr)
# Worker should always be None
if attr == "_refresh_worker":
assert getattr(unpickled, attr) is None
else:
assert getattr(creds, attr) == getattr(unpickled, attr)

def test_pickle_and_unpickle_universe_domain(self):
# old version of auth lib doesn't have _universe_domain, so the pickled
Expand Down Expand Up @@ -945,7 +951,7 @@ def test_pickle_and_unpickle_with_refresh_handler(self):
for attr in list(creds.__dict__):
# For the _refresh_handler property, the unpickled creds should be
# set to None.
if attr == "_refresh_handler":
if attr == "_refresh_handler" or attr == "_refresh_worker":
assert getattr(unpickled, attr) is None
else:
assert getattr(creds, attr) == getattr(unpickled, attr)
Expand Down
Loading

0 comments on commit a6dc2c3

Please sign in to comment.