diff --git a/pontos/nvd/api.py b/pontos/nvd/api.py index dffe646e1..00aa571c5 100644 --- a/pontos/nvd/api.py +++ b/pontos/nvd/api.py @@ -20,10 +20,26 @@ from abc import ABC from datetime import datetime, timezone from types import TracebackType -from typing import Any, Dict, Optional, Type, Union +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Dict, + Generator, + Generic, + Iterator, + Optional, + Sequence, + Type, + TypeVar, + Union, +) -from httpx import AsyncClient, Response, Timeout +from httpx import URL, AsyncClient, Response, Timeout +from pontos.errors import PontosError from pontos.helper import snake_case SLEEP_TIMEOUT = 30.0 # in seconds @@ -78,6 +94,239 @@ def convert_camel_case(dct: Dict[str, Any]) -> Dict[str, Any]: return converted +class NoMoreResults(PontosError): + """ + Raised if the NVD API has no more results to consume + """ + + +class InvalidState(PontosError): + """ + Raised if the state of the NVD API is invalid + """ + + +T = TypeVar("T") + +result_iterator_func = Callable[[JSON], Iterator[T]] + + +class NVDResults(Generic[T], AsyncIterable[T], Awaitable["NVDResults"]): + """ + A generic object for accessing the results of a NVD API response + + It implements the pagination and will issue requests against the NVD API. + """ + + def __init__( + self, + api: "NVDApi", + params: Params, + result_func: result_iterator_func, + *, + request_results: Optional[int] = None, + results_per_page: Optional[int] = None, + start_index: int = 0, + ) -> None: + self._api = api + self._params = params + self._url: Optional[URL] = None + + self._data: Optional[JSON] = None + self._it: Optional[Iterator[T]] = None + self._total_results: Optional[int] = None + self._downloaded_results: int = 0 + + self._start_index = start_index + self._request_results = request_results + self._results_per_page = results_per_page + + self._current_index = start_index + self._current_request_results = request_results + self._current_results_per_page = results_per_page + + self._result_func = result_func + + async def chunks(self) -> AsyncIterator[Sequence[T]]: + """ + Return the results in chunks + + The size of the chunks is defined by results_per_page. + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + + async for results in nvd_results.chunks(): + for result in results: + print(result) + """ + while True: + try: + if self._it: + yield list(self._it) + await self._next_iterator() + except NoMoreResults: + return + + async def items(self) -> AsyncIterator[T]: + """ + Return the results of the NVD API response + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + + async for result in nvd_results.items(): + print(result) + """ + while True: + try: + if self._it: + for result in self._it: + yield result + await self._next_iterator() + except NoMoreResults: + return + + async def json(self) -> Optional[JSON]: + """ + Return the result from the NVD API request as JSON + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + while data := await nvd_results.json(): + print(data) + + Returns: + The response data as JSON or None if the response is exhausted. + """ + try: + if not self._data: + await self._next_iterator() + + data = self._data + self._data = None + return data + except NoMoreResults: + return None + + def __len__(self) -> int: + """ + Get the number of available result items for a NVD API request + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + total_results = len(nvd_results) # None because it hasn't been awaited yet + json = await nvd_results.json() # request the plain JSON data + total_results = len(nvd_results) # contains the total number of results now + + nvd_results: NVDResults = ... + total_results = len(nvd_results) # None because it hasn't been awaited yet + async for result in nvd_results: + print(result) + total_results = len(nvd_results) # contains the total number of results now + + Returns: + The total number of available results if the NVDResults has been awaited + """ + if self._total_results is None: + raise InvalidState( + f"{self.__class__.__name__} has not been awaited yet." + ) + return self._total_results + + def __aiter__(self) -> AsyncIterator[T]: + """ + Return the results of the NVD API response + + Same as the items() method. @see items() + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + + async for result in nvd_results: + print(result) + """ + return self.items() + + def __await__(self) -> Generator[Any, None, "NVDResults"]: + """ + Request the next results from the NVD API + + Examples: + .. code-block:: python + + nvd_results: NVDResults = ... + print(len(nvd_results)) # None, because no request has been send yet + await nvd_results # creates a request to the NVD API + print(len(nvd_results)) + + Returns: + The response data as JSON or None if the response is exhausted. + """ + + return self._next_iterator().__await__() + + async def _load_next_data(self) -> None: + if ( + not self._current_request_results + or self._downloaded_results < self._current_request_results + ): + params = self._params + params["startIndex"] = self._current_index + + if self._current_results_per_page is not None: + params["resultsPerPage"] = self._current_results_per_page + + response = await self._api._get(params=params) + response.raise_for_status() + + self._url = response.url + data: JSON = response.json(object_hook=convert_camel_case) + + self._data = data + self._current_results_per_page = int(data["results_per_page"]) # type: ignore + self._total_results = int(data["total_results"]) # type: ignore + self._current_index += self._current_results_per_page + self._downloaded_results += self._current_results_per_page + + if not self._current_request_results: + self._current_request_results = self._total_results + + if ( + self._request_results + and self._downloaded_results + self._current_results_per_page + > self._request_results + ): + # avoid downloading more results then requested + self._current_results_per_page = ( + self._request_results - self._downloaded_results + ) + + else: + raise NoMoreResults() + + async def _get_next_iterator(self) -> Iterator[T]: + await self._load_next_data() + return self._result_func(self._data) # type: ignore + + async def _next_iterator(self) -> "NVDResults": + self._it = await self._get_next_iterator() + return self + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} url="{self._url}" total_results={self._total_results} start_index={self._start_index} current_index={self._current_index} results_per_page={self._results_per_page}>' + + class NVDApi(ABC): """ Abstract base class for querying the NIST NVD API. @@ -155,7 +404,7 @@ async def _get( params: Optional[Params] = None, ) -> Response: """ - A request against the NIST NVD CVE REST API. + A request against the NIST NVD REST API. """ headers = self._request_headers() diff --git a/pontos/nvd/cpe/api.py b/pontos/nvd/cpe/api.py index 072b3b07d..d2001d601 100644 --- a/pontos/nvd/cpe/api.py +++ b/pontos/nvd/cpe/api.py @@ -19,8 +19,8 @@ from datetime import datetime from types import TracebackType from typing import ( - AsyncIterator, - Iterable, + Any, + Iterator, List, Optional, Type, @@ -35,6 +35,7 @@ DEFAULT_TIMEOUT_CONFIG, JSON, NVDApi, + NVDResults, Params, convert_camel_case, format_date, @@ -46,6 +47,11 @@ MAX_CPES_PER_PAGE = 10000 +def _result_iterator(data: JSON) -> Iterator[CPE]: + results: list[dict[str, Any]] = data.get("products", []) # type: ignore + return (CPE.from_dict(result["cpe"]) for result in results) + + class CPEApi(NVDApi): """ API for querying the NIST NVD CPE information. @@ -125,7 +131,7 @@ async def cpe(self, cpe_name_id: Union[str, UUID]) -> CPE: product = products[0] return CPE.from_dict(product["cpe"]) - async def cpes( + def cpes( self, *, last_modified_start_date: Optional[datetime] = None, @@ -134,7 +140,7 @@ async def cpes( keywords: Optional[Union[List[str], str]] = None, match_criteria_id: Optional[str] = None, request_results: Optional[int] = None, - ) -> AsyncIterator[CPE]: + ) -> NVDResults[CPE]: """ Get all CPEs for the provided arguments @@ -155,9 +161,9 @@ async def cpes( to download all available CPEs. Returns: - An async iterator of CPE model instances. + A NVDResponse for CPEs - Example: + Examples: .. code-block:: python from pontos.nvd.cpe import CPEApi @@ -165,6 +171,14 @@ async def cpes( async with CPEApi() as api: async for cpe in api.cpes(keywords=["Mac OS X"]): print(cpe.cpe_name, cpe.cpe_name_id) + + json = await api.cpes(request_results=10).json() + + async for cpes in api.cpes( + cpe_match_string="cpe:2.3:o:microsoft:windows_7:-:*:*:*:*:*:*:*", + ).chunks(): + for cpe in cpes: + print(cpe) """ params: Params = {} if last_modified_start_date: @@ -189,51 +203,20 @@ async def cpes( params["matchCriteriaId"] = match_criteria_id start_index = 0 - downloaded_results = 0 results_per_page = ( request_results if request_results and request_results < MAX_CPES_PER_PAGE else MAX_CPES_PER_PAGE ) - total_results = None - requested_results = request_results - - while ( - requested_results is None or downloaded_results < requested_results - ): - params["startIndex"] = start_index - - if results_per_page is not None: - params["resultsPerPage"] = results_per_page - - response = await self._get(params=params) - response.raise_for_status() - data: JSON = response.json(object_hook=convert_camel_case) - - results_per_page: int = data["results_per_page"] # type: ignore - total_results: int = data["total_results"] # type: ignore - products: Iterable = data.get("products", []) # type: ignore - - if not requested_results: - requested_results = total_results - - for product in products: - yield CPE.from_dict(product["cpe"]) - - if results_per_page is None: - # just be safe here. should never occur - results_per_page = len(products) - - start_index += results_per_page - downloaded_results += results_per_page - - if ( - request_results - and downloaded_results + results_per_page > request_results - ): - # avoid downloading more results then requested - results_per_page = request_results - downloaded_results + return NVDResults( + self, + params, + _result_iterator, + request_results=request_results, + results_per_page=results_per_page, + start_index=start_index, + ) async def __aenter__(self) -> "CPEApi": await super().__aenter__() diff --git a/pontos/nvd/cve/api.py b/pontos/nvd/cve/api.py index cd35acc7d..06565f89a 100644 --- a/pontos/nvd/cve/api.py +++ b/pontos/nvd/cve/api.py @@ -18,8 +18,8 @@ from datetime import datetime from types import TracebackType from typing import ( - AsyncIterator, Iterable, + Iterator, List, Optional, Type, @@ -33,6 +33,7 @@ DEFAULT_TIMEOUT_CONFIG, JSON, NVDApi, + NVDResults, Params, convert_camel_case, format_date, @@ -48,6 +49,13 @@ MAX_CVES_PER_PAGE = 2000 +def _result_iterator(data: JSON) -> Iterator[CVE]: + vulnerabilities: Iterable = data.get("vulnerabilities", []) # type: ignore + return ( + CVE.from_dict(vulnerability["cve"]) for vulnerability in vulnerabilities + ) + + class CVEApi(NVDApi): """ API for querying the NIST NVD CVE information. @@ -91,7 +99,7 @@ def __init__( rate_limit=rate_limit, ) - async def cves( + def cves( self, *, last_modified_start_date: Optional[datetime] = None, @@ -113,7 +121,7 @@ async def cves( has_kev: Optional[bool] = None, has_oval: Optional[bool] = None, request_results: Optional[int] = None, - ) -> AsyncIterator[CVE]: + ) -> NVDResults[CVE]: """ Get all CVEs for the provided arguments @@ -165,9 +173,9 @@ async def cves( to download all available CVEs. Returns: - An async iterator to iterate over CVE model instances + A NVDResponse for CVEs - Example: + Examples: .. code-block:: python from pontos.nvd.cve import CVEApi @@ -175,6 +183,16 @@ async def cves( async with CVEApi() as api: async for cve in api.cves(keywords=["Mac OS X", "kernel"]): print(cve.id) + + json = await api.cves( + cpe_name="cpe:2.3:o:microsoft:windows_7:-:*:*:*:*:*:x64:*", + ).json() + + async for cves in api.cves( + virtual_match_string="cpe:2.3:o:microsoft:windows_7:-:*:*:*:*:*:x64:*", + ).chunks(): + for cve in cves: + print(cve) """ params: Params = {} if last_modified_start_date: @@ -231,54 +249,20 @@ async def cves( if has_oval: params["hasOval"] = "" - start_index: int = 0 - downloaded_results = 0 + start_index = 0 results_per_page = ( request_results if request_results and request_results < MAX_CVES_PER_PAGE else MAX_CVES_PER_PAGE ) - total_results = None - requested_results = request_results - - while ( - requested_results is None or downloaded_results < requested_results - ): - params["startIndex"] = start_index - - if results_per_page is not None: - params["resultsPerPage"] = results_per_page - - response = await self._get(params=params) - response.raise_for_status() - - data: JSON = response.json(object_hook=convert_camel_case) - - results_per_page: int = data["results_per_page"] # type: ignore - total_results: int = data["total_results"] # type: ignore - vulnerabilities: Iterable = data.get( # type: ignore - "vulnerabilities", [] - ) - - if not requested_results: - requested_results = total_results - - for vulnerability in vulnerabilities: - yield CVE.from_dict(vulnerability["cve"]) - - if results_per_page is None: - # just be safe here. should never occur - results_per_page = len(vulnerabilities) - - start_index += results_per_page - downloaded_results += results_per_page - - if ( - request_results - and downloaded_results + results_per_page > request_results - ): - # avoid downloading more results then requested - results_per_page = request_results - downloaded_results + return NVDResults( + self, + params, + _result_iterator, + request_results=request_results, + results_per_page=results_per_page, + start_index=start_index, + ) async def cve(self, cve_id: str) -> CVE: """ diff --git a/pontos/nvd/cve_changes/api.py b/pontos/nvd/cve_changes/api.py index 777a60841..1378c4dfd 100644 --- a/pontos/nvd/cve_changes/api.py +++ b/pontos/nvd/cve_changes/api.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from datetime import datetime, timedelta -from typing import AsyncIterator, Iterable, Optional, Union +from typing import Any, Iterator, Optional, Union from httpx import Timeout @@ -12,8 +12,8 @@ DEFAULT_TIMEOUT_CONFIG, JSON, NVDApi, + NVDResults, Params, - convert_camel_case, format_date, now, ) @@ -25,6 +25,13 @@ "https://services.nvd.nist.gov/rest/json/cvehistory/2.0" ) +MAX_CVE_CHANGES_PER_PAGE = 5000 + + +def _result_iterator(data: JSON) -> Iterator[CVEChange]: + results: list[dict[str, Any]] = data.get("cve_changes", []) # type: ignore + return (CVEChange.from_dict(result["change"]) for result in results) + class CVEChangesApi(NVDApi): """ @@ -70,14 +77,15 @@ def __init__( rate_limit=rate_limit, ) - async def changes( + def changes( self, *, change_start_date: Optional[datetime] = None, change_end_date: Optional[datetime] = None, cve_id: Optional[str] = None, event_name: Optional[Union[EventName, str]] = None, - ) -> AsyncIterator[CVEChange]: + request_results: Optional[int] = None, + ) -> NVDResults[CVEChange]: """ Get all CVEs for the provided arguments @@ -88,10 +96,11 @@ async def changes( change_end_date: Return all CVE changes before this date. cve_id: Return all CVE changes for this Common Vulnerabilities and Exposures identifier. event_name: Return all CVE changes with this event name. - + request_results: Number of CVEs changes to download. Set to None + (default) to download all available CPEs. Returns: - An async iterator to iterate over CVEChange model instances + A NVDResponse for CVE changes Example: .. code-block:: python @@ -101,9 +110,15 @@ async def changes( async with CVEChangesApi() as api: async for cve_change in api.changes(event_name=EventName.INITIAL_ANALYSIS): print(cve_change) - """ - total_results: Optional[int] = None + json = api.changes(event_name=EventName.INITIAL_ANALYSIS).json() + + async for changes in api.changes( + event_name=EventName.INITIAL_ANALYSIS, + ).chunks(): + for cve_change in changes: + print(cve_change) + """ if change_start_date and not change_end_date: change_end_date = min( now(), change_start_date + timedelta(days=120) @@ -128,27 +143,19 @@ async def changes( params["eventName"] = event_name start_index: int = 0 - results_per_page = None - - while total_results is None or start_index < total_results: - params["startIndex"] = start_index - - if results_per_page is not None: - params["resultsPerPage"] = results_per_page - - response = await self._get(params=params) - response.raise_for_status() - - data: JSON = response.json(object_hook=convert_camel_case) - - total_results = data["total_results"] # type: ignore - results_per_page: int = data["results_per_page"] # type: ignore - cve_changes: Iterable = data.get("cve_changes", []) # type: ignore - - for cve_change in cve_changes: - yield CVEChange.from_dict(cve_change["change"]) - - start_index += results_per_page # type: ignore + results_per_page = ( + request_results + if request_results and request_results < MAX_CVE_CHANGES_PER_PAGE + else MAX_CVE_CHANGES_PER_PAGE + ) + return NVDResults( + self, + params, + _result_iterator, + request_results=request_results, + results_per_page=results_per_page, + start_index=start_index, + ) async def __aenter__(self) -> "CVEChangesApi": await super().__aenter__() diff --git a/tests/nvd/cve_changes/test_api.py b/tests/nvd/cve_changes/test_api.py index 1b167c2ee..f4e4ca3ca 100644 --- a/tests/nvd/cve_changes/test_api.py +++ b/tests/nvd/cve_changes/test_api.py @@ -12,7 +12,7 @@ from pontos.errors import PontosError from pontos.nvd.api import now -from pontos.nvd.cve_changes.api import CVEChangesApi +from pontos.nvd.cve_changes.api import MAX_CVE_CHANGES_PER_PAGE, CVEChangesApi from pontos.nvd.models.cve_change import Detail, EventName from tests import AsyncMock, IsolatedAsyncioTestCase, aiter, anext from tests.nvd import get_cve_change_data @@ -88,7 +88,10 @@ async def test_cve_changes(self): self.http_client.get.assert_awaited_once_with( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", headers={"apiKey": "token"}, - params={"startIndex": 0}, + params={ + "startIndex": 0, + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, + }, ) self.http_client.get.reset_mock() @@ -123,6 +126,7 @@ async def test_cve_changes_change_dates(self): "startIndex": 0, "changeStartDate": "2022-12-01T00:00:00", "changeEndDate": "2022-12-31T00:00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) @@ -155,7 +159,11 @@ async def test_cve_changes_cve_id(self): self.http_client.get.assert_awaited_once_with( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", headers={"apiKey": "token"}, - params={"startIndex": 0, "cveId": "CVE-1"}, + params={ + "startIndex": 0, + "cveId": "CVE-1", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, + }, ) self.http_client.get.reset_mock() @@ -186,7 +194,11 @@ async def test_cve_changes_event_name(self): self.http_client.get.assert_awaited_once_with( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", headers={"apiKey": "token"}, - params={"startIndex": 0, "eventName": "Initial Analysis"}, + params={ + "startIndex": 0, + "eventName": "Initial Analysis", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, + }, ) self.http_client.get.reset_mock() @@ -227,6 +239,7 @@ async def test_cve_changes_calculate_end_date(self, now_mock: MagicMock): "startIndex": 0, "changeStartDate": "2023-01-01T00:00:00+00:00", "changeEndDate": "2023-01-02T00:00:00+00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) @@ -252,6 +265,7 @@ async def test_cve_changes_calculate_end_date_with_limit( "startIndex": 0, "changeStartDate": "2023-01-01T00:00:00+00:00", "changeEndDate": "2023-05-01T00:00:00+00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) @@ -273,18 +287,32 @@ async def test_cve_changes_calculate_start_date(self): "startIndex": 0, "changeStartDate": "2023-01-01T00:00:00+00:00", "changeEndDate": "2023-05-01T00:00:00+00:00", + "resultsPerPage": MAX_CVE_CHANGES_PER_PAGE, }, ) async def test_cve_changes_range_too_long(self): - it = aiter( + with self.assertRaises(PontosError): self.api.changes( change_start_date=datetime(2023, 1, 1), change_end_date=datetime(2023, 5, 2), ) + + async def test_cve_changes_request_results(self): + self.http_client.get.side_effect = create_cve_changes_responses() + + it = aiter(self.api.changes(request_results=10)) + + await anext(it) + + self.http_client.get.assert_awaited_once_with( + "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", + headers={"apiKey": "token"}, + params={ + "startIndex": 0, + "resultsPerPage": 10, + }, ) - with self.assertRaises(PontosError): - await anext(it) async def test_context_manager(self): async with self.api: diff --git a/tests/nvd/test_api.py b/tests/nvd/test_api.py index 1df2391f2..b86166050 100644 --- a/tests/nvd/test_api.py +++ b/tests/nvd/test_api.py @@ -19,12 +19,21 @@ import unittest from datetime import datetime +from typing import Any, Iterator from unittest.mock import AsyncMock, MagicMock, patch -from httpx import AsyncClient +from httpx import AsyncClient, Response -from pontos.nvd.api import NVDApi, convert_camel_case, format_date -from tests import IsolatedAsyncioTestCase +from pontos.nvd.api import ( + JSON, + InvalidState, + NoMoreResults, + NVDApi, + NVDResults, + convert_camel_case, + format_date, +) +from tests import IsolatedAsyncioTestCase, aiter, anext class ConvertCamelCaseTestCase(unittest.TestCase): @@ -131,3 +140,329 @@ async def test_no_rate_limit( await api._get() sleep_mock.assert_not_called() + + +class Result: + def __init__(self, value: int) -> None: + self.value = value + + +def result_func(data: JSON) -> Iterator[Result]: + return (Result(d) for d in data["values"]) # type: ignore + + +class NVDResultsTestCase(IsolatedAsyncioTestCase): + async def test_items(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + it = aiter(results.items()) + + result = await anext(it) + self.assertEqual(result.value, 1) + + result = await anext(it) + self.assertEqual(result.value, 2) + + result = await anext(it) + self.assertEqual(result.value, 3) + + result = await anext(it) + self.assertEqual(result.value, 4) + + result = await anext(it) + self.assertEqual(result.value, 5) + + result = await anext(it) + self.assertEqual(result.value, 6) + + async def test_aiter(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + it = aiter(results) + + result = await anext(it) + self.assertEqual(result.value, 1) + + result = await anext(it) + self.assertEqual(result.value, 2) + + result = await anext(it) + self.assertEqual(result.value, 3) + + result = await anext(it) + self.assertEqual(result.value, 4) + + result = await anext(it) + self.assertEqual(result.value, 5) + + result = await anext(it) + self.assertEqual(result.value, 6) + + async def test_len(self): + response_mock = MagicMock(spec=Response) + response_mock.json.return_value = { + "values": [1, 2, 3], + "total_results": 3, + "results_per_page": 3, + } + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + with self.assertRaisesRegex( + InvalidState, "NVDResults has not been awaited yet" + ): + len(results) + + await results + + self.assertEqual(len(results), 3) + + async def test_chunks(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + it = aiter(nvd_results.chunks()) + + results = await anext(it) + self.assertEqual([result.value for result in results], [1, 2, 3]) + + results = await anext(it) + self.assertEqual([result.value for result in results], [4, 5, 6]) + + async def test_json(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [1, 2, 3]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [4, 5, 6]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + self.assertIsNone(await nvd_results.json()) + + async def test_await(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + await nvd_results + self.assertEqual(len(nvd_results), 6) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [1, 2, 3]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + await nvd_results + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [4, 5, 6]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + with self.assertRaises(NoMoreResults): + await nvd_results + + async def test_mix_and_match(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + await nvd_results + self.assertEqual(len(nvd_results), 6) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [1, 2, 3]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + self.assertEqual( + [result.value async for result in nvd_results], [1, 2, 3, 4, 5, 6] + ) + + json: dict[str, Any] = await nvd_results.json() # type: ignore + self.assertEqual(json["values"], [4, 5, 6]) + self.assertEqual(json["total_results"], 6) + self.assertEqual(json["results_per_page"], 3) + + async def test_response_error(self): + response_mock = MagicMock(spec=Response) + response_mock.json.side_effect = [ + { + "values": [1, 2, 3], + "total_results": 6, + "results_per_page": 3, + }, + ] + api_mock = AsyncMock(spec=NVDApi) + api_mock._get.return_value = response_mock + + nvd_results: NVDResults[Result] = NVDResults( + api_mock, + {}, + result_func, + ) + + json = await nvd_results.json() + self.assertEqual(json["values"], [1, 2, 3]) # type: ignore + + api_mock._get.assert_called_once_with(params={"startIndex": 0}) + + response_mock.raise_for_status.side_effect = Exception("Server Error") + + api_mock.reset_mock() + + with self.assertRaises(Exception): + json = await nvd_results.json() + + api_mock._get.assert_called_once_with( + params={ + "startIndex": 3, + "resultsPerPage": 3, + } + ) + + response_mock.reset_mock(return_value=True, side_effect=True) + api_mock.reset_mock() + + response_mock.json.side_effect = [ + { + "values": [4, 5, 6], + "total_results": 6, + "results_per_page": 3, + }, + ] + + json = await nvd_results.json() + self.assertEqual(json["values"], [4, 5, 6]) # type: ignore + + api_mock._get.assert_called_once_with( + params={ + "startIndex": 3, + "resultsPerPage": 3, + } + )