Skip to content

Commit

Permalink
Change: Improve rate limit algorithm for requests against the NVD API
Browse files Browse the repository at this point in the history
The original algorithm was very simple by just applying a pause after a
specific number of allowed requests. When doing additional costly tasks
after each requests this algorithm results in unnecessary delays and the
sleep might be completely obsolete. Therefore calculate the time since
the last check for a sleep. If this time delta is within the rate limit
we wait for the required time before doing the next request.
  • Loading branch information
bjoernricks authored and greenbonebot committed Nov 3, 2023
1 parent 124bd52 commit 4b72fd8
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 21 deletions.
14 changes: 9 additions & 5 deletions pontos/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import time
from abc import ABC
from datetime import datetime, timezone
from types import TracebackType
Expand Down Expand Up @@ -76,10 +77,6 @@ def convert_camel_case(dct: Dict[str, Any]) -> Dict[str, Any]:
return converted


async def sleep() -> None:
await asyncio.sleep(SLEEP_TIMEOUT)


class NVDApi(ABC):
"""
Abstract base class for querying the NIST NVD API.
Expand Down Expand Up @@ -120,6 +117,7 @@ def __init__(
self._rate_limit = None

self._request_count = 0
self._last_sleep = time.monotonic()

def _request_headers(self) -> Headers:
"""
Expand All @@ -141,7 +139,13 @@ async def _consider_rate_limit(self) -> None:

self._request_count += 1
if self._request_count > self._rate_limit:
await sleep()
time_since_last_sleep = time.monotonic() - self._last_sleep

if time_since_last_sleep < SLEEP_TIMEOUT:
time_to_sleep = SLEEP_TIMEOUT - time_since_last_sleep
await asyncio.sleep(time_to_sleep)

self._last_sleep = time.monotonic()
self._request_count = 0

async def _get(
Expand Down
20 changes: 14 additions & 6 deletions tests/nvd/cpe/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from httpx import AsyncClient, Response

from pontos.errors import PontosError
from pontos.nvd.api import now, sleep
from pontos.nvd.api import now
from pontos.nvd.cpe.api import CPEApi
from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext
from tests.nvd import get_cpe_data
Expand Down Expand Up @@ -54,10 +54,12 @@ def create_cpes_responses(count: int = 2) -> List[MagicMock]:


class CPEApiTestCase(IsolatedAsyncioTestCase):
@patch("pontos.nvd.api.time.monotonic", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
def setUp(self, async_client: MagicMock) -> None:
def setUp(self, async_client: MagicMock, monotonic_mock: MagicMock) -> None:
self.http_client = AsyncMock()
async_client.return_value = self.http_client
monotonic_mock.return_value = 0
self.api = CPEApi()

async def test_no_cpe_name_id(self):
Expand Down Expand Up @@ -102,9 +104,15 @@ async def test_cpe(self):
self.assertEqual(cpe.titles, [])
self.assertEqual(cpe.deprecated_by, [])

@patch("pontos.nvd.api.sleep", spec=sleep)
async def test_rate_limit(self, sleep_mock: MagicMock):
self.http_client.get.side_effect = create_cpes_responses(6)
@patch("pontos.nvd.api.time.monotonic", autospec=True)
@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
async def test_rate_limit(
self,
sleep_mock: MagicMock,
monotonic_mock: MagicMock,
):
self.http_client.get.side_effect = create_cpes_responses(8)
monotonic_mock.side_effect = [10, 11]

it = aiter(self.api.cpes())
await anext(it)
Expand All @@ -117,7 +125,7 @@ async def test_rate_limit(self, sleep_mock: MagicMock):

await anext(it)

sleep_mock.assert_called_once_with()
sleep_mock.assert_called_once_with(20.0)

@patch("pontos.nvd.cpe.api.now", spec=now)
async def test_cves_last_modified_start_date(self, now_mock: MagicMock):
Expand Down
22 changes: 17 additions & 5 deletions tests/nvd/cve/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from httpx import AsyncClient, Response

from pontos.errors import PontosError
from pontos.nvd.api import now, sleep
from pontos.nvd.api import now
from pontos.nvd.cve.api import CVEApi
from pontos.nvd.models import cvss_v2, cvss_v3
from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext
Expand Down Expand Up @@ -54,10 +54,16 @@ def create_cves_responses(count: int = 2) -> List[MagicMock]:


class CVEApiTestCase(IsolatedAsyncioTestCase):
@patch("pontos.nvd.api.time.monotonic", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
def setUp(self, async_client: MagicMock) -> None:
def setUp(
self,
async_client: MagicMock,
monotonic_mock: MagicMock,
) -> None:
self.http_client = AsyncMock()
async_client.return_value = self.http_client
monotonic_mock.return_value = 0
self.api = CVEApi(token="token")

async def test_no_cve_id(self):
Expand Down Expand Up @@ -798,10 +804,16 @@ async def test_context_manager(self):
self.http_client.__aenter__.assert_awaited_once()
self.http_client.__aexit__.assert_awaited_once()

@patch("pontos.nvd.api.sleep", spec=sleep)
async def test_rate_limit(self, sleep_mock: MagicMock):
@patch("pontos.nvd.api.time.monotonic", autospec=True)
@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
async def test_rate_limit(
self,
sleep_mock: MagicMock,
monotonic_mock: MagicMock,
):
self.http_client.get.side_effect = create_cves_responses(6)
self.api._rate_limit = 5 # pylint: disable=protected-access
monotonic_mock.side_effect = [10.0, 11.0]

it = aiter(self.api.cves())
await anext(it)
Expand All @@ -814,4 +826,4 @@ async def test_rate_limit(self, sleep_mock: MagicMock):

await anext(it)

sleep_mock.assert_called_once_with()
sleep_mock.assert_called_once_with(20.0)
16 changes: 11 additions & 5 deletions tests/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from httpx import AsyncClient

from pontos.nvd.api import NVDApi, convert_camel_case, format_date, sleep
from pontos.nvd.api import NVDApi, convert_camel_case, format_date
from tests import IsolatedAsyncioTestCase


Expand Down Expand Up @@ -84,13 +84,19 @@ async def test_get_with_token(self, async_client: MagicMock):
"https://foo.bar/baz", headers={"apiKey": "token"}, params=None
)

@patch("pontos.nvd.api.sleep", spec=sleep)
@patch("pontos.nvd.api.time.monotonic", autospec=True)
@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
async def test_rate_limit(
self, async_client: MagicMock, sleep_mock: MagicMock
self,
async_client: MagicMock,
sleep_mock: MagicMock,
monotonic_mock: MagicMock,
):
http_client = AsyncMock()
async_client.return_value = http_client
monotonic_mock.side_effect = [0.0, 10.0, 11.0]

api = NVDApi("https://foo.bar/baz")

await api._get()
Expand All @@ -103,9 +109,9 @@ async def test_rate_limit(

await api._get()

sleep_mock.assert_called_once_with()
sleep_mock.assert_called_once_with(20.0)

@patch("pontos.nvd.api.sleep", spec=sleep)
@patch("pontos.nvd.api.asyncio.sleep", autospec=True)
@patch("pontos.nvd.api.AsyncClient", spec=AsyncClient)
async def test_no_rate_limit(
self, async_client: MagicMock, sleep_mock: MagicMock
Expand Down

0 comments on commit 4b72fd8

Please sign in to comment.