Skip to content

Commit

Permalink
Merge pull request #9 from AG3NTZ3R0/controlParams
Browse files Browse the repository at this point in the history
Implement Controlled `Parameter` Dictionary
  • Loading branch information
AG3NTZ3R0 authored Feb 24, 2024
2 parents eed859b + 66e4bcd commit 0f449c2
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 68 deletions.
Empty file removed src/__init__.py
Empty file.
41 changes: 11 additions & 30 deletions src/cnbc/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests

from .endpoints import Endpoints
from .exceptions import APIRequestException, NetworkError, InvalidParameterConfiguration
from .exceptions import APIRequestException, NetworkError


class APIWrapper:
Expand All @@ -22,25 +22,18 @@ def __init__(self, api_key: str, endpoint: Endpoints, timeout: int = 10):
:param timeout: The timeout of the API request.
"""
self._endpoint: str
self._params: dict[str, str]
self._params: Endpoints.Parameters
self._headers: dict[str, str]
self._timeout: int

self._translate_table: dict[str, str]
self._translation_table: dict[str, str]

self._endpoint, self._params = endpoint.value
self._headers = {'x-rapidapi-host': Endpoints.HOST.value, 'x-rapidapi-key': api_key}
self._timeout = timeout

self._translation_table = {}

def _safe_delete(self):
"""
Safely delete the attributes.
"""
del self._endpoint
del self._params

@property
def endpoint(self):
"""
Expand All @@ -55,7 +48,8 @@ def endpoint(self, endpoint: Endpoints):

@endpoint.deleter
def endpoint(self):
self._safe_delete()
del self._endpoint
del self._params

@property
def params(self):
Expand All @@ -65,21 +59,6 @@ def params(self):
"""
return self._params

@params.setter
def params(self, params: dict[str, str]):
try:
# Update the parameters if the input parameters match the endpoint parameters.
if set(self._params.keys()) == set(params.keys()):
self._params.update(params)
else:
raise InvalidParameterConfiguration()
except AttributeError:
raise InvalidParameterConfiguration()

@params.deleter
def params(self):
self._safe_delete()

@property
def headers(self):
"""
Expand Down Expand Up @@ -134,7 +113,7 @@ def translation_table_load(self, file_path: str):
:param file_path: The file path of the translation table.
:return: None
"""
with open(file_path, "r") as file:
with open(file_path, "r", encoding="utf-8") as file:
self._translation_table = json.load(file)

def translation_table_save(self, file_path: str):
Expand All @@ -143,7 +122,7 @@ def translation_table_save(self, file_path: str):
:param file_path: The file path of the translation table.
:return: None
"""
with open(file_path, "w") as file:
with open(file_path, "w", encoding="utf-8") as file:
json.dump(self._translation_table, file)

def request(self) -> dict:
Expand All @@ -154,17 +133,19 @@ def request(self) -> dict:
# If the endpoint is the translate endpoint, then check if the symbol is in the translation table.
if self._endpoint == Endpoints.TRANSLATE.get_endpoint():
# If the symbol is in the translation table, then return a faux JSON response.
if issueId := self._translation_table.get(self._params['symbol']):
return {'issueId': issueId, 'errorMessage': '', 'errorCode': ''}
if issue_id := self._translation_table.get(self._params['symbol']):
return {'issueId': issue_id, 'errorMessage': '', 'errorCode': ''}

with requests.request("GET", self._endpoint,
headers=self._headers, params=self._params, timeout=self._timeout) as response:
try:
response.raise_for_status()
response_json = response.json()

# If the endpoint is the translate endpoint, then update the translation table.
if self._endpoint == Endpoints.TRANSLATE.get_endpoint():
self._translation_table[self._params['symbol']] = response_json['issueId']

return response_json
except requests.exceptions.HTTPError as e:
raise APIRequestException(response.status_code, response.text) from e
Expand Down
57 changes: 42 additions & 15 deletions src/cnbc/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,58 @@
"""
from enum import Enum

from .exceptions import InvalidParameterConfiguration


class Endpoints(Enum):
"""
Endpoints for CNBC API.
"""
class Parameters(dict):
"""
Parameters for the endpoint.
"""
def __init__(self, *args, **kwargs):
"""
Initialize the parameters.
:param args: The arguments.
:param kwargs: The keyword arguments.
"""
super().__init__(*args, **kwargs)
self._init_keys = set(self.keys())

def __setitem__(self, key: str, value: str):
"""
Set the item.
:param key: The key.
:param value: The value.
"""
if key not in self._init_keys:
raise InvalidParameterConfiguration()
super().__setitem__(key, value)

HOST: str = "cnbc.p.rapidapi.com"
BASE_URL: str = f"https://{HOST}"
# Default
GET_METADATA: tuple[str, None] = (f"{BASE_URL}/get-meta-data", None)
AUTO_COMPLETE: tuple[str, dict[str, str]] = (f"{BASE_URL}/v2/auto-complete", {"q": None})
GET_METADATA: tuple[str, Parameters] = (f"{BASE_URL}/get-meta-data", Parameters())
AUTO_COMPLETE: tuple[str, Parameters] = (f"{BASE_URL}/v2/auto-complete", Parameters({"q": None}))
# Market
LIST_INDICES: tuple[str, None] = (f"{BASE_URL}/market/list-indices", None)
LIST_INDICES: tuple[str, Parameters] = (f"{BASE_URL}/market/list-indices", Parameters())
# News
LIST_TRENDING_NEWS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-trending", {"tag": "Articles", "count": None})
LIST_SPECIAL_REPORTS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-special-reports", {"pageSize": None, "page": None})
LIST_SYMBOL_NEWS: tuple[str, dict[str, str]] = (f"{BASE_URL}/news/v2/list-by-symbol", {"symbol": None, "page": None, "pageSize": None})
LIST_TRENDING_NEWS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-trending", Parameters({"tag": "Articles", "count": None}))
LIST_SPECIAL_REPORTS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-special-reports", Parameters({"pageSize": None, "page": None}))
LIST_SYMBOL_NEWS: tuple[str, Parameters] = (f"{BASE_URL}/news/v2/list-by-symbol", Parameters({"symbol": None, "page": None, "pageSize": None}))
# Symbol
GET_EARNINGS_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-earnings-chart", {"issueId": None, "numberOfYears": None})
GET_PROFILE: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-profile", {"issueId": None})
GET_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-chart", {"symbol": None, "interval": None})
TRANSLATE: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/translate", {"symbol": None})
GET_SUMMARY: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-summary", {"issueIds": None})
GET_FUNDAMENTALS: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-fundamentals", {"issueIds": None})
GET_PRICELINE_CHART: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-priceline-chart", {"issueId": None, "numberOfDays": None})
GET_PEERS: tuple[str, dict[str, str]] = (f"{BASE_URL}/symbols/get-peers", {"symbol": None})
GET_EARNINGS_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-earnings-chart", Parameters({"issueId": None, "numberOfYears": None}))
GET_PROFILE: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-profile", Parameters({"issueId": None}))
GET_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-chart", Parameters({"symbol": None, "interval": None}))
TRANSLATE: tuple[str, Parameters] = (f"{BASE_URL}/symbols/translate", Parameters({"symbol": None}))
GET_SUMMARY: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-summary", Parameters({"issueIds": None}))
GET_FUNDAMENTALS: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-fundamentals", Parameters({"issueIds": None}))
GET_PRICELINE_CHART: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-priceline-chart", Parameters({"issueId": None, "numberOfDays": None}))
GET_PEERS: tuple[str, Parameters] = (f"{BASE_URL}/symbols/get-peers", Parameters({"symbol": None}))

def get_endpoint(self) -> str | None:
"""
Expand All @@ -38,7 +65,7 @@ def get_endpoint(self) -> str | None:
return None
return self.value[0]

def get_parameters(self) -> dict[str, str] | None:
def get_parameters(self) -> Parameters | None:
"""
Get the parameters of the endpoint.
:return: The parameters of the endpoint.
Expand Down
2 changes: 1 addition & 1 deletion src/cnbc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ class InvalidParameterConfiguration(Exception):
Custom exception for invalid parameter configuration.
"""
def __init__(self):
super().__init__("The supplied parameters are incompatible with the required parameters")
super().__init__("The supplied parameters are incompatible with the required parameters.")
24 changes: 2 additions & 22 deletions tests/cnbc/test_api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,13 @@

from src.cnbc.api_wrapper import APIWrapper
from src.cnbc.endpoints import Endpoints
from src.cnbc.exceptions import InvalidParameterConfiguration


class TestAPIWrapper(unittest.TestCase):
"""
Test the APIWrapper class.
"""

def test_set_params(self):
"""
Test the set_params method.
:return: None
"""
endpoint = APIWrapper("API_KEY", Endpoints.GET_FUNDAMENTALS)
endpoint_params = {"issueIds": None}
endpoint.params = endpoint_params
self.assertEqual(endpoint_params, endpoint.params)

def test_set_params_invalid_parameter_configuration(self):
"""
Test the set_params method with an invalid parameter configuration.
:return: None
"""
endpoint = APIWrapper("API_KEY", Endpoints.GET_FUNDAMENTALS)
with self.assertRaises(InvalidParameterConfiguration):
endpoint.params = {"key": "value"}

@patch('src.cnbc.api_wrapper.requests.request')
def test_request(self, mock_request: MagicMock):
"""
Expand All @@ -53,7 +33,7 @@ def test_request(self, mock_request: MagicMock):
"GET", endpoint.endpoint,
headers=endpoint.headers, params=endpoint.params, timeout=endpoint.timeout
)
self.assertEqual({"key": "value"}, response)
self.assertEqual(mock_response.json.return_value, response)

def test_request_translate_translation_table(self):
"""
Expand All @@ -64,7 +44,7 @@ def test_request_translate_translation_table(self):
json_response_expected = {'issueId': '123', 'errorMessage': '', 'errorCode': ''}

api_wrapper = APIWrapper('API_KEY', Endpoints.TRANSLATE)
api_wrapper._translation_table = translation_table
api_wrapper.translation_table = translation_table
api_wrapper_params = api_wrapper.params
api_wrapper_params['symbol'] = 'AAPL'
json_response = api_wrapper.request()
Expand Down
9 changes: 9 additions & 0 deletions tests/cnbc/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
import unittest

from src.cnbc.endpoints import Endpoints
from src.cnbc.exceptions import InvalidParameterConfiguration


class TestEndpoints(unittest.TestCase):
"""
Test the Endpoints enum.
"""

def test_parameters_set_item_invalid_parameter_configuration(self):
"""
Test the __setitem__ method of the Parameters class for an invalid parameter configuration.
:return: None
"""
with self.assertRaises(InvalidParameterConfiguration):
Endpoints.GET_FUNDAMENTALS.get_parameters()["key"] = "value"

def test_get_endpoint_host(self):
"""
Test the get_endpoint method of the Endpoints enum for the attributes which aren't endpoints.
Expand Down

0 comments on commit 0f449c2

Please sign in to comment.