diff --git a/pontos/nvd/api.py b/pontos/nvd/api.py index 598d2d359..34949b329 100644 --- a/pontos/nvd/api.py +++ b/pontos/nvd/api.py @@ -16,6 +16,7 @@ # along with this program. If not, see . import asyncio +import time from abc import ABC from datetime import datetime, timezone from types import TracebackType @@ -76,10 +77,6 @@ def convert_camel_case(dct: Dict[str, Any]) -> Dict[str, Any]: return converted -async def sleep() -> None: - await asyncio.sleep(SLEEP_TIMEOUT) - - class NVDApi(ABC): """ Abstract base class for querying the NIST NVD API. @@ -120,6 +117,7 @@ def __init__( self._rate_limit = None self._request_count = 0 + self._last_sleep = time.monotonic() def _request_headers(self) -> Headers: """ @@ -141,7 +139,13 @@ async def _consider_rate_limit(self) -> None: self._request_count += 1 if self._request_count > self._rate_limit: - await sleep() + time_since_last_sleep = time.monotonic() - self._last_sleep + + if time_since_last_sleep < SLEEP_TIMEOUT: + time_to_sleep = SLEEP_TIMEOUT - time_since_last_sleep + await asyncio.sleep(time_to_sleep) + + self._last_sleep = time.monotonic() self._request_count = 0 async def _get( diff --git a/tests/nvd/cpe/test_api.py b/tests/nvd/cpe/test_api.py index 05ecacf2d..b21b9a5ed 100644 --- a/tests/nvd/cpe/test_api.py +++ b/tests/nvd/cpe/test_api.py @@ -25,7 +25,7 @@ from httpx import AsyncClient, Response from pontos.errors import PontosError -from pontos.nvd.api import now, sleep +from pontos.nvd.api import now from pontos.nvd.cpe.api import CPEApi from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext from tests.nvd import get_cpe_data @@ -54,10 +54,12 @@ def create_cpes_responses(count: int = 2) -> List[MagicMock]: class CPEApiTestCase(IsolatedAsyncioTestCase): + @patch("pontos.nvd.api.time.monotonic", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) - def setUp(self, async_client: MagicMock) -> None: + def setUp(self, async_client: MagicMock, monotonic_mock: MagicMock) -> None: self.http_client = AsyncMock() async_client.return_value = self.http_client + monotonic_mock.return_value = 0 self.api = CPEApi() async def test_no_cpe_name_id(self): @@ -102,9 +104,15 @@ async def test_cpe(self): self.assertEqual(cpe.titles, []) self.assertEqual(cpe.deprecated_by, []) - @patch("pontos.nvd.api.sleep", spec=sleep) - async def test_rate_limit(self, sleep_mock: MagicMock): - self.http_client.get.side_effect = create_cpes_responses(6) + @patch("pontos.nvd.api.time.monotonic", autospec=True) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) + async def test_rate_limit( + self, + sleep_mock: MagicMock, + monotonic_mock: MagicMock, + ): + self.http_client.get.side_effect = create_cpes_responses(8) + monotonic_mock.side_effect = [10, 11] it = aiter(self.api.cpes()) await anext(it) @@ -117,7 +125,7 @@ async def test_rate_limit(self, sleep_mock: MagicMock): await anext(it) - sleep_mock.assert_called_once_with() + sleep_mock.assert_called_once_with(20.0) @patch("pontos.nvd.cpe.api.now", spec=now) async def test_cves_last_modified_start_date(self, now_mock: MagicMock): diff --git a/tests/nvd/cve/test_api.py b/tests/nvd/cve/test_api.py index 9cf87b65b..acfcb69ba 100644 --- a/tests/nvd/cve/test_api.py +++ b/tests/nvd/cve/test_api.py @@ -24,7 +24,7 @@ from httpx import AsyncClient, Response from pontos.errors import PontosError -from pontos.nvd.api import now, sleep +from pontos.nvd.api import now from pontos.nvd.cve.api import CVEApi from pontos.nvd.models import cvss_v2, cvss_v3 from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext @@ -54,10 +54,16 @@ def create_cves_responses(count: int = 2) -> List[MagicMock]: class CVEApiTestCase(IsolatedAsyncioTestCase): + @patch("pontos.nvd.api.time.monotonic", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) - def setUp(self, async_client: MagicMock) -> None: + def setUp( + self, + async_client: MagicMock, + monotonic_mock: MagicMock, + ) -> None: self.http_client = AsyncMock() async_client.return_value = self.http_client + monotonic_mock.return_value = 0 self.api = CVEApi(token="token") async def test_no_cve_id(self): @@ -798,10 +804,16 @@ async def test_context_manager(self): self.http_client.__aenter__.assert_awaited_once() self.http_client.__aexit__.assert_awaited_once() - @patch("pontos.nvd.api.sleep", spec=sleep) - async def test_rate_limit(self, sleep_mock: MagicMock): + @patch("pontos.nvd.api.time.monotonic", autospec=True) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) + async def test_rate_limit( + self, + sleep_mock: MagicMock, + monotonic_mock: MagicMock, + ): self.http_client.get.side_effect = create_cves_responses(6) self.api._rate_limit = 5 # pylint: disable=protected-access + monotonic_mock.side_effect = [10.0, 11.0] it = aiter(self.api.cves()) await anext(it) @@ -814,4 +826,4 @@ async def test_rate_limit(self, sleep_mock: MagicMock): await anext(it) - sleep_mock.assert_called_once_with() + sleep_mock.assert_called_once_with(20.0) diff --git a/tests/nvd/test_api.py b/tests/nvd/test_api.py index ca32f446a..1df2391f2 100644 --- a/tests/nvd/test_api.py +++ b/tests/nvd/test_api.py @@ -23,7 +23,7 @@ from httpx import AsyncClient -from pontos.nvd.api import NVDApi, convert_camel_case, format_date, sleep +from pontos.nvd.api import NVDApi, convert_camel_case, format_date from tests import IsolatedAsyncioTestCase @@ -84,13 +84,19 @@ async def test_get_with_token(self, async_client: MagicMock): "https://foo.bar/baz", headers={"apiKey": "token"}, params=None ) - @patch("pontos.nvd.api.sleep", spec=sleep) + @patch("pontos.nvd.api.time.monotonic", autospec=True) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) async def test_rate_limit( - self, async_client: MagicMock, sleep_mock: MagicMock + self, + async_client: MagicMock, + sleep_mock: MagicMock, + monotonic_mock: MagicMock, ): http_client = AsyncMock() async_client.return_value = http_client + monotonic_mock.side_effect = [0.0, 10.0, 11.0] + api = NVDApi("https://foo.bar/baz") await api._get() @@ -103,9 +109,9 @@ async def test_rate_limit( await api._get() - sleep_mock.assert_called_once_with() + sleep_mock.assert_called_once_with(20.0) - @patch("pontos.nvd.api.sleep", spec=sleep) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) async def test_no_rate_limit( self, async_client: MagicMock, sleep_mock: MagicMock