diff --git a/google/auth/_exponential_backoff.py b/google/auth/_exponential_backoff.py index 04f9f9764..89853448f 100644 --- a/google/auth/_exponential_backoff.py +++ b/google/auth/_exponential_backoff.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import random import time @@ -38,9 +39,8 @@ """ -class ExponentialBackoff: - """An exponential backoff iterator. This can be used in a for loop to - perform requests with exponential backoff. +class _BaseExponentialBackoff: + """An exponential backoff iterator base class. Args: total_attempts Optional[int]: @@ -84,9 +84,40 @@ def __init__( self._multiplier = multiplier self._backoff_count = 0 - def __iter__(self): + @property + def total_attempts(self): + """The total amount of backoff attempts that will be made.""" + return self._total_attempts + + @property + def backoff_count(self): + """The current amount of backoff attempts that have been made.""" + return self._backoff_count + + def _reset(self): self._backoff_count = 0 self._current_wait_in_seconds = self._initial_wait_seconds + + def _calculate_jitter(self): + jitter_variance = self._current_wait_in_seconds * self._randomization_factor + jitter = random.uniform( + self._current_wait_in_seconds - jitter_variance, + self._current_wait_in_seconds + jitter_variance, + ) + + return jitter + + +class ExponentialBackoff(_BaseExponentialBackoff): + """An exponential backoff iterator. This can be used in a for loop to + perform requests with exponential backoff. + """ + + def __init__(self, *args, **kwargs): + super(ExponentialBackoff, self).__init__(*args, **kwargs) + + def __iter__(self): + self._reset() return self def __next__(self): @@ -97,23 +128,37 @@ def __next__(self): if self._backoff_count <= 1: return self._backoff_count - jitter_variance = self._current_wait_in_seconds * self._randomization_factor - jitter = random.uniform( - self._current_wait_in_seconds - jitter_variance, - self._current_wait_in_seconds + jitter_variance, - ) + jitter = self._calculate_jitter() time.sleep(jitter) self._current_wait_in_seconds *= self._multiplier return self._backoff_count - @property - def total_attempts(self): - """The total amount of backoff attempts that will be made.""" - return self._total_attempts - @property - def backoff_count(self): - """The current amount of backoff attempts that have been made.""" +class AsyncExponentialBackoff(_BaseExponentialBackoff): + """An async exponential backoff iterator. This can be used in a for loop to + perform async requests with exponential backoff. + """ + + def __init__(self, *args, **kwargs): + super(AsyncExponentialBackoff, self).__init__(*args, **kwargs) + + def __aiter__(self): + self._reset() + return self + + async def __anext__(self): + if self._backoff_count >= self._total_attempts: + raise StopAsyncIteration + self._backoff_count += 1 + + if self._backoff_count <= 1: + return self._backoff_count + + jitter = self._calculate_jitter() + + await asyncio.sleep(jitter) + + self._current_wait_in_seconds *= self._multiplier return self._backoff_count diff --git a/google/auth/aio/transport/__init__.py b/google/auth/aio/transport/__init__.py new file mode 100644 index 000000000..166a3be50 --- /dev/null +++ b/google/auth/aio/transport/__init__.py @@ -0,0 +1,144 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport - Asynchronous HTTP client library support. + +:mod:`google.auth.aio` is designed to work with various asynchronous client libraries such +as aiohttp. In order to work across these libraries with different +interfaces some abstraction is needed. + +This module provides two interfaces that are implemented by transport adapters +to support HTTP libraries. :class:`Request` defines the interface expected by +:mod:`google.auth` to make asynchronous requests. :class:`Response` defines the interface +for the return value of :class:`Request`. +""" + +import abc +from typing import AsyncGenerator, Mapping, Optional + +import google.auth.transport + + +_DEFAULT_TIMEOUT_SECONDS = 180 + +DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES +"""Sequence[int]: HTTP status codes indicating a request can be retried. +""" + + +DEFAULT_MAX_RETRY_ATTEMPTS = 3 +"""int: How many times to retry a request.""" + + +class Response(metaclass=abc.ABCMeta): + """Asynchronous HTTP Response Interface.""" + + @property + @abc.abstractmethod + def status_code(self) -> int: + """ + The HTTP response status code. + + Returns: + int: The HTTP response status code. + + """ + raise NotImplementedError("status_code must be implemented.") + + @property + @abc.abstractmethod + def headers(self) -> Mapping[str, str]: + """The HTTP response headers. + + Returns: + Mapping[str, str]: The HTTP response headers. + """ + raise NotImplementedError("headers must be implemented.") + + @abc.abstractmethod + async def content(self, chunk_size: int) -> AsyncGenerator[bytes, None]: + """The raw response content. + + Args: + chunk_size (int): The size of each chunk. + + Yields: + AsyncGenerator[bytes, None]: An asynchronous generator yielding + response chunks as bytes. + """ + raise NotImplementedError("content must be implemented.") + + @abc.abstractmethod + async def read(self) -> bytes: + """Read the entire response content as bytes. + + Returns: + bytes: The entire response content. + """ + raise NotImplementedError("read must be implemented.") + + @abc.abstractmethod + async def close(self): + """Close the response after it is fully consumed to resource.""" + raise NotImplementedError("close must be implemented.") + + +class Request(metaclass=abc.ABCMeta): + """Interface for a callable that makes HTTP requests. + + Specific transport implementations should provide an implementation of + this that adapts their specific request / response API. + + .. automethod:: __call__ + """ + + @abc.abstractmethod + async def __call__( + self, + url: str, + method: str, + body: Optional[bytes], + headers: Optional[Mapping[str, str]], + timeout: float, + **kwargs + ) -> Response: + """Make an HTTP request. + + Args: + url (str): The URI to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (Optional[bytes]): The payload / body in HTTP request. + headers (Mapping[str, str]): Request headers. + timeout (float): The number of seconds to wait for a + response from the server. If not specified or if None, the + transport-specific default timeout will be used. + kwargs: Additional arguments passed on to the transport's + request method. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + # pylint: disable=redundant-returns-doc, missing-raises-doc + # (pylint doesn't play well with abstract docstrings.) + raise NotImplementedError("__call__ must be implemented.") + + async def close(self) -> None: + """ + Close the underlying session. + """ + raise NotImplementedError("close must be implemented.") diff --git a/google/auth/aio/transport/aiohttp.py b/google/auth/aio/transport/aiohttp.py new file mode 100644 index 000000000..074d1491c --- /dev/null +++ b/google/auth/aio/transport/aiohttp.py @@ -0,0 +1,184 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport adapter for Asynchronous HTTP Requests based on aiohttp. +""" + +import asyncio +from typing import AsyncGenerator, Mapping, Optional + +try: + import aiohttp # type: ignore +except ImportError as caught_exc: # pragma: NO COVER + raise ImportError( + "The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport." + ) from caught_exc + +from google.auth import _helpers +from google.auth import exceptions +from google.auth.aio import transport + + +class Response(transport.Response): + """ + Represents an HTTP response and its data. It is returned by ``google.auth.aio.transport.sessions.AsyncAuthorizedSession``. + + Args: + response (aiohttp.ClientResponse): An instance of aiohttp.ClientResponse. + + Attributes: + status_code (int): The HTTP status code of the response. + headers (Mapping[str, str]): The HTTP headers of the response. + """ + + def __init__(self, response: aiohttp.ClientResponse): + self._response = response + + @property + @_helpers.copy_docstring(transport.Response) + def status_code(self) -> int: + return self._response.status + + @property + @_helpers.copy_docstring(transport.Response) + def headers(self) -> Mapping[str, str]: + return {key: value for key, value in self._response.headers.items()} + + @_helpers.copy_docstring(transport.Response) + async def content(self, chunk_size: int = 1024) -> AsyncGenerator[bytes, None]: + try: + async for chunk in self._response.content.iter_chunked( + chunk_size + ): # pragma: no branch + yield chunk + except aiohttp.ClientPayloadError as exc: + raise exceptions.ResponseError( + "Failed to read from the payload stream." + ) from exc + + @_helpers.copy_docstring(transport.Response) + async def read(self) -> bytes: + try: + return await self._response.read() + except aiohttp.ClientResponseError as exc: + raise exceptions.ResponseError("Failed to read the response body.") from exc + + @_helpers.copy_docstring(transport.Response) + async def close(self): + self._response.close() + + +class Request(transport.Request): + """Asynchronous Requests request adapter. + + This class is used internally for making requests using aiohttp + in a consistent way. If you use :class:`google.auth.aio.transport.sessions.AsyncAuthorizedSession` + you do not need to construct or use this class directly. + + This class can be useful if you want to configure a Request callable + with a custom ``aiohttp.ClientSession`` in :class:`AuthorizedSession` or if + you want to manually refresh a :class:`~google.auth.aio.credentials.Credentials` instance:: + + import aiohttp + import google.auth.aio.transport.aiohttp + + # Default example: + request = google.auth.aio.transport.aiohttp.Request() + await credentials.refresh(request) + + # Custom aiohttp Session Example: + session = session=aiohttp.ClientSession(auto_decompress=False) + request = google.auth.aio.transport.aiohttp.Request(session=session) + auth_sesion = google.auth.aio.transport.sessions.AsyncAuthorizedSession(auth_request=request) + + Args: + session (aiohttp.ClientSession): An instance :class:`aiohttp.ClientSession` used + to make HTTP requests. If not specified, a session will be created. + + .. automethod:: __call__ + """ + + def __init__(self, session: aiohttp.ClientSession = None): + self._session = session + self._closed = False + + async def __call__( + self, + url: str, + method: str = "GET", + body: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + """ + Make an HTTP request using aiohttp. + + Args: + url (str): The URL to be requested. + method (Optional[str]): + The HTTP method to use for the request. Defaults to 'GET'. + body (Optional[bytes]): + The payload or body in HTTP request. + headers (Optional[Mapping[str, str]]): + Request headers. + timeout (float): The number of seconds to wait for a + response from the server. If not specified or if None, the + requests default timeout will be used. + kwargs: Additional arguments passed through to the underlying + aiohttp :meth:`aiohttp.Session.request` method. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + - google.auth.exceptions.TransportError: If the request fails or if the session is closed. + - google.auth.exceptions.TimeoutError: If the request times out. + """ + + try: + if self._closed: + raise exceptions.TransportError("session is closed.") + + if not self._session: + self._session = aiohttp.ClientSession() + + client_timeout = aiohttp.ClientTimeout(total=timeout) + response = await self._session.request( + method, + url, + data=body, + headers=headers, + timeout=client_timeout, + **kwargs, + ) + return Response(response) + + except aiohttp.ClientError as caught_exc: + client_exc = exceptions.TransportError(f"Failed to send request to {url}.") + raise client_exc from caught_exc + + except asyncio.TimeoutError as caught_exc: + timeout_exc = exceptions.TimeoutError( + f"Request timed out after {timeout} seconds." + ) + raise timeout_exc from caught_exc + + async def close(self) -> None: + """ + Close the underlying aiohttp session to release the acquired resources. + """ + if not self._closed and self._session: + await self._session.close() + self._closed = True diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py new file mode 100644 index 000000000..fea7cbbb2 --- /dev/null +++ b/google/auth/aio/transport/sessions.py @@ -0,0 +1,268 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from contextlib import asynccontextmanager +import functools +import time +from typing import Mapping, Optional + +from google.auth import _exponential_backoff, exceptions +from google.auth.aio import transport +from google.auth.aio.credentials import Credentials +from google.auth.exceptions import TimeoutError + +try: + from google.auth.aio.transport.aiohttp import Request as AiohttpRequest + + AIOHTTP_INSTALLED = True +except ImportError: # pragma: NO COVER + AIOHTTP_INSTALLED = False + + +@asynccontextmanager +async def timeout_guard(timeout): + """ + timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code. + + Args: + timeout (float): The time in seconds before the context manager times out. + + Raises: + google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout. + + Usage: + async with timeout_guard(10) as with_timeout: + await with_timeout(async_function()) + """ + start = time.monotonic() + total_timeout = timeout + + def _remaining_time(): + elapsed = time.monotonic() - start + remaining = total_timeout - elapsed + if remaining <= 0: + raise TimeoutError( + f"Context manager exceeded the configured timeout of {total_timeout}s." + ) + return remaining + + async def with_timeout(coro): + try: + remaining = _remaining_time() + response = await asyncio.wait_for(coro, remaining) + return response + except (asyncio.TimeoutError, TimeoutError) as e: + raise TimeoutError( + f"The operation {coro} exceeded the configured timeout of {total_timeout}s." + ) from e + + try: + yield with_timeout + + finally: + _remaining_time() + + +class AsyncAuthorizedSession: + """This is an asynchronous implementation of :class:`google.auth.requests.AuthorizedSession` class. + We utilize an instance of a class that implements :class:`google.auth.aio.transport.Request` configured + by the caller or otherwise default to `google.auth.aio.transport.aiohttp.Request` if the external aiohttp + package is installed. + + A Requests Session class with credentials. + + This class is used to perform asynchronous requests to API endpoints that require + authorization:: + + import aiohttp + from google.auth.aio.transport import sessions + + async with sessions.AsyncAuthorizedSession(credentials) as authed_session: + response = await authed_session.request( + 'GET', 'https://www.googleapis.com/storage/v1/b') + + The underlying :meth:`request` implementation handles adding the + credentials' headers to the request and refreshing credentials as needed. + + Args: + credentials (google.auth.aio.credentials.Credentials): + The credentials to add to the request. + auth_request (Optional[google.auth.aio.transport.Request]): + An instance of a class that implements + :class:`~google.auth.aio.transport.Request` used to make requests + and refresh credentials. If not passed, + an instance of :class:`~google.auth.aio.transport.aiohttp.Request` + is created. + + Raises: + - google.auth.exceptions.TransportError: If `auth_request` is `None` + and the external package `aiohttp` is not installed. + - google.auth.exceptions.InvalidType: If the provided credentials are + not of type `google.auth.aio.credentials.Credentials`. + """ + + def __init__( + self, credentials: Credentials, auth_request: Optional[transport.Request] = None + ): + if not isinstance(credentials, Credentials): + raise exceptions.InvalidType( + f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`" + ) + self._credentials = credentials + _auth_request = auth_request + if not _auth_request and AIOHTTP_INSTALLED: + _auth_request = AiohttpRequest() + if _auth_request is None: + raise exceptions.TransportError( + "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." + ) + self._auth_request = _auth_request + + async def request( + self, + method: str, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + """ + Args: + method (str): The http method used to make the request. + url (str): The URI to be requested. + data (Optional[bytes]): The payload or body in HTTP request. + headers (Optional[Mapping[str, str]]): Request headers. + timeout (float): + The amount of time in seconds to wait for the server response + with each individual request. + max_allowed_time (float): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout`` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time``. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. + + Returns: + google.auth.aio.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TimeoutError: If the method does not complete within + the configured `max_allowed_time` or the request exceeds the configured + `timeout`. + """ + + retries = _exponential_backoff.AsyncExponentialBackoff( + total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS + ) + async with timeout_guard(max_allowed_time) as with_timeout: + await with_timeout( + # Note: before_request will attempt to refresh credentials if expired. + self._credentials.before_request( + self._auth_request, method, url, headers + ) + ) + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for _ in retries: # pragma: no branch + response = await with_timeout( + self._auth_request(url, method, data, headers, timeout, **kwargs) + ) + if response.status_code not in transport.DEFAULT_RETRYABLE_STATUS_CODES: + break + return response + + @functools.wraps(request) + async def get( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "GET", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def post( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "POST", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def put( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "PUT", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def patch( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "PATCH", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + @functools.wraps(request) + async def delete( + self, + url: str, + data: Optional[bytes] = None, + headers: Optional[Mapping[str, str]] = None, + max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS, + timeout: float = transport._DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ) -> transport.Response: + return await self.request( + "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs + ) + + async def close(self) -> None: + """ + Close the underlying auth request session. + """ + await self._auth_request.close() diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py index fcbe61b74..feb9f7411 100644 --- a/google/auth/exceptions.py +++ b/google/auth/exceptions.py @@ -98,3 +98,11 @@ class InvalidType(DefaultCredentialsError, TypeError): class OSError(DefaultCredentialsError, EnvironmentError): """Used to wrap EnvironmentError(OSError after python3.3).""" + + +class TimeoutError(GoogleAuthError): + """Used to indicate a timeout error occurred during an HTTP request.""" + + +class ResponseError(GoogleAuthError): + """Used to indicate an error occurred when reading an HTTP response.""" diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index c22d40f6a..7f0fd8d86 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/test__exponential_backoff.py b/tests/test__exponential_backoff.py index 95422502b..b7b6877b2 100644 --- a/tests/test__exponential_backoff.py +++ b/tests/test__exponential_backoff.py @@ -54,3 +54,44 @@ def test_minimum_total_attempts(): with pytest.raises(exceptions.InvalidValue): _exponential_backoff.ExponentialBackoff(total_attempts=-1) _exponential_backoff.ExponentialBackoff(total_attempts=1) + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", return_value=None) +async def test_exponential_backoff_async(mock_time_async): + eb = _exponential_backoff.AsyncExponentialBackoff() + curr_wait = eb._current_wait_in_seconds + iteration_count = 0 + + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for attempt in eb: # pragma: no branch + if attempt == 1: + assert mock_time_async.call_count == 0 + else: + backoff_interval = mock_time_async.call_args[0][0] + jitter = curr_wait * eb._randomization_factor + + assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter) + assert attempt == iteration_count + 1 + assert eb.backoff_count == iteration_count + 1 + assert eb._current_wait_in_seconds == eb._multiplier ** iteration_count + + curr_wait = eb._current_wait_in_seconds + iteration_count += 1 + + assert eb.total_attempts == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert eb.backoff_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert iteration_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert ( + mock_time_async.call_count + == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS - 1 + ) + + +def test_minimum_total_attempts_async(): + with pytest.raises(exceptions.InvalidValue): + _exponential_backoff.AsyncExponentialBackoff(total_attempts=0) + with pytest.raises(exceptions.InvalidValue): + _exponential_backoff.AsyncExponentialBackoff(total_attempts=-1) + _exponential_backoff.AsyncExponentialBackoff(total_attempts=1) diff --git a/tests/transport/aio/test_aiohttp.py b/tests/transport/aio/test_aiohttp.py new file mode 100644 index 000000000..632abff25 --- /dev/null +++ b/tests/transport/aio/test_aiohttp.py @@ -0,0 +1,170 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from aioresponses import aioresponses # type: ignore +from mock import AsyncMock, Mock, patch +import pytest # type: ignore +import pytest_asyncio # type: ignore + +from google.auth import exceptions +import google.auth.aio.transport.aiohttp as auth_aiohttp + + +try: + import aiohttp # type: ignore +except ImportError as caught_exc: # pragma: NO COVER + raise ImportError( + "The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport." + ) from caught_exc + + +@pytest.fixture +def mock_response(): + response = Mock() + response.status = 200 + response.headers = {"Content-Type": "application/json", "Content-Length": "100"} + mock_iterator = AsyncMock() + mock_iterator.__aiter__.return_value = iter( + [b"Cavefish ", b"have ", b"no ", b"sight."] + ) + response.content.iter_chunked = lambda chunk_size: mock_iterator + response.read = AsyncMock(return_value=b"Cavefish have no sight.") + response.close = AsyncMock() + + return auth_aiohttp.Response(response) + + +class TestResponse(object): + @pytest.mark.asyncio + async def test_response_status_code(self, mock_response): + assert mock_response.status_code == 200 + + @pytest.mark.asyncio + async def test_response_headers(self, mock_response): + assert mock_response.headers["Content-Type"] == "application/json" + assert mock_response.headers["Content-Length"] == "100" + + @pytest.mark.asyncio + async def test_response_content(self, mock_response): + content = b"".join([chunk async for chunk in mock_response.content()]) + assert content == b"Cavefish have no sight." + + @pytest.mark.asyncio + async def test_response_content_raises_error(self, mock_response): + with patch.object( + mock_response._response.content, + "iter_chunked", + side_effect=aiohttp.ClientPayloadError, + ): + with pytest.raises(exceptions.ResponseError) as exc: + [chunk async for chunk in mock_response.content()] + exc.match("Failed to read from the payload stream") + + @pytest.mark.asyncio + async def test_response_read(self, mock_response): + content = await mock_response.read() + assert content == b"Cavefish have no sight." + + @pytest.mark.asyncio + async def test_response_read_raises_error(self, mock_response): + with patch.object( + mock_response._response, + "read", + side_effect=aiohttp.ClientResponseError(None, None), + ): + with pytest.raises(exceptions.ResponseError) as exc: + await mock_response.read() + exc.match("Failed to read the response body.") + + @pytest.mark.asyncio + async def test_response_close(self, mock_response): + await mock_response.close() + mock_response._response.close.assert_called_once() + + @pytest.mark.asyncio + async def test_response_content_stream(self, mock_response): + itr = mock_response.content().__aiter__() + content = [] + try: + while True: + chunk = await itr.__anext__() + content.append(chunk) + except StopAsyncIteration: + pass + assert b"".join(content) == b"Cavefish have no sight." + + +@pytest.mark.asyncio +class TestRequest: + @pytest_asyncio.fixture + async def aiohttp_request(self): + request = auth_aiohttp.Request() + yield request + await request.close() + + async def test_request_call_success(self, aiohttp_request): + with aioresponses() as m: + mocked_chunks = [b"Cavefish ", b"have ", b"no ", b"sight."] + mocked_response = b"".join(mocked_chunks) + m.get("http://example.com", status=200, body=mocked_response) + response = await aiohttp_request("http://example.com") + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + content = b"".join([chunk async for chunk in response.content()]) + assert content == b"Cavefish have no sight." + + async def test_request_call_success_with_provided_session(self): + mock_session = aiohttp.ClientSession() + request = auth_aiohttp.Request(mock_session) + with aioresponses() as m: + mocked_chunks = [b"Cavefish ", b"have ", b"no ", b"sight."] + mocked_response = b"".join(mocked_chunks) + m.get("http://example.com", status=200, body=mocked_response) + response = await request("http://example.com") + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + content = b"".join([chunk async for chunk in response.content()]) + assert content == b"Cavefish have no sight." + + async def test_request_call_raises_client_error(self, aiohttp_request): + with aioresponses() as m: + m.get("http://example.com", exception=aiohttp.ClientError) + + with pytest.raises(exceptions.TransportError) as exc: + await aiohttp_request("http://example.com/api") + + exc.match("Failed to send request to http://example.com/api.") + + async def test_request_call_raises_timeout_error(self, aiohttp_request): + with aioresponses() as m: + m.get("http://example.com", exception=asyncio.TimeoutError) + + with pytest.raises(exceptions.TimeoutError) as exc: + await aiohttp_request("http://example.com") + + exc.match("Request timed out after 180 seconds.") + + async def test_request_call_raises_transport_error_for_closed_session( + self, aiohttp_request + ): + with aioresponses() as m: + m.get("http://example.com", exception=asyncio.TimeoutError) + aiohttp_request._closed = True + with pytest.raises(exceptions.TransportError) as exc: + await aiohttp_request("http://example.com") + + exc.match("session is closed.") + aiohttp_request._closed = False diff --git a/tests/transport/aio/test_sessions.py b/tests/transport/aio/test_sessions.py new file mode 100644 index 000000000..c91a7c40a --- /dev/null +++ b/tests/transport/aio/test_sessions.py @@ -0,0 +1,311 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import AsyncGenerator + +from aioresponses import aioresponses # type: ignore +from mock import Mock, patch +import pytest # type: ignore + +from google.auth.aio.credentials import AnonymousCredentials +from google.auth.aio.transport import ( + _DEFAULT_TIMEOUT_SECONDS, + DEFAULT_MAX_RETRY_ATTEMPTS, + DEFAULT_RETRYABLE_STATUS_CODES, + Request, + Response, + sessions, +) +from google.auth.exceptions import InvalidType, TimeoutError, TransportError + + +@pytest.fixture +async def simple_async_task(): + return True + + +class MockRequest(Request): + def __init__(self, response=None, side_effect=None): + self._closed = False + self._response = response + self._side_effect = side_effect + self.call_count = 0 + + async def __call__( + self, + url, + method="GET", + body=None, + headers=None, + timeout=_DEFAULT_TIMEOUT_SECONDS, + **kwargs, + ): + self.call_count += 1 + if self._side_effect: + raise self._side_effect + return self._response + + async def close(self): + self._closed = True + return None + + +class MockResponse(Response): + def __init__(self, status_code, headers=None, content=None): + self._status_code = status_code + self._headers = headers + self._content = content + self._close = False + + @property + def status_code(self): + return self._status_code + + @property + def headers(self): + return self._headers + + async def read(self) -> bytes: + content = await self.content(1024) + return b"".join([chunk async for chunk in content]) + + async def content(self, chunk_size=None) -> AsyncGenerator: + return self._content + + async def close(self) -> None: + self._close = True + + +class TestTimeoutGuard(object): + default_timeout = 1 + + def make_timeout_guard(self, timeout): + return sessions.timeout_guard(timeout) + + @pytest.mark.asyncio + async def test_timeout_with_simple_async_task_within_bounds( + self, simple_async_task + ): + task = False + with patch("time.monotonic", side_effect=[0, 0.25, 0.75]): + with patch("asyncio.wait_for", lambda coro, _: coro): + async with self.make_timeout_guard( + timeout=self.default_timeout + ) as with_timeout: + task = await with_timeout(simple_async_task) + + # Task succeeds. + assert task is True + + @pytest.mark.asyncio + async def test_timeout_with_simple_async_task_out_of_bounds( + self, simple_async_task + ): + task = False + with patch("time.monotonic", side_effect=[0, 1, 1]): + with pytest.raises(TimeoutError) as exc: + async with self.make_timeout_guard( + timeout=self.default_timeout + ) as with_timeout: + task = await with_timeout(simple_async_task) + + # Task does not succeed and the context manager times out i.e. no remaining time left. + assert task is False + assert exc.match( + f"Context manager exceeded the configured timeout of {self.default_timeout}s." + ) + + @pytest.mark.asyncio + async def test_timeout_with_async_task_timing_out_before_context( + self, simple_async_task + ): + task = False + with pytest.raises(TimeoutError) as exc: + async with self.make_timeout_guard( + timeout=self.default_timeout + ) as with_timeout: + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): + task = await with_timeout(simple_async_task) + + # Task does not complete i.e. the operation times out. + assert task is False + assert exc.match( + f"The operation {simple_async_task} exceeded the configured timeout of {self.default_timeout}s." + ) + + +class TestAsyncAuthorizedSession(object): + TEST_URL = "http://example.com/" + credentials = AnonymousCredentials() + + @pytest.fixture + async def mocked_content(self): + content = [b"Cavefish ", b"have ", b"no ", b"sight."] + for chunk in content: + yield chunk + + @pytest.mark.asyncio + async def test_constructor_with_default_auth_request(self): + with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", True): + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + assert authed_session._credentials == self.credentials + await authed_session.close() + + @pytest.mark.asyncio + async def test_constructor_with_provided_auth_request(self): + auth_request = MockRequest() + authed_session = sessions.AsyncAuthorizedSession( + self.credentials, auth_request=auth_request + ) + + assert authed_session._auth_request is auth_request + await authed_session.close() + + @pytest.mark.asyncio + async def test_constructor_raises_no_auth_request_error(self): + with patch("google.auth.aio.transport.sessions.AIOHTTP_INSTALLED", False): + with pytest.raises(TransportError) as exc: + sessions.AsyncAuthorizedSession(self.credentials) + + exc.match( + "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." + ) + + @pytest.mark.asyncio + async def test_constructor_raises_incorrect_credentials_error(self): + credentials = Mock() + with pytest.raises(InvalidType) as exc: + sessions.AsyncAuthorizedSession(credentials) + + exc.match( + f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`" + ) + + @pytest.mark.asyncio + async def test_request_default_auth_request_success(self): + with aioresponses() as m: + mocked_chunks = [b"Cavefish ", b"have ", b"no ", b"sight."] + mocked_response = b"".join(mocked_chunks) + m.get(self.TEST_URL, status=200, body=mocked_response) + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + response = await authed_session.request("GET", self.TEST_URL) + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + assert await response.read() == b"Cavefish have no sight." + await response.close() + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_provided_auth_request_success(self, mocked_content): + mocked_response = MockResponse( + status_code=200, + headers={"Content-Type": "application/json"}, + content=mocked_content, + ) + auth_request = MockRequest(mocked_response) + authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request) + response = await authed_session.request("GET", self.TEST_URL) + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + assert await response.read() == b"Cavefish have no sight." + await response.close() + assert response._close + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_raises_timeout_error(self): + auth_request = MockRequest(side_effect=asyncio.TimeoutError) + authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request) + with pytest.raises(TimeoutError): + await authed_session.request("GET", self.TEST_URL) + + @pytest.mark.asyncio + async def test_request_raises_transport_error(self): + auth_request = MockRequest(side_effect=TransportError) + authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request) + with pytest.raises(TransportError): + await authed_session.request("GET", self.TEST_URL) + + @pytest.mark.asyncio + async def test_request_max_allowed_time_exceeded_error(self): + auth_request = MockRequest(side_effect=TransportError) + authed_session = sessions.AsyncAuthorizedSession(self.credentials, auth_request) + with patch("time.monotonic", side_effect=[0, 1, 1]): + with pytest.raises(TimeoutError): + await authed_session.request("GET", self.TEST_URL, max_allowed_time=1) + + @pytest.mark.parametrize("retry_status", DEFAULT_RETRYABLE_STATUS_CODES) + @pytest.mark.asyncio + async def test_request_max_retries(self, retry_status): + mocked_response = MockResponse(status_code=retry_status) + auth_request = MockRequest(mocked_response) + with patch("asyncio.sleep", return_value=None): + authed_session = sessions.AsyncAuthorizedSession( + self.credentials, auth_request + ) + await authed_session.request("GET", self.TEST_URL) + assert auth_request.call_count == DEFAULT_MAX_RETRY_ATTEMPTS + + @pytest.mark.asyncio + async def test_http_get_method_success(self): + expected_payload = b"content is retrieved." + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + with aioresponses() as m: + m.get(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.get(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_post_method_success(self): + expected_payload = b"content is posted." + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + with aioresponses() as m: + m.post(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.post(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_put_method_success(self): + expected_payload = b"content is retrieved." + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + with aioresponses() as m: + m.put(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.put(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_patch_method_success(self): + expected_payload = b"content is retrieved." + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + with aioresponses() as m: + m.patch(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.patch(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close() + + @pytest.mark.asyncio + async def test_http_delete_method_success(self): + expected_payload = b"content is deleted." + authed_session = sessions.AsyncAuthorizedSession(self.credentials) + with aioresponses() as m: + m.delete(self.TEST_URL, status=200, body=expected_payload) + response = await authed_session.delete(self.TEST_URL) + assert await response.read() == expected_payload + response = await authed_session.close()