diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/cnbc/api_wrapper.py b/src/cnbc/api_wrapper.py index 444ec27..ff9daf4 100644 --- a/src/cnbc/api_wrapper.py +++ b/src/cnbc/api_wrapper.py @@ -5,7 +5,7 @@ import requests from .endpoints import Endpoints -from .exceptions import APIRequestException, NetworkError, InvalidParameterConfiguration +from .exceptions import APIRequestException, NetworkError class APIWrapper: @@ -22,11 +22,11 @@ def __init__(self, api_key: str, endpoint: Endpoints, timeout: int = 10): :param timeout: The timeout of the API request. """ self._endpoint: str - self._params: dict[str, str] + self._params: Endpoints.Parameters self._headers: dict[str, str] self._timeout: int - self._translate_table: dict[str, str] + self._translation_table: dict[str, str] self._endpoint, self._params = endpoint.value self._headers = {'x-rapidapi-host': Endpoints.HOST.value, 'x-rapidapi-key': api_key} @@ -34,13 +34,6 @@ def __init__(self, api_key: str, endpoint: Endpoints, timeout: int = 10): self._translation_table = {} - def _safe_delete(self): - """ - Safely delete the attributes. - """ - del self._endpoint - del self._params - @property def endpoint(self): """ @@ -55,7 +48,8 @@ def endpoint(self, endpoint: Endpoints): @endpoint.deleter def endpoint(self): - self._safe_delete() + del self._endpoint + del self._params @property def params(self): @@ -65,21 +59,6 @@ def params(self): """ return self._params - @params.setter - def params(self, params: dict[str, str]): - try: - # Update the parameters if the input parameters match the endpoint parameters. - if set(self._params.keys()) == set(params.keys()): - self._params.update(params) - else: - raise InvalidParameterConfiguration() - except AttributeError: - raise InvalidParameterConfiguration() - - @params.deleter - def params(self): - self._safe_delete() - @property def headers(self): """ @@ -134,7 +113,7 @@ def translation_table_load(self, file_path: str): :param file_path: The file path of the translation table. :return: None """ - with open(file_path, "r") as file: + with open(file_path, "r", encoding="utf-8") as file: self._translation_table = json.load(file) def translation_table_save(self, file_path: str): @@ -143,7 +122,7 @@ def translation_table_save(self, file_path: str): :param file_path: The file path of the translation table. :return: None """ - with open(file_path, "w") as file: + with open(file_path, "w", encoding="utf-8") as file: json.dump(self._translation_table, file) def request(self) -> dict: @@ -154,17 +133,19 @@ def request(self) -> dict: # If the endpoint is the translate endpoint, then check if the symbol is in the translation table. if self._endpoint == Endpoints.TRANSLATE.get_endpoint(): # If the symbol is in the translation table, then return a faux JSON response. - if issueId := self._translation_table.get(self._params['symbol']): - return {'issueId': issueId, 'errorMessage': '', 'errorCode': ''} + if issue_id := self._translation_table.get(self._params['symbol']): + return {'issueId': issue_id, 'errorMessage': '', 'errorCode': ''} with requests.request("GET", self._endpoint, headers=self._headers, params=self._params, timeout=self._timeout) as response: try: response.raise_for_status() response_json = response.json() + # If the endpoint is the translate endpoint, then update the translation table. if self._endpoint == Endpoints.TRANSLATE.get_endpoint(): self._translation_table[self._params['symbol']] = response_json['issueId'] + return response_json except requests.exceptions.HTTPError as e: raise APIRequestException(response.status_code, response.text) from e diff --git a/src/cnbc/endpoints.py b/src/cnbc/endpoints.py index a66a11b..30ccec5 100644 --- a/src/cnbc/endpoints.py +++ b/src/cnbc/endpoints.py @@ -3,31 +3,58 @@ """ from enum import Enum +from .exceptions import InvalidParameterConfiguration + class Endpoints(Enum): """ Endpoints for CNBC API. """ + class Parameters(dict): + """ + Parameters for the endpoint. + """ + def __init__(self, *args, **kwargs): + """ + Initialize the parameters. + + :param args: The arguments. + :param kwargs: The keyword arguments. + """ + super().__init__(*args, **kwargs) + self._init_keys = set(self.keys()) + + def __setitem__(self, key: str, value: str): + """ + Set the item. + + :param key: The key. + :param value: The value. + """ + if key not in self._init_keys: + raise InvalidParameterConfiguration() + super().__setitem__(key, value) + HOST: str = "cnbc.p.rapidapi.com" BASE_URL: str = f"https://{HOST}" # Default - GET_METADATA: tuple[str, None] = (f"{BASE_URL}/get-meta-data", None) - AUTO_COMPLETE: tuple[str, dict[str, str]] = (f"{BASE_URL}/v2/auto-complete", {"q": None}) + GET_METADATA: tuple[str, Parameters] = (f"{BASE_URL}/get-meta-data", Parameters()) + AUTO_COMPLETE: tuple[str, Parameters] = (f"{BASE_URL}/v2/auto-complete", Parameters({"q": None})) # Market - LIST_INDICES: tuple[str, None] = (f"{BASE_URL}/market/list-indices", None) + LIST_INDICES: tuple[str, Parameters] = (f"{BASE_URL}/market/list-indices", Parameters()) # News - LIST_TRENDING_NEWS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-trending", {"tag": "Articles", "count": None}) - LIST_SPECIAL_REPORTS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-special-reports", {"pageSize": None, "page": None}) - LIST_SYMBOL_NEWS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-by-symbol", {"symbol": None, "page": None, "pageSize": None}) + LIST_TRENDING_NEWS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-trending", Parameters({"tag": "Articles", "count": None})) + LIST_SPECIAL_REPORTS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-special-reports", Parameters({"pageSize": None, "page": None})) + LIST_SYMBOL_NEWS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-by-symbol", Parameters({"symbol": None, "page": None, "pageSize": None})) # Symbol - GET_EARNINGS_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-earnings-chart", {"issueId": None, "numberOfYears": None}) - GET_PROFILE: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-profile", {"issueId": None}) - GET_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-chart", {"symbol": None, "interval": None}) - TRANSLATE: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/translate", {"symbol": None}) - GET_SUMMARY: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-summary", {"issueIds": None}) - GET_FUNDAMENTALS: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-fundamentals", {"issueIds": None}) - GET_PRICELINE_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-priceline-chart", {"issueId": None, "numberOfDays": None}) - GET_PEERS: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-peers", {"symbol": None}) + GET_EARNINGS_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-earnings-chart", Parameters({"issueId": None, "numberOfYears": None})) + GET_PROFILE: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-profile", Parameters({"issueId": None})) + GET_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-chart", Parameters({"symbol": None, "interval": None})) + TRANSLATE: tuple[str, Parameters] = (f"{BASE_URL}/symbols/translate", Parameters({"symbol": None})) + GET_SUMMARY: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-summary", Parameters({"issueIds": None})) + GET_FUNDAMENTALS: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-fundamentals", Parameters({"issueIds": None})) + GET_PRICELINE_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-priceline-chart", Parameters({"issueId": None, "numberOfDays": None})) + GET_PEERS: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-peers", Parameters({"symbol": None})) def get_endpoint(self) -> str | None: """ @@ -38,7 +65,7 @@ def get_endpoint(self) -> str | None: return None return self.value[0] - def get_parameters(self) -> dict[str, str] | None: + def get_parameters(self) -> Parameters | None: """ Get the parameters of the endpoint. :return: The parameters of the endpoint. diff --git a/src/cnbc/exceptions.py b/src/cnbc/exceptions.py index 3e104d4..b861569 100644 --- a/src/cnbc/exceptions.py +++ b/src/cnbc/exceptions.py @@ -26,4 +26,4 @@ class InvalidParameterConfiguration(Exception): Custom exception for invalid parameter configuration. """ def __init__(self): - super().__init__("The supplied parameters are incompatible with the required parameters") + super().__init__("The supplied parameters are incompatible with the required parameters.") diff --git a/tests/cnbc/test_api_wrapper.py b/tests/cnbc/test_api_wrapper.py index 54c198c..b9849e2 100644 --- a/tests/cnbc/test_api_wrapper.py +++ b/tests/cnbc/test_api_wrapper.py @@ -6,7 +6,6 @@ from src.cnbc.api_wrapper import APIWrapper from src.cnbc.endpoints import Endpoints -from src.cnbc.exceptions import InvalidParameterConfiguration class TestAPIWrapper(unittest.TestCase): @@ -14,25 +13,6 @@ class TestAPIWrapper(unittest.TestCase): Test the APIWrapper class. """ - def test_set_params(self): - """ - Test the set_params method. - :return: None - """ - endpoint = APIWrapper("API_KEY", Endpoints.GET_FUNDAMENTALS) - endpoint_params = {"issueIds": None} - endpoint.params = endpoint_params - self.assertEqual(endpoint_params, endpoint.params) - - def test_set_params_invalid_parameter_configuration(self): - """ - Test the set_params method with an invalid parameter configuration. - :return: None - """ - endpoint = APIWrapper("API_KEY", Endpoints.GET_FUNDAMENTALS) - with self.assertRaises(InvalidParameterConfiguration): - endpoint.params = {"key": "value"} - @patch('src.cnbc.api_wrapper.requests.request') def test_request(self, mock_request: MagicMock): """ @@ -53,7 +33,7 @@ def test_request(self, mock_request: MagicMock): "GET", endpoint.endpoint, headers=endpoint.headers, params=endpoint.params, timeout=endpoint.timeout ) - self.assertEqual({"key": "value"}, response) + self.assertEqual(mock_response.json.return_value, response) def test_request_translate_translation_table(self): """ @@ -64,7 +44,7 @@ def test_request_translate_translation_table(self): json_response_expected = {'issueId': '123', 'errorMessage': '', 'errorCode': ''} api_wrapper = APIWrapper('API_KEY', Endpoints.TRANSLATE) - api_wrapper._translation_table = translation_table + api_wrapper.translation_table = translation_table api_wrapper_params = api_wrapper.params api_wrapper_params['symbol'] = 'AAPL' json_response = api_wrapper.request() diff --git a/tests/cnbc/test_endpoints.py b/tests/cnbc/test_endpoints.py index e45ffaa..eafdd77 100644 --- a/tests/cnbc/test_endpoints.py +++ b/tests/cnbc/test_endpoints.py @@ -4,6 +4,7 @@ import unittest from src.cnbc.endpoints import Endpoints +from src.cnbc.exceptions import InvalidParameterConfiguration class TestEndpoints(unittest.TestCase): @@ -11,6 +12,14 @@ class TestEndpoints(unittest.TestCase): Test the Endpoints enum. """ + def test_parameters_set_item_invalid_parameter_configuration(self): + """ + Test the __setitem__ method of the Parameters class for an invalid parameter configuration. + :return: None + """ + with self.assertRaises(InvalidParameterConfiguration): + Endpoints.GET_FUNDAMENTALS.get_parameters()["key"] = "value" + def test_get_endpoint_host(self): """ Test the get_endpoint method of the Endpoints enum for the attributes which aren't endpoints.