Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Retrying for NVDApi #1070

Merged
merged 7 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pontos/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SLEEP_TIMEOUT = 30.0 # in seconds
DEFAULT_TIMEOUT = 180.0 # three minutes
DEFAULT_TIMEOUT_CONFIG = Timeout(DEFAULT_TIMEOUT) # three minutes
RETRY_DELAY = 2.0 # in seconds

Headers = Dict[str, str]
Params = Dict[str, Union[str, int]]
Expand Down Expand Up @@ -342,6 +343,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CVE API.
Expand All @@ -357,6 +359,7 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
self._url = url
self._token = token
Expand All @@ -370,6 +373,8 @@ def __init__(
self._request_count = 0
self._last_sleep = time.monotonic()

self._request_attempts = request_attempts

def _request_headers(self) -> Headers:
"""
Get the default request headers
Expand Down Expand Up @@ -409,9 +414,19 @@ async def _get(
"""
headers = self._request_headers()

await self._consider_rate_limit()
for attempt in range(self._request_attempts):
if attempt > 0:
delay = RETRY_DELAY**attempt
await asyncio.sleep(delay)

await self._consider_rate_limit()
response = await self._client.get(
self._url, headers=headers, params=params
)
if not response.is_server_error:
break

return await self._client.get(self._url, headers=headers, params=params)
return response

async def __aenter__(self) -> "NVDApi":
# reset rate limit counter
Expand Down
3 changes: 3 additions & 0 deletions pontos/nvd/cpe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CPE API.
Expand All @@ -75,12 +76,14 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
super().__init__(
DEFAULT_NIST_NVD_CPES_URL,
token=token,
timeout=timeout,
rate_limit=rate_limit,
request_attempts=request_attempts,
)

async def cpe(self, cpe_name_id: Union[str, UUID]) -> CPE:
Expand Down
3 changes: 3 additions & 0 deletions pontos/nvd/cve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CVE API.
Expand All @@ -79,12 +80,14 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
super().__init__(
DEFAULT_NIST_NVD_CVES_URL,
token=token,
timeout=timeout,
rate_limit=rate_limit,
request_attempts=request_attempts,
)

def cves(
Expand Down
3 changes: 3 additions & 0 deletions pontos/nvd/cve_changes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[Timeout] = DEFAULT_TIMEOUT_CONFIG,
rate_limit: bool = True,
request_attempts: int = 1,
) -> None:
"""
Create a new instance of the CVE Change History API.
Expand All @@ -69,12 +70,14 @@ def __init__(
rolling 30 second window.
See https://nvd.nist.gov/developers/start-here#divRateLimits
Default: True.
request_attempts: The number of attempts per HTTP request. Defaults to 1.
"""
super().__init__(
DEFAULT_NIST_NVD_CVE_HISTORY_URL,
token=token,
timeout=timeout,
rate_limit=rate_limit,
request_attempts=request_attempts,
)

def changes(
Expand Down
48 changes: 47 additions & 1 deletion tests/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest
from datetime import datetime
from typing import Any, Iterator
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, call, patch

from httpx import AsyncClient, Response

Expand Down Expand Up @@ -129,6 +129,52 @@ async def test_no_rate_limit(

sleep_mock.assert_not_called()

@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
async def test_retry(
self,
async_client: MagicMock,
sleep_mock: MagicMock,
):
response_mocks = [
MagicMock(spec=Response, is_server_error=True),
MagicMock(spec=Response, is_server_error=True),
MagicMock(spec=Response, is_server_error=True),
MagicMock(spec=Response, is_server_error=False),
]
http_client = AsyncMock()
http_client.get.side_effect = response_mocks
async_client.return_value = http_client

api = NVDApi("https://foo.bar/baz", request_attempts=4)

result = await api._get()

calls = [call(2.0), call(4.0), call(8.0)]
sleep_mock.assert_has_calls(calls)
self.assertFalse(result.is_server_error)

@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
async def test_no_retry(
self,
async_client: MagicMock,
sleep_mock: MagicMock,
):
response_mock = MagicMock(spec=Response)
response_mock.is_server_error = False

http_client = AsyncMock()
http_client.get.return_value = response_mock
async_client.return_value = http_client

api = NVDApi("https://foo.bar/baz")

result = await api._get()

sleep_mock.assert_not_called()
self.assertFalse(result.is_server_error)


class Result:
def __init__(self, value: int) -> None:
Expand Down
Loading