From 7913e9763a324a15292447d0de4e5f0e329f183d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ricks?= Date: Wed, 29 Nov 2023 17:25:22 +0100 Subject: [PATCH] Change: Introduce more flexible NVDResults class Before all NVD API classes methods returned an async iterator. This didn't allow much control of what the user actually wants and how the requests are issued. To improve the situation a new NVDResults class is returned which itself is an async iterator so that the previous API is kept compatible. But additionally the NVDResults instance allows to get the plain JSON data, the number of available results and also to iterate over chunks of results (which the NVD API is always returning). Most important improvement the NVDResults instance keeps the state. That means if an http error occurs it is possible to request the same data again. With the old API the requests need to start from the beginning if something did go wrong. For example if we downloaded already 100k CVEs and a http error was raised we needed to start from CVE number 1 again. With the new implementation we can just continue with the last request again. --- pontos/nvd/api.py | 255 +++++++++++++++++++++- pontos/nvd/cpe/api.py | 73 +++---- pontos/nvd/cve/api.py | 80 +++---- pontos/nvd/cve_changes/api.py | 65 +++--- tests/nvd/cve_changes/test_api.py | 42 +++- tests/nvd/test_api.py | 341 +++++++++++++++++++++++++++++- 6 files changed, 721 insertions(+), 135 deletions(-) diff --git a/pontos/nvd/api.py b/pontos/nvd/api.py index dffe646e1..00aa571c5 100644 --- a/pontos/nvd/api.py +++ b/pontos/nvd/api.py @@ -20,10 +20,26 @@ from abc import ABC from datetime import datetime, timezone from types import TracebackType -from typing import Any, Dict, Optional, Type, Union +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Dict, + Generator, + Generic, + Iterator, + Optional, + Sequence, + Type, + TypeVar, + Union, +) -from httpx import AsyncClient, Response, Timeout +from httpx import URL, AsyncClient, Response, Timeout +from pontos.errors import PontosError from pontos.helper import snake_case SLEEP_TIMEOUT = 30.0 # in seconds @@ -78,6 +94,239 @@ def convert_camel_case(dct: Dict[str, Any]) -> Dict[str, Any]: return converted +class NoMoreResults(PontosError): + """ + Raised if the NVD API has no more results to consume + """ + + +class InvalidState(PontosError): + """ + Raised if the state of the NVD API is invalid + """ + + +T = TypeVar("T") + +result_iterator_func = Callable[[JSON], Iterator[T]] + + +class NVDResults(Generic[T], AsyncIterable[T], Awaitable["NVDResults"]): + """ + A generic object for accessing the results of a NVD API response + + It implements the pagination and will issue requests against the NVD API. + """ + + def __init__( + self, + api: "NVDApi", + params: Params, + result_func: result_iterator_func, + *, + request_results: Optional[int] = None, + results_per_page: Optional[int] = None, + start_index: int = 0, + ) -> None: + self._api = api + self._params = params + self._url: Optional[URL] = None + + self._data: Optional[JSON] = None + self._it: Optional[Iterator[T]] = None + self._total_results: Optional[int] = None + self._downloaded_results: int = 0 + + self._start_index = start_index + self._request_results = request_results + self._results_per_page = results_per_page + + self._current_index = start_index + self._current_request_results = request_results + self._current_results_per_page = results_per_page + + self._result_func = result_func + + async def chunks(self) -> AsyncIterator[Sequence[T]]: + """ + Return the results in chunks + + The size of the chunks is defined by results_per_page. + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + + async for results in nvd_results.chunks(): + for result in results: + print(result) + """ + while True: + try: + if self._it: + yield list(self._it) + await self._next_iterator() + except NoMoreResults: + return + + async def items(self) -> AsyncIterator[T]: + """ + Return the results of the NVD API response + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + + async for result in nvd_results.items(): + print(result) + """ + while True: + try: + if self._it: + for result in self._it: + yield result + await self._next_iterator() + except NoMoreResults: + return + + async def json(self) -> Optional[JSON]: + """ + Return the result from the NVD API request as JSON + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + while data := await nvd_results.json(): + print(data) + + Returns: + The response data as JSON or None if the response is exhausted. + """ + try: + if not self._data: + await self._next_iterator() + + data = self._data + self._data = None + return data + except NoMoreResults: + return None + + def __len__(self) -> int: + """ + Get the number of available result items for a NVD API request + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + total_results = len(nvd_results) # None because it hasn't been awaited yet + json = await nvd_results.json() # request the plain JSON data + total_results = len(nvd_results) # contains the total number of results now + + nvd_results: NVDResults = ... + total_results = len(nvd_results) # None because it hasn't been awaited yet + async for result in nvd_results: + print(result) + total_results = len(nvd_results) # contains the total number of results now + + Returns: + The total number of available results if the NVDResults has been awaited + """ + if self._total_results is None: + raise InvalidState( + f"{self.__class__.__name__} has not been awaited yet." + ) + return self._total_results + + def __aiter__(self) -> AsyncIterator[T]: + """ + Return the results of the NVD API response + + Same as the items() method. @see items() + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + + async for result in nvd_results: + print(result) + """ + return self.items() + + def __await__(self) -> Generator[Any, None, "NVDResults"]: + """ + Request the next results from the NVD API + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + print(len(nvd_results)) # None, because no request has been send yet + await nvd_results # creates a request to the NVD API + print(len(nvd_results)) + + Returns: + The response data as JSON or None if the response is exhausted. + """ + + return self._next_iterator().__await__() + + async def _load_next_data(self) -> None: + if ( + not self._current_request_results + or self._downloaded_results < self._current_request_results + ): + params = self._params + params["startIndex"] = self._current_index + + if self._current_results_per_page is not None: + params["resultsPerPage"] = self._current_results_per_page + + response = await self._api._get(params=params) + response.raise_for_status() + + self._url = response.url + data: JSON = response.json(object_hook=convert_camel_case) + + self._data = data + self._current_results_per_page = int(data["results_per_page"]) # type: ignore + self._total_results = int(data["total_results"]) # type: ignore + self._current_index += self._current_results_per_page + self._downloaded_results += self._current_results_per_page + + if not self._current_request_results: + self._current_request_results = self._total_results + + if ( + self._request_results + and self._downloaded_results + self._current_results_per_page + > self._request_results + ): + # avoid downloading more results then requested + self._current_results_per_page = ( + self._request_results - self._downloaded_results + ) + + else: + raise NoMoreResults() + + async def _get_next_iterator(self) -> Iterator[T]: + await self._load_next_data() + return self._result_func(self._data) # type: ignore + + async def _next_iterator(self) -> "NVDResults": + self._it = await self._get_next_iterator() + return self + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} url="{self._url}" total_results={self._total_results} start_index={self._start_index} current_index={self._current_index} results_per_page={self._results_per_page}>' + + class NVDApi(ABC): """ Abstract base class for querying the NIST NVD API. @@ -155,7 +404,7 @@ async def _get( params: Optional[Params] = None, ) -> Response: """ - A request against the NIST NVD CVE REST API. + A request against the NIST NVD REST API. """ headers = self._request_headers() diff --git a/pontos/nvd/cpe/api.py b/pontos/nvd/cpe/api.py index 072b3b07d..d2001d601 100644 --- a/pontos/nvd/cpe/api.py +++ b/pontos/nvd/cpe/api.py @@ -19,8 +19,8 @@ from datetime import datetime from types import TracebackType from typing import ( - AsyncIterator, - Iterable, + Any, + Iterator, List, Optional, Type, @@ -35,6 +35,7 @@ DEFAULT_TIMEOUT_CONFIG, JSON, NVDApi, + NVDResults, Params, convert_camel_case, format_date, @@ -46,6 +47,11 @@ MAX_CPES_PER_PAGE = 10000 +def _result_iterator(data: JSON) -> Iterator[CPE]: + results: list[dict[str, Any]] = data.get("products", []) # type: ignore + return (CPE.from_dict(result["cpe"]) for result in results) + + class CPEApi(NVDApi): """ API for querying the NIST NVD CPE information. @@ -125,7 +131,7 @@ async def cpe(self, cpe_name_id: Union[str, UUID]) -> CPE: product = products[0] return CPE.from_dict(product["cpe"]) - async def cpes( + def cpes( self, *, last_modified_start_date: Optional[datetime] = None, @@ -134,7 +140,7 @@ async def cpes( keywords: Optional[Union[List[str], str]] = None, match_criteria_id: Optional[str] = None, request_results: Optional[int] = None, - ) -> AsyncIterator[CPE]: + ) -> NVDResults[CPE]: """ Get all CPEs for the provided arguments @@ -155,9 +161,9 @@ async def cpes( to download all available CPEs. Returns: - An async iterator of CPE model instances. + A NVDResponse for CPEs - Example: + Examples: .. code-block:: python from pontos.nvd.cpe import CPEApi @@ -165,6 +171,14 @@ async def cpes( async with CPEApi() as api: async for cpe in api.cpes(keywords=["Mac OS X"]): print(cpe.cpe_name, cpe.cpe_name_id) + + json = await api.cpes(request_results=10).json() + + async for cpes in api.cpes( + cpe_match_string="cpe:2.3:o:microsoft:windows_7:-:*:*:*:*:*:*:*", + ).chunks(): + for cpe in cpes: + print(cpe) """ params: Params = {} if last_modified_start_date: @@ -189,51 +203,20 @@ async def cpes( params["matchCriteriaId"] = match_criteria_id start_index = 0 - downloaded_results = 0 results_per_page = ( request_results if request_results and request_results < MAX_CPES_PER_PAGE else MAX_CPES_PER_PAGE ) - total_results = None - requested_results = request_results - - while ( - requested_results is None or downloaded_results < requested_results - ): - params["startIndex"] = start_index - - if results_per_page is not None: - params["resultsPerPage"] = results_per_page - - response = await self._get(params=params) - response.raise_for_status() - data: JSON = response.json(object_hook=convert_camel_case) - - results_per_page: int = data["results_per_page"] # type: ignore - total_results: int = data["total_results"] # type: ignore - products: Iterable = data.get("products", []) # type: ignore - - if not requested_results: - requested_results = total_results - - for product in products: - yield CPE.from_dict(product["cpe"]) - - if results_per_page is None: - # just be safe here. should never occur - results_per_page = len(products) - - start_index += results_per_page - downloaded_results += results_per_page - - if ( - request_results - and downloaded_results + results_per_page > request_results - ): - # avoid downloading more results then requested - results_per_page = request_results - downloaded_results + return NVDResults( + self, + params, + _result_iterator, + request_results=request_results, + results_per_page=results_per_page, + start_index=start_index, + ) async def __aenter__(self) -> "CPEApi": await super().__aenter__() diff --git a/pontos/nvd/cve/api.py b/pontos/nvd/cve/api.py index cd35acc7d..06565f89a 100644 --- a/pontos/nvd/cve/api.py +++ b/pontos/nvd/cve/api.py @@ -18,8 +18,8 @@ from datetime import datetime from types import TracebackType from typing import ( - AsyncIterator, Iterable, + Iterator, List, Optional, Type, @@ -33,6 +33,7 @@ DEFAULT_TIMEOUT_CONFIG, JSON, NVDApi, + NVDResults, Params, convert_camel_case, format_date, @@ -48,6 +49,13 @@ MAX_CVES_PER_PAGE = 2000 +def _result_iterator(data: JSON) -> Iterator[CVE]: + vulnerabilities: Iterable = data.get("vulnerabilities", []) # type: ignore + return ( + CVE.from_dict(vulnerability["cve"]) for vulnerability in vulnerabilities + ) + + class CVEApi(NVDApi): """ API for querying the NIST NVD CVE information. @@ -91,7 +99,7 @@ def __init__( rate_limit=rate_limit, ) - async def cves( + def cves( self, *, last_modified_start_date: Optional[datetime] = None, @@ -113,7 +121,7 @@ async def cves( has_kev: Optional[bool] = None, has_oval: Optional[bool] = None, request_results: Optional[int] = None, - ) -> AsyncIterator[CVE]: + ) -> NVDResults[CVE]: """ Get all CVEs for the provided arguments @@ -165,9 +173,9 @@ async def cves( to download all available CVEs. Returns: - An async iterator to iterate over CVE model instances + A NVDResponse for CVEs - Example: + Examples: .. code-block:: python from pontos.nvd.cve import CVEApi @@ -175,6 +183,16 @@ async def cves( async with CVEApi() as api: async for cve in api.cves(keywords=["Mac OS X", "kernel"]): print(cve.id) + + json = await api.cves( + cpe_name="cpe:2.3:o:microsoft:windows_7:-:*:*:*:*:*:x64:*", + ).json() + + async for cves in api.cves( + virtual_match_string="cpe:2.3:o:microsoft:windows_7:-:*:*:*:*:*:x64:*", + ).chunks(): + for cve in cves: + print(cve) """ params: Params = {} if last_modified_start_date: @@ -231,54 +249,20 @@ async def cves( if has_oval: params["hasOval"] = "" - start_index: int = 0 - downloaded_results = 0 + start_index = 0 results_per_page = ( request_results if request_results and request_results < MAX_CVES_PER_PAGE else MAX_CVES_PER_PAGE ) - total_results = None - requested_results = request_results - - while ( - requested_results is None or downloaded_results < requested_results - ): - params["startIndex"] = start_index - - if results_per_page is not None: - params["resultsPerPage"] = results_per_page - - response = await self._get(params=params) - response.raise_for_status() - - data: JSON = response.json(object_hook=convert_camel_case) - - results_per_page: int = data["results_per_page"] # type: ignore - total_results: int = data["total_results"] # type: ignore - vulnerabilities: Iterable = data.get( # type: ignore - "vulnerabilities", [] - ) - - if not requested_results: - requested_results = total_results - - for vulnerability in vulnerabilities: - yield CVE.from_dict(vulnerability["cve"]) - - if results_per_page is None: - # just be safe here. should never occur - results_per_page = len(vulnerabilities) - - start_index += results_per_page - downloaded_results += results_per_page - - if ( - request_results - and downloaded_results + results_per_page > request_results - ): - # avoid downloading more results then requested - results_per_page = request_results - downloaded_results + return NVDResults( + self, + params, + _result_iterator, + request_results=request_results, + results_per_page=results_per_page, + start_index=start_index, + ) async def cve(self, cve_id: str) -> CVE: """ diff --git a/pontos/nvd/cve_changes/api.py b/pontos/nvd/cve_changes/api.py index 777a60841..1378c4dfd 100644 --- a/pontos/nvd/cve_changes/api.py +++ b/pontos/nvd/cve_changes/api.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from datetime import datetime, timedelta -from typing import AsyncIterator, Iterable, Optional, Union +from typing import Any, Iterator, Optional, Union from httpx import Timeout @@ -12,8 +12,8 @@ DEFAULT_TIMEOUT_CONFIG, JSON, NVDApi, + NVDResults, Params, - convert_camel_case, format_date, now, ) @@ -25,6 +25,13 @@ "https://services.nvd.nist.gov/rest/json/cvehistory/2.0" ) +MAX_CVE_CHANGES_PER_PAGE = 5000 + + +def _result_iterator(data: JSON) -> Iterator[CVEChange]: + results: list[dict[str, Any]] = data.get("cve_changes", []) # type: ignore + return (CVEChange.from_dict(result["change"]) for result in results) + class CVEChangesApi(NVDApi): """ @@ -70,14 +77,15 @@ def __init__( rate_limit=rate_limit, ) - async def changes( + def changes( self, *, change_start_date: Optional[datetime] = None, change_end_date: Optional[datetime] = None, cve_id: Optional[str] = None, event_name: Optional[Union[EventName, str]] = None, - ) -> AsyncIterator[CVEChange]: + request_results: Optional[int] = None, + ) -> NVDResults[CVEChange]: """ Get all CVEs for the provided arguments @@ -88,10 +96,11 @@ async def changes( change_end_date: Return all CVE changes before this date. cve_id: Return all CVE changes for this Common Vulnerabilities and Exposures identifier. event_name: Return all CVE changes with this event name. - + request_results: Number of CVEs changes to download. Set to None + (default) to download all available CPEs. Returns: - An async iterator to iterate over CVEChange model instances + A NVDResponse for CVE changes Example: .. code-block:: python @@ -101,9 +110,15 @@ async def changes( async with CVEChangesApi() as api: async for cve_change in api.changes(event_name=EventName.INITIAL_ANALYSIS): print(cve_change) - """ - total_results: Optional[int] = None + json = api.changes(event_name=EventName.INITIAL_ANALYSIS).json() + + async for changes in api.changes( + event_name=EventName.INITIAL_ANALYSIS, + ).chunks(): + for cve_change in changes: + print(cve_change) + """ if change_start_date and not change_end_date: change_end_date = min( now(), change_start_date + timedelta(days=120) @@ -128,27 +143,19 @@ async def changes( params["eventName"] = event_name start_index: int = 0 - results_per_page = None - - while total_results is None or start_index < total_results: - params["startIndex"] = start_index - - if results_per_page is not None: - params["resultsPerPage"] = results_per_page - - response = await self._get(params=params) - response.raise_for_status() - - data: JSON = response.json(object_hook=convert_camel_case) - - total_results = data["total_results"] # type: ignore - results_per_page: int = data["results_per_page"] # type: ignore - cve_changes: Iterable = data.get("cve_changes", []) # type: ignore - - for cve_change in cve_changes: - yield CVEChange.from_dict(cve_change["change"]) - - start_index += results_per_page # type: ignore + results_per_page = ( + request_results + if request_results and request_results < MAX_CVE_CHANGES_PER_PAGE + else MAX_CVE_CHANGES_PER_PAGE + ) + return NVDResults( + self, + params, + _result_iterator, + request_results=request_results, + results_per_page=results_per_page, + start_index=start_index, + ) async def __aenter__(self) -> "CVEChangesApi": await super().__aenter__() diff --git a/tests/nvd/cve_changes/test_api.py b/tests/nvd/cve_changes/test_api.py index 1b167c2ee..f4e4ca3ca 100644 --- a/tests/nvd/cve_changes/test_api.py +++ b/tests/nvd/cve_changes/test_api.py @@ -12,7 +12,7 @@ from pontos.errors import PontosError from pontos.nvd.api import now -from pontos.nvd.cve_changes.api import CVEChangesApi +from pontos.nvd.cve_changes.api import MAX_CVE_CHANGES_PER_PAGE, CVEChangesApi from pontos.nvd.models.cve_change import Detail, EventName from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext from tests.nvd import get_cve_change_data @@ -88,7 +88,10 @@ async def test_cve_changes(self): self.http_client.get.assert_awaited_once_with( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", headers={"apiKey": "token"}, - params={"startIndex": 0}, + params={ + "startIndex": 0, + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, + }, ) self.http_client.get.reset_mock() @@ -123,6 +126,7 @@ async def test_cve_changes_change_dates(self): "startIndex": 0, "changeStartDate": "2022-12-01T00:00:00", "changeEndDate": "2022-12-31T00:00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) @@ -155,7 +159,11 @@ async def test_cve_changes_cve_id(self): self.http_client.get.assert_awaited_once_with( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", headers={"apiKey": "token"}, - params={"startIndex": 0, "cveId": "CVE-1"}, + params={ + "startIndex": 0, + "cveId": "CVE-1", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, + }, ) self.http_client.get.reset_mock() @@ -186,7 +194,11 @@ async def test_cve_changes_event_name(self): self.http_client.get.assert_awaited_once_with( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", headers={"apiKey": "token"}, - params={"startIndex": 0, "eventName": "Initial Analysis"}, + params={ + "startIndex": 0, + "eventName": "Initial Analysis", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, + }, ) self.http_client.get.reset_mock() @@ -227,6 +239,7 @@ async def test_cve_changes_calculate_end_date(self, now_mock: MagicMock): "startIndex": 0, "changeStartDate": "2023-01-01T00:00:00+00:00", "changeEndDate": "2023-01-02T00:00:00+00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) @@ -252,6 +265,7 @@ async def test_cve_changes_calculate_end_date_with_limit( "startIndex": 0, "changeStartDate": "2023-01-01T00:00:00+00:00", "changeEndDate": "2023-05-01T00:00:00+00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) @@ -273,18 +287,32 @@ async def test_cve_changes_calculate_start_date(self): "startIndex": 0, "changeStartDate": "2023-01-01T00:00:00+00:00", "changeEndDate": "2023-05-01T00:00:00+00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) async def test_cve_changes_range_too_long(self): - it = aiter( + with self.assertRaises(PontosError): self.api.changes( change_start_date=datetime(2023, 1, 1), change_end_date=datetime(2023, 5, 2), ) + + async def test_cve_changes_request_results(self): + self.http_client.get.side_effect = create_cve_changes_responses() + + it = aiter(self.api.changes(request_results=10)) + + await anext(it) + + self.http_client.get.assert_awaited_once_with( + "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", + headers={"apiKey": "token"}, + params={ + "startIndex": 0, + "resultsPerPage": 10, + }, ) - with self.assertRaises(PontosError): - await anext(it) async def test_context_manager(self): async with self.api: diff --git a/tests/nvd/test_api.py b/tests/nvd/test_api.py index 1df2391f2..b86166050 100644 --- a/tests/nvd/test_api.py +++ b/tests/nvd/test_api.py @@ -19,12 +19,21 @@ import unittest from datetime import datetime +from typing import Any, Iterator from unittest.mock import AsyncMock, MagicMock, patch -from httpx import AsyncClient +from httpx import AsyncClient, Response -from pontos.nvd.api import NVDApi, convert_camel_case, format_date -from tests import IsolatedAsyncioTestCase +from pontos.nvd.api import ( + JSON, + InvalidState, + NoMoreResults, + NVDApi, + NVDResults, + convert_camel_case, + format_date, +) +from tests import IsolatedAsyncioTestCase, aiter, anext class ConvertCamelCaseTestCase(unittest.TestCase): @@ -131,3 +140,329 @@ async def test_no_rate_limit( await api._get() sleep_mock.assert_not_called() + + +class Result: + def __init__(self, value: int) -> None: + self.value = value + + +def result_func(data: JSON) -> Iterator[Result]: + return (Result(d) for d in data["values"]) # type: ignore + + +class NVDResultsTestCase(IsolatedAsyncioTestCase): + async def test_items(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + it = aiter(results.items()) + + result = await anext(it) + self.assertEqual(result.value, 1) + + result = await anext(it) + self.assertEqual(result.value, 2) + + result = await anext(it) + self.assertEqual(result.value, 3) + + result = await anext(it) + self.assertEqual(result.value, 4) + + result = await anext(it) + self.assertEqual(result.value, 5) + + result = await anext(it) + self.assertEqual(result.value, 6) + + async def test_aiter(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + it = aiter(results) + + result = await anext(it) + self.assertEqual(result.value, 1) + + result = await anext(it) + self.assertEqual(result.value, 2) + + result = await anext(it) + self.assertEqual(result.value, 3) + + result = await anext(it) + self.assertEqual(result.value, 4) + + result = await anext(it) + self.assertEqual(result.value, 5) + + result = await anext(it) + self.assertEqual(result.value, 6) + + async def test_len(self): + response_mock = MagicMock(spec=Response) + response_mock.json.return_value = { + "values": [1, 2, 3], + "total_results": 3, + "results_per_page": 3, + } + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + with self.assertRaisesRegex( + InvalidState, "NVDResults has not been awaited yet" + ): + len(results) + + await results + + self.assertEqual(len(results), 3) + + async def test_chunks(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + it = aiter(nvd_results.chunks()) + + results = await anext(it) + self.assertEqual([result.value for result in results], [1, 2, 3]) + + results = await anext(it) + self.assertEqual([result.value for result in results], [4, 5, 6]) + + async def test_json(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [1, 2, 3]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [4, 5, 6]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + self.assertIsNone(await nvd_results.json()) + + async def test_await(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + await nvd_results + self.assertEqual(len(nvd_results), 6) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [1, 2, 3]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + await nvd_results + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [4, 5, 6]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + with self.assertRaises(NoMoreResults): + await nvd_results + + async def test_mix_and_match(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + await nvd_results + self.assertEqual(len(nvd_results), 6) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [1, 2, 3]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + self.assertEqual( + [result.value async for result in nvd_results], [1, 2, 3, 4, 5, 6] + ) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [4, 5, 6]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + async def test_response_error(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + json = await nvd_results.json() + self.assertEqual(json["values"], [1, 2, 3]) # type: ignore + + api_mock._get.assert_called_once_with(params={"startIndex": 0}) + + response_mock.raise_for_status.side_effect = Exception("Server Error") + + api_mock.reset_mock() + + with self.assertRaises(Exception): + json = await nvd_results.json() + + api_mock._get.assert_called_once_with( + params={ + "startIndex": 3, + "resultsPerPage": 3, + } + ) + + response_mock.reset_mock(return_value=True, side_effect=True) + api_mock.reset_mock() + + response_mock.json.side_effect = [ + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + + json = await nvd_results.json() + self.assertEqual(json["values"], [4, 5, 6]) # type: ignore + + api_mock._get.assert_called_once_with( + params={ + "startIndex": 3, + "resultsPerPage": 3, + } + )