Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Controlled Parameter Dictionary #9

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed src/__init__.py
Empty file.
41 changes: 11 additions & 30 deletions src/cnbc/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests

from .endpoints import Endpoints
from .exceptions import APIRequestException, NetworkError, InvalidParameterConfiguration
from .exceptions import APIRequestException, NetworkError


class APIWrapper:
Expand All @@ -22,25 +22,18 @@ def __init__(self, api_key: str, endpoint: Endpoints, timeout: int = 10):
:param timeout: The timeout of the API request.
"""
self._endpoint: str
self._params: dict[str, str]
self._params: Endpoints.Parameters
self._headers: dict[str, str]
self._timeout: int

self._translate_table: dict[str, str]
self._translation_table: dict[str, str]

self._endpoint, self._params = endpoint.value
self._headers = {'x-rapidapi-host': Endpoints.HOST.value, 'x-rapidapi-key': api_key}
self._timeout = timeout

self._translation_table = {}

def _safe_delete(self):
"""
Safely delete the attributes.
"""
del self._endpoint
del self._params

@property
def endpoint(self):
"""
Expand All @@ -55,7 +48,8 @@ def endpoint(self, endpoint: Endpoints):

@endpoint.deleter
def endpoint(self):
self._safe_delete()
del self._endpoint
del self._params

@property
def params(self):
Expand All @@ -65,21 +59,6 @@ def params(self):
"""
return self._params

@params.setter
def params(self, params: dict[str, str]):
try:
# Update the parameters if the input parameters match the endpoint parameters.
if set(self._params.keys()) == set(params.keys()):
self._params.update(params)
else:
raise InvalidParameterConfiguration()
except AttributeError:
raise InvalidParameterConfiguration()

@params.deleter
def params(self):
self._safe_delete()

@property
def headers(self):
"""
Expand Down Expand Up @@ -134,7 +113,7 @@ def translation_table_load(self, file_path: str):
:param file_path: The file path of the translation table.
:return: None
"""
with open(file_path, "r") as file:
with open(file_path, "r", encoding="utf-8") as file:
self._translation_table = json.load(file)

def translation_table_save(self, file_path: str):
Expand All @@ -143,7 +122,7 @@ def translation_table_save(self, file_path: str):
:param file_path: The file path of the translation table.
:return: None
"""
with open(file_path, "w") as file:
with open(file_path, "w", encoding="utf-8") as file:
json.dump(self._translation_table, file)

def request(self) -> dict:
Expand All @@ -154,17 +133,19 @@ def request(self) -> dict:
# If the endpoint is the translate endpoint, then check if the symbol is in the translation table.
if self._endpoint == Endpoints.TRANSLATE.get_endpoint():
# If the symbol is in the translation table, then return a faux JSON response.
if issueId := self._translation_table.get(self._params['symbol']):
return {'issueId': issueId, 'errorMessage': '', 'errorCode': ''}
if issue_id := self._translation_table.get(self._params['symbol']):
return {'issueId': issue_id, 'errorMessage': '', 'errorCode': ''}

with requests.request("GET", self._endpoint,
headers=self._headers, params=self._params, timeout=self._timeout) as response:
try:
response.raise_for_status()
response_json = response.json()

# If the endpoint is the translate endpoint, then update the translation table.
if self._endpoint == Endpoints.TRANSLATE.get_endpoint():
self._translation_table[self._params['symbol']] = response_json['issueId']

return response_json
except requests.exceptions.HTTPError as e:
raise APIRequestException(response.status_code, response.text) from e
Expand Down
57 changes: 42 additions & 15 deletions src/cnbc/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,58 @@
"""
from enum import Enum

from .exceptions import InvalidParameterConfiguration


class Endpoints(Enum):
"""
Endpoints for CNBC API.
"""
class Parameters(dict):
"""
Parameters for the endpoint.
"""
def __init__(self, *args, **kwargs):
"""
Initialize the parameters.

:param args: The arguments.
:param kwargs: The keyword arguments.
"""
super().__init__(*args, **kwargs)
self._init_keys = set(self.keys())

def __setitem__(self, key: str, value: str):
"""
Set the item.

:param key: The key.
:param value: The value.
"""
if key not in self._init_keys:
raise InvalidParameterConfiguration()
super().__setitem__(key, value)

HOST: str = "cnbc.p.rapidapi.com"
BASE_URL: str = f"https://{HOST}"
# Default
GET_METADATA: tuple[str, None] = (f"{BASE_URL}/get-meta-data", None)
AUTO_COMPLETE: tuple[str, dict[str, str]] = (f"{BASE_URL}/v2/auto-complete", {"q": None})
GET_METADATA: tuple[str, Parameters] = (f"{BASE_URL}/get-meta-data", Parameters())
AUTO_COMPLETE: tuple[str, Parameters] = (f"{BASE_URL}/v2/auto-complete", Parameters({"q": None}))
# Market
LIST_INDICES: tuple[str, None] = (f"{BASE_URL}/market/list-indices", None)
LIST_INDICES: tuple[str, Parameters] = (f"{BASE_URL}/market/list-indices", Parameters())
# News
LIST_TRENDING_NEWS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-trending", {"tag": "Articles", "count": None})
LIST_SPECIAL_REPORTS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-special-reports", {"pageSize": None, "page": None})
LIST_SYMBOL_NEWS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-by-symbol", {"symbol": None, "page": None, "pageSize": None})
LIST_TRENDING_NEWS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-trending", Parameters({"tag": "Articles", "count": None}))
LIST_SPECIAL_REPORTS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-special-reports", Parameters({"pageSize": None, "page": None}))
LIST_SYMBOL_NEWS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-by-symbol", Parameters({"symbol": None, "page": None, "pageSize": None}))
# Symbol
GET_EARNINGS_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-earnings-chart", {"issueId": None, "numberOfYears": None})
GET_PROFILE: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-profile", {"issueId": None})
GET_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-chart", {"symbol": None, "interval": None})
TRANSLATE: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/translate", {"symbol": None})
GET_SUMMARY: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-summary", {"issueIds": None})
GET_FUNDAMENTALS: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-fundamentals", {"issueIds": None})
GET_PRICELINE_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-priceline-chart", {"issueId": None, "numberOfDays": None})
GET_PEERS: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-peers", {"symbol": None})
GET_EARNINGS_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-earnings-chart", Parameters({"issueId": None, "numberOfYears": None}))
GET_PROFILE: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-profile", Parameters({"issueId": None}))
GET_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-chart", Parameters({"symbol": None, "interval": None}))
TRANSLATE: tuple[str, Parameters] = (f"{BASE_URL}/symbols/translate", Parameters({"symbol": None}))
GET_SUMMARY: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-summary", Parameters({"issueIds": None}))
GET_FUNDAMENTALS: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-fundamentals", Parameters({"issueIds": None}))
GET_PRICELINE_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-priceline-chart", Parameters({"issueId": None, "numberOfDays": None}))
GET_PEERS: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-peers", Parameters({"symbol": None}))

def get_endpoint(self) -> str | None:
"""
Expand All @@ -38,7 +65,7 @@ def get_endpoint(self) -> str | None:
return None
return self.value[0]

def get_parameters(self) -> dict[str, str] | None:
def get_parameters(self) -> Parameters | None:
"""
Get the parameters of the endpoint.
:return: The parameters of the endpoint.
Expand Down
2 changes: 1 addition & 1 deletion src/cnbc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ class InvalidParameterConfiguration(Exception):
Custom exception for invalid parameter configuration.
"""
def __init__(self):
super().__init__("The supplied parameters are incompatible with the required parameters")
super().__init__("The supplied parameters are incompatible with the required parameters.")
24 changes: 2 additions & 22 deletions tests/cnbc/test_api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,13 @@

from src.cnbc.api_wrapper import APIWrapper
from src.cnbc.endpoints import Endpoints
from src.cnbc.exceptions import InvalidParameterConfiguration


class TestAPIWrapper(unittest.TestCase):
"""
Test the APIWrapper class.
"""

def test_set_params(self):
"""
Test the set_params method.
:return: None
"""
endpoint = APIWrapper("API_KEY", Endpoints.GET_FUNDAMENTALS)
endpoint_params = {"issueIds": None}
endpoint.params = endpoint_params
self.assertEqual(endpoint_params, endpoint.params)

def test_set_params_invalid_parameter_configuration(self):
"""
Test the set_params method with an invalid parameter configuration.
:return: None
"""
endpoint = APIWrapper("API_KEY", Endpoints.GET_FUNDAMENTALS)
with self.assertRaises(InvalidParameterConfiguration):
endpoint.params = {"key": "value"}

@patch('src.cnbc.api_wrapper.requests.request')
def test_request(self, mock_request: MagicMock):
"""
Expand All @@ -53,7 +33,7 @@ def test_request(self, mock_request: MagicMock):
"GET", endpoint.endpoint,
headers=endpoint.headers, params=endpoint.params, timeout=endpoint.timeout
)
self.assertEqual({"key": "value"}, response)
self.assertEqual(mock_response.json.return_value, response)

def test_request_translate_translation_table(self):
"""
Expand All @@ -64,7 +44,7 @@ def test_request_translate_translation_table(self):
json_response_expected = {'issueId': '123', 'errorMessage': '', 'errorCode': ''}

api_wrapper = APIWrapper('API_KEY', Endpoints.TRANSLATE)
api_wrapper._translation_table = translation_table
api_wrapper.translation_table = translation_table
api_wrapper_params = api_wrapper.params
api_wrapper_params['symbol'] = 'AAPL'
json_response = api_wrapper.request()
Expand Down
9 changes: 9 additions & 0 deletions tests/cnbc/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
import unittest

from src.cnbc.endpoints import Endpoints
from src.cnbc.exceptions import InvalidParameterConfiguration


class TestEndpoints(unittest.TestCase):
"""
Test the Endpoints enum.
"""

def test_parameters_set_item_invalid_parameter_configuration(self):
"""
Test the __setitem__ method of the Parameters class for an invalid parameter configuration.
:return: None
"""
with self.assertRaises(InvalidParameterConfiguration):
Endpoints.GET_FUNDAMENTALS.get_parameters()["key"] = "value"

def test_get_endpoint_host(self):
"""
Test the get_endpoint method of the Endpoints enum for the attributes which aren't endpoints.
Expand Down
Loading