diff --git a/pontos/nvd/api.py b/pontos/nvd/api.py index 709023fa8..9ce9952e0 100644 --- a/pontos/nvd/api.py +++ b/pontos/nvd/api.py @@ -33,6 +33,7 @@ SLEEP_TIMEOUT = 30.0 # in seconds DEFAULT_TIMEOUT = 180.0 # three minutes DEFAULT_TIMEOUT_CONFIG = Timeout(DEFAULT_TIMEOUT) # three minutes +RETRY_DELAY = 2.0 # in seconds Headers = Dict[str, str] Params = Dict[str, Union[str, int]] @@ -342,6 +343,7 @@ def __init__( token: Optional[str] = None, timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG, rate_limit: bool = True, + request_attempts: int = 1, ) -> None: """ Create a new instance of the CVE API. @@ -357,6 +359,7 @@ def __init__( rolling 30 second window. See https://nvd.nist.gov/developers/start-here#divRateLimits Default: True. + request_attempts: The number of attempts per HTTP request. Defaults to 1. """ self._url = url self._token = token @@ -370,6 +373,8 @@ def __init__( self._request_count = 0 self._last_sleep = time.monotonic() + self._request_attempts = request_attempts + def _request_headers(self) -> Headers: """ Get the default request headers @@ -409,9 +414,19 @@ async def _get( """ headers = self._request_headers() - await self._consider_rate_limit() + for attempt in range(self._request_attempts): + if attempt > 0: + delay = RETRY_DELAY**attempt + await asyncio.sleep(delay) + + await self._consider_rate_limit() + response = await self._client.get( + self._url, headers=headers, params=params + ) + if not response.is_server_error: + break - return await self._client.get(self._url, headers=headers, params=params) + return response async def __aenter__(self) -> "NVDApi": # reset rate limit counter diff --git a/pontos/nvd/cpe/api.py b/pontos/nvd/cpe/api.py index ba4df766c..d397e50bb 100644 --- a/pontos/nvd/cpe/api.py +++ b/pontos/nvd/cpe/api.py @@ -61,6 +61,7 @@ def __init__( token: Optional[str] = None, timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG, rate_limit: bool = True, + request_attempts: int = 1, ) -> None: """ Create a new instance of the CPE API. @@ -75,12 +76,14 @@ def __init__( rolling 30 second window. See https://nvd.nist.gov/developers/start-here#divRateLimits Default: True. + request_attempts: The number of attempts per HTTP request. Defaults to 1. """ super().__init__( DEFAULT_NIST_NVD_CPES_URL, token=token, timeout=timeout, rate_limit=rate_limit, + request_attempts=request_attempts, ) async def cpe(self, cpe_name_id: Union[str, UUID]) -> CPE: diff --git a/pontos/nvd/cve/api.py b/pontos/nvd/cve/api.py index 98f0af10f..6dbce3dda 100644 --- a/pontos/nvd/cve/api.py +++ b/pontos/nvd/cve/api.py @@ -65,6 +65,7 @@ def __init__( token: Optional[str] = None, timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG, rate_limit: bool = True, + request_attempts: int = 1, ) -> None: """ Create a new instance of the CVE API. @@ -79,12 +80,14 @@ def __init__( rolling 30 second window. See https://nvd.nist.gov/developers/start-here#divRateLimits Default: True. + request_attempts: The number of attempts per HTTP request. Defaults to 1. """ super().__init__( DEFAULT_NIST_NVD_CVES_URL, token=token, timeout=timeout, rate_limit=rate_limit, + request_attempts=request_attempts, ) def cves( diff --git a/pontos/nvd/cve_changes/api.py b/pontos/nvd/cve_changes/api.py index faa8456b9..f6e37f6c6 100644 --- a/pontos/nvd/cve_changes/api.py +++ b/pontos/nvd/cve_changes/api.py @@ -55,6 +55,7 @@ def __init__( token: Optional[str] = None, timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG, rate_limit: bool = True, + request_attempts: int = 1, ) -> None: """ Create a new instance of the CVE Change History API. @@ -69,12 +70,14 @@ def __init__( rolling 30 second window. See https://nvd.nist.gov/developers/start-here#divRateLimits Default: True. + request_attempts: The number of attempts per HTTP request. Defaults to 1. """ super().__init__( DEFAULT_NIST_NVD_CVE_HISTORY_URL, token=token, timeout=timeout, rate_limit=rate_limit, + request_attempts=request_attempts, ) def changes( diff --git a/tests/nvd/test_api.py b/tests/nvd/test_api.py index 56f2bd1b8..ea27f7c9e 100644 --- a/tests/nvd/test_api.py +++ b/tests/nvd/test_api.py @@ -8,7 +8,7 @@ import unittest from datetime import datetime from typing import Any, Iterator -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch from httpx import AsyncClient, Response @@ -129,6 +129,52 @@ async def test_no_rate_limit( sleep_mock.assert_not_called() + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) + @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) + async def test_retry( + self, + async_client: MagicMock, + sleep_mock: MagicMock, + ): + response_mocks = [ + MagicMock(spec=Response, is_server_error=True), + MagicMock(spec=Response, is_server_error=True), + MagicMock(spec=Response, is_server_error=True), + MagicMock(spec=Response, is_server_error=False), + ] + http_client = AsyncMock() + http_client.get.side_effect = response_mocks + async_client.return_value = http_client + + api = NVDApi("https://foo.bar/baz", request_attempts=4) + + result = await api._get() + + calls = [call(2.0), call(4.0), call(8.0)] + sleep_mock.assert_has_calls(calls) + self.assertFalse(result.is_server_error) + + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) + @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) + async def test_no_retry( + self, + async_client: MagicMock, + sleep_mock: MagicMock, + ): + response_mock = MagicMock(spec=Response) + response_mock.is_server_error = False + + http_client = AsyncMock() + http_client.get.return_value = response_mock + async_client.return_value = http_client + + api = NVDApi("https://foo.bar/baz") + + result = await api._get() + + sleep_mock.assert_not_called() + self.assertFalse(result.is_server_error) + class Result: def __init__(self, value: int) -> None: