From 4b72fd82c28a473701237b3dfe0ff3c8c52bba64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ricks?= Date: Fri, 3 Nov 2023 11:01:20 +0100 Subject: [PATCH] Change: Improve rate limit algorithm for requests against the NVD API The original algorithm was very simple by just applying a pause after a specific number of allowed requests. When doing additional costly tasks after each requests this algorithm results in unnecessary delays and the sleep might be completely obsolete. Therefore calculate the time since the last check for a sleep. If this time delta is within the rate limit we wait for the required time before doing the next request. --- pontos/nvd/api.py | 14 +++++++++----- tests/nvd/cpe/test_api.py | 20 ++++++++++++++------ tests/nvd/cve/test_api.py | 22 +++++++++++++++++----- tests/nvd/test_api.py | 16 +++++++++++----- 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/pontos/nvd/api.py b/pontos/nvd/api.py index 598d2d359..34949b329 100644 --- a/pontos/nvd/api.py +++ b/pontos/nvd/api.py @@ -16,6 +16,7 @@ # along with this program. If not, see . import asyncio +import time from abc import ABC from datetime import datetime, timezone from types import TracebackType @@ -76,10 +77,6 @@ def convert_camel_case(dct: Dict[str, Any]) -> Dict[str, Any]: return converted -async def sleep() -> None: - await asyncio.sleep(SLEEP_TIMEOUT) - - class NVDApi(ABC): """ Abstract base class for querying the NIST NVD API. @@ -120,6 +117,7 @@ def __init__( self._rate_limit = None self._request_count = 0 + self._last_sleep = time.monotonic() def _request_headers(self) -> Headers: """ @@ -141,7 +139,13 @@ async def _consider_rate_limit(self) -> None: self._request_count += 1 if self._request_count > self._rate_limit: - await sleep() + time_since_last_sleep = time.monotonic() - self._last_sleep + + if time_since_last_sleep < SLEEP_TIMEOUT: + time_to_sleep = SLEEP_TIMEOUT - time_since_last_sleep + await asyncio.sleep(time_to_sleep) + + self._last_sleep = time.monotonic() self._request_count = 0 async def _get( diff --git a/tests/nvd/cpe/test_api.py b/tests/nvd/cpe/test_api.py index 05ecacf2d..b21b9a5ed 100644 --- a/tests/nvd/cpe/test_api.py +++ b/tests/nvd/cpe/test_api.py @@ -25,7 +25,7 @@ from httpx import AsyncClient, Response from pontos.errors import PontosError -from pontos.nvd.api import now, sleep +from pontos.nvd.api import now from pontos.nvd.cpe.api import CPEApi from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext from tests.nvd import get_cpe_data @@ -54,10 +54,12 @@ def create_cpes_responses(count: int = 2) -> List[MagicMock]: class CPEApiTestCase(IsolatedAsyncioTestCase): + @patch("pontos.nvd.api.time.monotonic", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) - def setUp(self, async_client: MagicMock) -> None: + def setUp(self, async_client: MagicMock, monotonic_mock: MagicMock) -> None: self.http_client = AsyncMock() async_client.return_value = self.http_client + monotonic_mock.return_value = 0 self.api = CPEApi() async def test_no_cpe_name_id(self): @@ -102,9 +104,15 @@ async def test_cpe(self): self.assertEqual(cpe.titles, []) self.assertEqual(cpe.deprecated_by, []) - @patch("pontos.nvd.api.sleep", spec=sleep) - async def test_rate_limit(self, sleep_mock: MagicMock): - self.http_client.get.side_effect = create_cpes_responses(6) + @patch("pontos.nvd.api.time.monotonic", autospec=True) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) + async def test_rate_limit( + self, + sleep_mock: MagicMock, + monotonic_mock: MagicMock, + ): + self.http_client.get.side_effect = create_cpes_responses(8) + monotonic_mock.side_effect = [10, 11] it = aiter(self.api.cpes()) await anext(it) @@ -117,7 +125,7 @@ async def test_rate_limit(self, sleep_mock: MagicMock): await anext(it) - sleep_mock.assert_called_once_with() + sleep_mock.assert_called_once_with(20.0) @patch("pontos.nvd.cpe.api.now", spec=now) async def test_cves_last_modified_start_date(self, now_mock: MagicMock): diff --git a/tests/nvd/cve/test_api.py b/tests/nvd/cve/test_api.py index 9cf87b65b..acfcb69ba 100644 --- a/tests/nvd/cve/test_api.py +++ b/tests/nvd/cve/test_api.py @@ -24,7 +24,7 @@ from httpx import AsyncClient, Response from pontos.errors import PontosError -from pontos.nvd.api import now, sleep +from pontos.nvd.api import now from pontos.nvd.cve.api import CVEApi from pontos.nvd.models import cvss_v2, cvss_v3 from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext @@ -54,10 +54,16 @@ def create_cves_responses(count: int = 2) -> List[MagicMock]: class CVEApiTestCase(IsolatedAsyncioTestCase): + @patch("pontos.nvd.api.time.monotonic", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) - def setUp(self, async_client: MagicMock) -> None: + def setUp( + self, + async_client: MagicMock, + monotonic_mock: MagicMock, + ) -> None: self.http_client = AsyncMock() async_client.return_value = self.http_client + monotonic_mock.return_value = 0 self.api = CVEApi(token="token") async def test_no_cve_id(self): @@ -798,10 +804,16 @@ async def test_context_manager(self): self.http_client.__aenter__.assert_awaited_once() self.http_client.__aexit__.assert_awaited_once() - @patch("pontos.nvd.api.sleep", spec=sleep) - async def test_rate_limit(self, sleep_mock: MagicMock): + @patch("pontos.nvd.api.time.monotonic", autospec=True) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) + async def test_rate_limit( + self, + sleep_mock: MagicMock, + monotonic_mock: MagicMock, + ): self.http_client.get.side_effect = create_cves_responses(6) self.api._rate_limit = 5 # pylint: disable=protected-access + monotonic_mock.side_effect = [10.0, 11.0] it = aiter(self.api.cves()) await anext(it) @@ -814,4 +826,4 @@ async def test_rate_limit(self, sleep_mock: MagicMock): await anext(it) - sleep_mock.assert_called_once_with() + sleep_mock.assert_called_once_with(20.0) diff --git a/tests/nvd/test_api.py b/tests/nvd/test_api.py index ca32f446a..1df2391f2 100644 --- a/tests/nvd/test_api.py +++ b/tests/nvd/test_api.py @@ -23,7 +23,7 @@ from httpx import AsyncClient -from pontos.nvd.api import NVDApi, convert_camel_case, format_date, sleep +from pontos.nvd.api import NVDApi, convert_camel_case, format_date from tests import IsolatedAsyncioTestCase @@ -84,13 +84,19 @@ async def test_get_with_token(self, async_client: MagicMock): "https://foo.bar/baz", headers={"apiKey": "token"}, params=None ) - @patch("pontos.nvd.api.sleep", spec=sleep) + @patch("pontos.nvd.api.time.monotonic", autospec=True) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) async def test_rate_limit( - self, async_client: MagicMock, sleep_mock: MagicMock + self, + async_client: MagicMock, + sleep_mock: MagicMock, + monotonic_mock: MagicMock, ): http_client = AsyncMock() async_client.return_value = http_client + monotonic_mock.side_effect = [0.0, 10.0, 11.0] + api = NVDApi("https://foo.bar/baz") await api._get() @@ -103,9 +109,9 @@ async def test_rate_limit( await api._get() - sleep_mock.assert_called_once_with() + sleep_mock.assert_called_once_with(20.0) - @patch("pontos.nvd.api.sleep", spec=sleep) + @patch("pontos.nvd.api.asyncio.sleep", autospec=True) @patch("pontos.nvd.api.AsyncClient", spec=AsyncClient) async def test_no_rate_limit( self, async_client: MagicMock, sleep_mock: MagicMock