Skip to content

Commit

Permalink
Add: Retrying for NVDApi (#1070)
Browse files Browse the repository at this point in the history
* Add: Retrying of requests to NVDApi

* Add: Usage of retries to CPEApi

* Add: Usage of retries to CVEApi

* Add: Usage of retries to CVEChangesApi

* Change: Rename attempts to request_attempts

* Add: Unit tests for retrying of requests in NVDApi

* Change: Fix unit test for full coverage
  • Loading branch information
n-thumann authored Dec 13, 2024
1 parent 01eeea2 commit 1fc864f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 3 deletions.
19 changes: 17 additions & 2 deletions pontos/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SLEEP_TIMEOUT = 30.0 # in seconds
DEFAULT_TIMEOUT = 180.0 # three minutes
DEFAULT_TIMEOUT_CONFIG = Timeout(DEFAULT_TIMEOUT) # three minutes
RETRY_DELAY = 2.0 # in seconds

Headers = Dict[str, str]
Params = Dict[str, Union[str, int]]
Expand Down Expand Up @@ -342,6 +343,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CVE API.
Expand All @@ -357,6 +359,7 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
self._url = url
self._token = token
Expand All @@ -370,6 +373,8 @@ def __init__(
self._request_count = 0
self._last_sleep = time.monotonic()

self._request_attempts = request_attempts

def _request_headers(self) -> Headers:
"""
Get the default request headers
Expand Down Expand Up @@ -409,9 +414,19 @@ async def _get(
"""
headers = self._request_headers()

await self._consider_rate_limit()
for attempt in range(self._request_attempts):
if attempt > 0:
delay = RETRY_DELAY**attempt
await asyncio.sleep(delay)

await self._consider_rate_limit()
response = await self._client.get(
self._url, headers=headers, params=params
)
if not response.is_server_error:
break

return await self._client.get(self._url, headers=headers, params=params)
return response

async def __aenter__(self) -> "NVDApi":
# reset rate limit counter
Expand Down
3 changes: 3 additions & 0 deletions pontos/nvd/cpe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CPE API.
Expand All @@ -75,12 +76,14 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
super().__init__(
DEFAULT_NIST_NVD_CPES_URL,
token=token,
timeout=timeout,
rate_limit=rate_limit,
request_attempts=request_attempts,
)

async def cpe(self, cpe_name_id: Union[str, UUID]) -> CPE:
Expand Down
3 changes: 3 additions & 0 deletions pontos/nvd/cve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CVE API.
Expand All @@ -79,12 +80,14 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
super().__init__(
DEFAULT_NIST_NVD_CVES_URL,
token=token,
timeout=timeout,
rate_limit=rate_limit,
request_attempts=request_attempts,
)

def cves(
Expand Down
3 changes: 3 additions & 0 deletions pontos/nvd/cve_changes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CVE Change History API.
Expand All @@ -69,12 +70,14 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
super().__init__(
DEFAULT_NIST_NVD_CVE_HISTORY_URL,
token=token,
timeout=timeout,
rate_limit=rate_limit,
request_attempts=request_attempts,
)

def changes(
Expand Down
48 changes: 47 additions & 1 deletion tests/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest
from datetime import datetime
from typing import Any, Iterator
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, call, patch

from httpx import AsyncClient, Response

Expand Down Expand Up @@ -129,6 +129,52 @@ async def test_no_rate_limit(

sleep_mock.assert_not_called()

@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
async def test_retry(
self,
async_client: MagicMock,
sleep_mock: MagicMock,
):
response_mocks = [
MagicMock(spec=Response, is_server_error=True),
MagicMock(spec=Response, is_server_error=True),
MagicMock(spec=Response, is_server_error=True),
MagicMock(spec=Response, is_server_error=False),
]
http_client = AsyncMock()
http_client.get.side_effect = response_mocks
async_client.return_value = http_client

api = NVDApi("https://foo.bar/baz", request_attempts=4)

result = await api._get()

calls = [call(2.0), call(4.0), call(8.0)]
sleep_mock.assert_has_calls(calls)
self.assertFalse(result.is_server_error)

@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
async def test_no_retry(
self,
async_client: MagicMock,
sleep_mock: MagicMock,
):
response_mock = MagicMock(spec=Response)
response_mock.is_server_error = False

http_client = AsyncMock()
http_client.get.return_value = response_mock
async_client.return_value = http_client

api = NVDApi("https://foo.bar/baz")

result = await api._get()

sleep_mock.assert_not_called()
self.assertFalse(result.is_server_error)


class Result:
def __init__(self, value: int) -> None:
Expand Down

0 comments on commit 1fc864f

Please sign in to comment.