diff --git a/aiorobinhood/__init__.py b/aiorobinhood/__init__.py index b0e4caf..e709134 100644 --- a/aiorobinhood/__init__.py +++ b/aiorobinhood/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.0" +__version__ = "2.1.0" __all__ = [ "RobinhoodClient", # exceptions diff --git a/aiorobinhood/client.py b/aiorobinhood/client.py index e16d840..5b7d02f 100644 --- a/aiorobinhood/client.py +++ b/aiorobinhood/client.py @@ -5,10 +5,11 @@ from uuid import uuid4 import aiohttp +from yarl import URL from . import models, urls -from .decorators import check_session, check_tokens, mutually_exclusive -from .exceptions import ClientAPIError, ClientRequestError +from .decorators import check_tokens, mutually_exclusive +from .exceptions import ClientAPIError, ClientRequestError, ClientUninitializedError class RobinhoodClient: @@ -58,23 +59,64 @@ async def __aenter__(self) -> "RobinhoodClient": self._session = aiohttp.ClientSession() return self - @check_session async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - assert self._session is not None + if self._session is None: + raise ClientUninitializedError() await self._session.close() self._session = None + async def request( + self, + method: str, + url: URL, + json: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + success_code: int = 200, + ) -> Dict[str, Any]: + """Make a custom request to the Robinhood API servers. + + Args: + method: The HTTP request method. + url: The Robinhood API url. + json: JSON request parameters. + headers: HTTP headers to send with the request. + success_code: The HTTP status code indicating success. + + Returns: + The JSON response from the Robinhood API servers. + + Raises: + AssertionError: The origin of the url is not the Robinhood API servers. + ClientAPIError: Robinhood servers responded with an error. + ClientRequestError: The HTTP request timed out or failed. + ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. + """ + assert url.origin() == urls.BASE + if self._session is None: + raise ClientUninitializedError() + + try: + async with self._session.request( + method, url, headers=headers, json=json, timeout=self._timeout + ) as resp: + response = await resp.json() + if resp.status != success_code: + raise ClientAPIError(resp.method, resp.url, resp.status, response) + + return response + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise ClientRequestError(method, url) from e + ################################################################################### # OAUTH # ################################################################################### - @check_session async def login( self, username: str, @@ -96,67 +138,57 @@ async def login( ClientRequestError: The HTTP request timed out or failed. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - - url = urls.LOGIN + headers = {"x-robinhood-challenge-response-id": kwargs.get("challenge_id", "")} + json = { + "challenge_type": challenge_type.value, + "client_id": self._CLIENT_ID, + "device_token": self._device_token, + "expires_in": expires_in, + "grant_type": "password", + "mfa_code": kwargs.get("mfa_code", ""), + "password": password, + "scope": "internal", + "username": username, + } try: - async with self._session.post( - url, - timeout=self._timeout, - headers={ - "x-robinhood-challenge-response-id": kwargs.get("challenge_id", "") - }, - json={ - "challenge_type": challenge_type.value, - "client_id": self._CLIENT_ID, - "device_token": self._device_token, - "expires_in": expires_in, - "grant_type": "password", - "mfa_code": kwargs.get("mfa_code", ""), - "password": password, - "scope": "internal", - "username": username, - }, - ) as resp: - response = await resp.json() - while ( - "challenge" in response - and response["challenge"]["remaining_attempts"] > 0 - ): - url = urls.CHALLENGE / response["challenge"]["id"] / "respond/" - challenge_id = input(f"Enter the {challenge_type.value} code: ") - async with self._session.post( - url, timeout=self._timeout, json={"response": challenge_id}, - ) as resp: - response = await resp.json() - - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - elif "id" in response: - # Try again with challenge_id if challenge is passed - return await self.login( - username, password, expires_in, challenge_id=response["id"], - ) - elif response.get("mfa_required"): - # Try again with mfa_code if 2fac is enabled - mfa_code = input(f"Enter the {response['mfa_type']} code: ") - return await self.login( - username, password, expires_in, mfa_code=mfa_code, - ) - else: - self._access_token = f"Bearer {response['access_token']}" - self._refresh_token = response["refresh_token"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("POST", url) from e + response = await self.request( + "POST", urls.LOGIN, headers=headers, json=json + ) + if response.get("mfa_required"): + # Try again with mfa_code if 2fac is enabled + mfa_code = input(f"Enter the {response['mfa_type']} code: ") + return await self.login( + username, password, expires_in, mfa_code=mfa_code + ) + except ClientAPIError as e: + response = e.response + if "challenge" not in response: + raise e + + while True: + url = urls.CHALLENGE / response["challenge"]["id"] / "respond/" + json = {"response": input(f"Enter the {challenge_type.value} code: ")} + try: + response = await self.request("POST", url, json=json) + if "id" in response: + # Try again with challenge_id if challenge is passed + return await self.login( + username, password, expires_in, challenge_id=response["id"] + ) + except ClientAPIError as e: + if e.response["challenge"]["remaining_attempts"] == 0: + raise e from None - # Fetch the account info during login for other methods + self._access_token = f"Bearer {response['access_token']}" + self._refresh_token = response["refresh_token"] + + # Fetch the account info for other methods account = await self.get_account() self._account_url = account["url"] self._account_num = account["account_number"] @check_tokens - @check_session async def logout(self) -> None: """Invalidate the current session tokens. @@ -166,25 +198,12 @@ async def logout(self) -> None: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - - try: - async with self._session.post( - urls.LOGOUT, - timeout=self._timeout, - json={"client_id": self._CLIENT_ID, "token": self._refresh_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - self._access_token = None - self._refresh_token = None - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("POST", urls.LOGOUT) from e + json = {"client_id": self._CLIENT_ID, "token": self._refresh_token} + await self.request("POST", urls.LOGOUT, json=json) + self._access_token = None + self._refresh_token = None @check_tokens - @check_session async def refresh(self, expires_in: int = 86400) -> None: """Fetch a fresh set session tokens. @@ -197,28 +216,16 @@ async def refresh(self, expires_in: int = 86400) -> None: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - - try: - async with self._session.post( - urls.LOGIN, - timeout=self._timeout, - json={ - "client_id": self._CLIENT_ID, - "expires_in": expires_in, - "grant_type": "refresh_token", - "refresh_token": self._refresh_token, - "scope": "internal", - }, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - self._access_token = f"Bearer {response['access_token']}" - self._refresh_token = response["refresh_token"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("POST", urls.LOGIN) from e + json = { + "client_id": self._CLIENT_ID, + "expires_in": expires_in, + "grant_type": "refresh_token", + "refresh_token": self._refresh_token, + "scope": "internal", + } + response = await self.request("POST", urls.LOGIN, json=json) + self._access_token = f"Bearer {response['access_token']}" + self._refresh_token = response["refresh_token"] @check_tokens async def dump(self) -> None: @@ -259,7 +266,6 @@ async def load(self) -> None: ################################################################################### @check_tokens - @check_session async def get_account(self) -> Dict[str, Any]: """Fetch information associated with the Robinhood account. @@ -272,24 +278,11 @@ async def get_account(self) -> Dict[str, Any]: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - - try: - async with self._session.get( - urls.ACCOUNTS, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response["results"][0] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", urls.ACCOUNTS) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", urls.ACCOUNTS, headers=headers) + return response["results"][0] @check_tokens - @check_session async def get_portfolio(self) -> Dict[str, Any]: """Fetch the portfolio information associated with the Robinhood account. @@ -303,24 +296,11 @@ async def get_portfolio(self) -> Dict[str, Any]: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - - try: - async with self._session.get( - urls.PORTFOLIOS, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response["results"][0] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", urls.PORTFOLIOS) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", urls.PORTFOLIOS, headers=headers) + return response["results"][0] @check_tokens - @check_session async def get_historical_portfolio( self, interval: models.HistoricalInterval, @@ -348,8 +328,6 @@ async def get_historical_portfolio( Certain combinations of ``interval`` and ``span`` will be rejected by Robinhood. """ - assert self._session is not None - url = (urls.PORTFOLIOS / "historicals" / f"{self._account_num}/").with_query( { "bounds": "extended" if extended_hours else "regular", @@ -357,27 +335,14 @@ async def get_historical_portfolio( "span": span.value, } ) - - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + headers = {"Authorization": self._access_token} + return await self.request("GET", url, headers=headers) ################################################################################### # ACCOUNT # ################################################################################### @check_tokens - @check_session async def get_positions( self, nonzero: bool = True, pages: Optional[int] = None ) -> List[Dict[str, Any]]: @@ -397,34 +362,18 @@ async def get_positions( ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - results = [] url = urls.POSITIONS.with_query({"nonzero": str(nonzero).lower()}) - + headers = {"Authorization": self._access_token} while url is not None and (pages is None or pages > 0): - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError( - resp.method, resp.url, resp.status, response - ) - - results += response["results"] - url = response["next"] - pages = pages and pages - 1 - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + response = await self.request("GET", URL(url), headers=headers) + results += response["results"] + url = response["next"] + pages = pages and pages - 1 return results @check_tokens - @check_session async def get_watchlist( self, watchlist: str = "Default", pages: Optional[int] = None ) -> List[str]: @@ -443,34 +392,18 @@ async def get_watchlist( ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - results = [] url = urls.WATCHLISTS / f"{watchlist}/" - + headers = {"Authorization": self._access_token} while url is not None and (pages is None or pages > 0): - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError( - resp.method, resp.url, resp.status, response - ) - - results += response["results"] - url = response["next"] - pages = pages and pages - 1 - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + response = await self.request("GET", URL(url), headers=headers) + results += response["results"] + url = response["next"] + pages = pages and pages - 1 return [result["instrument"] for result in results] @check_tokens - @check_session async def add_to_watchlist( self, instrument: str, watchlist: str = "Default" ) -> None: @@ -486,25 +419,12 @@ async def add_to_watchlist( ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - url = urls.WATCHLISTS / f"{watchlist}/" - - try: - async with self._session.post( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - json={"instrument": instrument}, - ) as resp: - response = await resp.json() - if resp.status != 201: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("POST", url) from e + headers = {"Authorization": self._access_token} + json = {"instrument": instrument} + await self.request("POST", url, headers=headers, json=json, success_code=201) @check_tokens - @check_session async def remove_from_watchlist(self, id_: str, watchlist: str = "Default") -> None: """Remove a security from the given watchlist. @@ -518,21 +438,9 @@ async def remove_from_watchlist(self, id_: str, watchlist: str = "Default") -> N ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - url = urls.WATCHLISTS / watchlist / f"{id_}/" - - try: - async with self._session.delete( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 204: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("DELETE", url) from e + headers = {"Authorization": self._access_token} + await self.request("DELETE", url, headers=headers, success_code=204) ################################################################################### # STOCKS # @@ -540,7 +448,6 @@ async def remove_from_watchlist(self, id_: str, watchlist: str = "Default") -> N @mutually_exclusive("symbols", "instruments") @check_tokens - @check_session async def get_fundamentals( self, symbols: Optional[Iterable[str]] = None, @@ -563,30 +470,17 @@ async def get_fundamentals( ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. ValueError: Both/neither of ``symbols`` and ``instruments`` are supplied. """ - assert self._session is not None - if symbols is not None: url = urls.FUNDAMENTALS.with_query({"symbols": ",".join(symbols)}) elif instruments is not None: url = urls.FUNDAMENTALS.with_query({"instruments": ",".join(instruments)}) - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response["results"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", url, headers=headers) + return response["results"] @mutually_exclusive("symbol", "ids") @check_tokens - @check_session async def get_instruments( self, symbol: Optional[str] = None, @@ -611,38 +505,23 @@ async def get_instruments( ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. ValueError: Both/neither of ``symbol`` and ``ids`` are supplied. """ - assert self._session is not None - results = [] if symbol is not None: url = urls.INSTRUMENTS.with_query({"symbol": symbol}) elif ids is not None: url = urls.INSTRUMENTS.with_query({"ids": ",".join(ids)}) + headers = {"Authorization": self._access_token} while url is not None and (pages is None or pages > 0): - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError( - resp.method, resp.url, resp.status, response - ) - - results += response["results"] - url = response["next"] - pages = pages and pages - 1 - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + response = await self.request("GET", URL(url), headers=headers) + results += response["results"] + url = response["next"] + pages = pages and pages - 1 return results @mutually_exclusive("symbols", "instruments") @check_tokens - @check_session async def get_quotes( self, symbols: Optional[Iterable[str]] = None, @@ -665,30 +544,17 @@ async def get_quotes( ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. ValueError: Both/neither of ``symbols`` and ``instruments`` are supplied. """ - assert self._session is not None - if symbols is not None: url = urls.QUOTES.with_query({"symbols": ",".join(symbols)}) elif instruments is not None: url = urls.QUOTES.with_query({"instruments": ",".join(instruments)}) - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response["results"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", url, headers=headers) + return response["results"] @mutually_exclusive("symbols", "instruments") @check_tokens - @check_session async def get_historical_quotes( self, interval: models.HistoricalInterval, @@ -721,8 +587,6 @@ async def get_historical_quotes( Certain combinations of ``interval`` and ``span`` will be rejected by Robinhood. """ - assert self._session is not None - url = urls.HISTORICALS.with_query( { "bounds": "extended" if extended_hours else "regular", @@ -735,22 +599,11 @@ async def get_historical_quotes( elif instruments is not None: url = url.update_query({"instruments": ",".join(instruments)}) - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response["results"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", url, headers=headers) + return response["results"] @check_tokens - @check_session async def get_ratings( self, ids: Iterable[str], pages: Optional[int] = None ) -> List[Dict[str, Any]]: @@ -769,34 +622,18 @@ async def get_ratings( ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - results = [] url = urls.RATINGS.with_query({"ids": ",".join(ids)}) - + headers = {"Authorization": self._access_token} while url is not None and (pages is None or pages > 0): - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError( - resp.method, resp.url, resp.status, response - ) - - results += response["results"] - url = response.get("next") - pages = pages and pages - 1 - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + response = await self.request("GET", URL(url), headers=headers) + results += response["results"] + url = response["next"] + pages = pages and pages - 1 return results @check_tokens - @check_session async def get_tags(self, id_: str) -> List[str]: """Fetch the tags for a particular security. @@ -812,26 +649,12 @@ async def get_tags(self, id_: str) -> List[str]: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - url = urls.TAGS / "instrument" / f"{id_}/" - - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return [tag["slug"] for tag in response["tags"]] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", url, headers=headers) + return [tag["slug"] for tag in response["tags"]] @check_tokens - @check_session async def get_tag_members(self, tag: str) -> List[str]: """Fetch the instruments belonging to a particular tag. @@ -847,30 +670,16 @@ async def get_tag_members(self, tag: str) -> List[str]: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - url = urls.TAGS / "tag" / f"{tag}/" - - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - - return response["instruments"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + headers = {"Authorization": self._access_token} + response = await self.request("GET", url, headers=headers) + return response["instruments"] ################################################################################### # ORDERS # ################################################################################### @check_tokens - @check_session async def get_orders( self, order_id: Optional[str] = None, pages: Optional[int] = None ) -> List[Dict[str, Any]]: @@ -891,34 +700,20 @@ async def get_orders( ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - results = [] - url = urls.ORDERS if order_id is None else urls.ORDERS / f"{order_id}/" - + # fmt: off + url: Optional[URL] = urls.ORDERS if order_id is None else urls.ORDERS / f"{order_id}/" # noqa: E501 + # fmt: on + headers = {"Authorization": self._access_token} while url is not None and (pages is None or pages > 0): - try: - async with self._session.get( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError( - resp.method, resp.url, resp.status, response - ) - - results += response.get("results", [response]) - url = response.get("next") - pages = pages and pages - 1 - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("GET", url) from e + response = await self.request("GET", URL(url), headers=headers) + results += response.get("results", [response]) + url = response.get("next") + pages = pages and pages - 1 return results @check_tokens - @check_session async def cancel_order(self, order_id: str) -> None: """Cancel an order. @@ -931,41 +726,29 @@ async def cancel_order(self, order_id: str) -> None: ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. """ - assert self._session is not None - url = urls.ORDERS / order_id / "cancel/" - - try: - async with self._session.post( - url, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - ) as resp: - response = await resp.json() - if resp.status != 200: - raise ClientAPIError(resp.method, resp.url, resp.status, response) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("POST", url) from e + headers = {"Authorization": self._access_token} + await self.request("POST", url, headers=headers) @check_tokens - @check_session async def place_order(self, **kwargs) -> str: - assert self._session is not None + """Place a custom order. - try: - async with self._session.post( - urls.ORDERS, - timeout=self._timeout, - headers={"Authorization": self._access_token}, - json={"account": self._account_url, "ref_id": str(uuid4()), **kwargs}, - ) as resp: - response = await resp.json() - if resp.status != 201: - raise ClientAPIError(resp.method, resp.url, resp.status, response) + Returns: + The order ID. - return response["id"] - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise ClientRequestError("POST", urls.ORDERS) from e + Raises: + ClientAPIError: Robinhood server responded with an error. + ClientRequestError: The HTTP request timed out or failed. + ClientUnauthenticatedError: The :class:`~.RobinhoodClient` is not logged in. + ClientUninitializedError: The :class:`~.RobinhoodClient` is not initialized. + """ + headers = {"Authorization": self._access_token} + json = {"account": self._account_url, "ref_id": str(uuid4()), **kwargs} + response = await self.request( + "POST", urls.ORDERS, headers=headers, json=json, success_code=201 + ) + return response["id"] async def place_limit_buy_order( self, diff --git a/aiorobinhood/decorators.py b/aiorobinhood/decorators.py index 47cd8ba..cc658b6 100644 --- a/aiorobinhood/decorators.py +++ b/aiorobinhood/decorators.py @@ -1,17 +1,7 @@ from functools import wraps from typing import Callable -from .exceptions import ClientUnauthenticatedError, ClientUninitializedError - - -def check_session(func: Callable): - @wraps(func) - async def inner(self, *args, **kwargs): - if self._session is None: - raise ClientUninitializedError() - return await func(self, *args, **kwargs) - - return inner +from .exceptions import ClientUnauthenticatedError def check_tokens(func: Callable): diff --git a/docs/client.rst b/docs/client.rst index 6711052..fb4b560 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -11,6 +11,12 @@ All communication with the Robinhood servers is done through the .. autoclass:: RobinhoodClient +All :class:`~.RobinhoodClient` request methods call the :meth:`~.request` helper +method below. This method can also be used to craft custom API requests that are not +encapsulated by :class:`~.RobinhoodClient` API methods. + +.. automethod:: RobinhoodClient.request + Authentication ============== @@ -58,6 +64,11 @@ Placing Orders .. warning:: Robinhood rate limits the ``/orders`` endpoint used by the following methods. +All :class:`~.RobinhoodClient` order methods call the :meth:`~.place_order` helper +method below. This method can also be used to craft custom order requests that are not +encapsulated by :class:`~.RobinhoodClient` order API methods. + +.. automethod:: RobinhoodClient.place_order .. automethod:: RobinhoodClient.place_limit_buy_order .. automethod:: RobinhoodClient.place_limit_sell_order .. automethod:: RobinhoodClient.place_market_buy_order @@ -65,4 +76,4 @@ Placing Orders .. automethod:: RobinhoodClient.place_stop_buy_order .. automethod:: RobinhoodClient.place_stop_sell_order .. automethod:: RobinhoodClient.place_stop_limit_buy_order -.. automethod:: RobinhoodClient.place_stop_limit_sell_order \ No newline at end of file +.. automethod:: RobinhoodClient.place_stop_limit_sell_order diff --git a/docs/exceptions.rst b/docs/exceptions.rst index cfc7796..7bf85c8 100644 --- a/docs/exceptions.rst +++ b/docs/exceptions.rst @@ -20,4 +20,4 @@ Hierarchy * :exc:`ClientUnauthenticatedError` * :exc:`ClientRequestError` - * :exc:`ClientAPIError` \ No newline at end of file + * :exc:`ClientAPIError` diff --git a/docs/misc.rst b/docs/misc.rst index e196687..1757cc8 100644 --- a/docs/misc.rst +++ b/docs/misc.rst @@ -9,4 +9,4 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` -* :ref:`search` \ No newline at end of file +* :ref:`search` diff --git a/docs/reference.rst b/docs/reference.rst index 19945c2..f726c71 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -9,4 +9,4 @@ Reference Client Exceptions - Models \ No newline at end of file + Models diff --git a/tests/test_account.py b/tests/test_account.py index 4e33804..daa35c9 100644 --- a/tests/test_account.py +++ b/tests/test_account.py @@ -3,7 +3,6 @@ import pytest -from aiorobinhood import ClientAPIError, ClientRequestError from aiorobinhood.urls import POSITIONS, WATCHLISTS @@ -38,39 +37,6 @@ async def test_get_positions(logged_in_client): assert result == [{"foo": "bar"}, {"baz": "quux"}] -@pytest.mark.asyncio -async def test_get_positions_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_positions(nonzero=False)) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == POSITIONS.path - assert request.query["nonzero"] == "false" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_positions_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_positions()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == POSITIONS.path - assert request.query["nonzero"] == "true" - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_get_watchlist(logged_in_client): client, server = logged_in_client @@ -100,37 +66,6 @@ async def test_get_watchlist(logged_in_client): assert result == ["<>", "><"] -@pytest.mark.asyncio -async def test_get_watchlist_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_watchlist()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (WATCHLISTS / "Default/").path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_watchlist_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_watchlist()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (WATCHLISTS / "Default/").path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_add_to_watchlist(logged_in_client): client, server = logged_in_client @@ -147,39 +82,6 @@ async def test_add_to_watchlist(logged_in_client): assert result is None -@pytest.mark.asyncio -async def test_add_to_watchlist_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.add_to_watchlist(instrument="<>")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (WATCHLISTS / "Default/").path - assert (await request.json())["instrument"] == "<>" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_add_to_watchlist_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.add_to_watchlist(instrument="<>")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (WATCHLISTS / "Default/").path - assert (await request.json())["instrument"] == "<>" - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_remove_from_watchlist(logged_in_client): client, server = logged_in_client @@ -193,34 +95,3 @@ async def test_remove_from_watchlist(logged_in_client): result = await asyncio.wait_for(task, pytest.TIMEOUT) assert result is None - - -@pytest.mark.asyncio -async def test_remove_from_watchlist_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.remove_from_watchlist(id_="12345")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "DELETE" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (WATCHLISTS / "Default" / "12345/").path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_remove_from_watchlist_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.remove_from_watchlist(id_="12345")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "DELETE" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (WATCHLISTS / "Default" / "12345/").path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..3972e60 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,34 @@ +import asyncio + +import pytest + +from aiorobinhood import ClientAPIError, ClientRequestError + + +@pytest.mark.asyncio +async def test_request_api_error(logged_in_client): + client, server = logged_in_client + task = asyncio.create_task(client.request("GET", pytest.NEXT)) + + request = await server.receive_request(timeout=pytest.TIMEOUT) + assert request.method == "GET" + assert request.path == pytest.NEXT.path + server.send_response(request, status=400, content_type="application/json") + + with pytest.raises(ClientAPIError): + await task + + +@pytest.mark.asyncio +async def test_request_timeout_error(logged_in_client): + client, server = logged_in_client + task = asyncio.create_task(client.request("GET", pytest.NEXT)) + + request = await server.receive_request(timeout=pytest.TIMEOUT) + assert request.method == "GET" + assert request.path == pytest.NEXT.path + + with pytest.raises(ClientRequestError) as exc_info: + await asyncio.sleep(pytest.TIMEOUT + 1) + await task + assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) diff --git a/tests/test_magic.py b/tests/test_magic.py index a7cb2ad..41eb5af 100644 --- a/tests/test_magic.py +++ b/tests/test_magic.py @@ -1,7 +1,7 @@ import aiohttp import pytest -from aiorobinhood import RobinhoodClient +from aiorobinhood import ClientUninitializedError, RobinhoodClient @pytest.mark.asyncio @@ -9,3 +9,10 @@ async def test_async_context_manager(): async with RobinhoodClient(timeout=pytest.TIMEOUT) as client: assert client._session is not None assert isinstance(client._session, aiohttp.ClientSession) + + +@pytest.mark.asyncio +async def test_async_context_manager_client_uninitialized_error(): + with pytest.raises(ClientUninitializedError): + async with RobinhoodClient(timeout=pytest.TIMEOUT) as client: + client._session = None diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 6b49826..8490c54 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -41,6 +41,7 @@ async def test_login_sfa_flow(logged_out_client): assert request.path == LOGIN.path server.send_response( request, + status=400, content_type="application/json", text=json.dumps( {"challenge": {"id": challenge_id, "remaining_attempts": 3}} @@ -150,41 +151,67 @@ async def test_login_mfa_flow(logged_out_client): @pytest.mark.asyncio -async def test_login_uninitialized_client(): - client = RobinhoodClient(timeout=pytest.TIMEOUT) - with pytest.raises(ClientUninitializedError): - await client.login(username="robin", password="hood") +async def test_login_api_error(logged_out_client): + client, server = logged_out_client + challenge_code = "123456" + + with replace_input(StringIO(challenge_code)): + task = asyncio.create_task(client.login(username="robin", password="hood")) + + request = await server.receive_request(timeout=pytest.TIMEOUT) + assert request.method == "POST" + assert request.path == LOGIN.path + server.send_response( + request, status=400, content_type="application/json", text=json.dumps({}), + ) + + with pytest.raises(ClientAPIError): + await task @pytest.mark.asyncio -async def test_login_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.login(username="robin", password="hood")) +async def test_login_sfa_zero_challenge_attempts(logged_out_client): + client, server = logged_out_client + challenge_code = "123456" + challenge_id = "abcdef" - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.path == LOGIN.path - server.send_response( - request, status=400, content_type="application/json", text=json.dumps({}) - ) + with replace_input(StringIO(challenge_code)): + task = asyncio.create_task(client.login(username="robin", password="hood")) - with pytest.raises(ClientAPIError): - await task + request = await server.receive_request(timeout=pytest.TIMEOUT) + assert request.method == "POST" + assert request.path == LOGIN.path + server.send_response( + request, + status=400, + content_type="application/json", + text=json.dumps( + {"challenge": {"id": challenge_id, "remaining_attempts": 1}} + ), + ) + request = await server.receive_request(timeout=pytest.TIMEOUT) + assert request.method == "POST" + assert (await request.json())["response"] == challenge_code + assert request.path == f"{CHALLENGE.path}{challenge_id}/respond/" + server.send_response( + request, + status=400, + content_type="application/json", + text=json.dumps( + {"challenge": {"id": challenge_id, "remaining_attempts": 0}} + ), + ) -@pytest.mark.asyncio -async def test_login_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.login(username="robin", password="hood")) + with pytest.raises(ClientAPIError): + await task - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.path == LOGIN.path - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) +@pytest.mark.asyncio +async def test_login_uninitialized_client(): + client = RobinhoodClient(timeout=pytest.TIMEOUT) + with pytest.raises(ClientUninitializedError): + await client.login(username="robin", password="hood") @pytest.mark.asyncio @@ -230,37 +257,6 @@ async def test_logout(logged_in_client): assert result is None -@pytest.mark.asyncio -async def test_logout_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.logout()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert (await request.json())["token"] == pytest.REFRESH_TOKEN - assert request.path == LOGOUT.path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_logout_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.logout()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert (await request.json())["token"] == pytest.REFRESH_TOKEN - assert request.path == LOGOUT.path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_logout_unauthenticated_client(logged_out_client): client, _ = logged_out_client @@ -294,41 +290,6 @@ async def test_refresh(logged_in_client): assert result is None -@pytest.mark.asyncio -async def test_refresh_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.refresh()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - request_json = await request.json() - assert request_json["grant_type"] == "refresh_token" - assert request_json["refresh_token"] == pytest.REFRESH_TOKEN - assert request.path == LOGIN.path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_refresh_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.refresh()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - request_json = await request.json() - assert request_json["grant_type"] == "refresh_token" - assert request_json["refresh_token"] == pytest.REFRESH_TOKEN - assert request.path == LOGIN.path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_dump(logged_in_client): client, _ = logged_in_client diff --git a/tests/test_orders.py b/tests/test_orders.py index 20a0513..826962b 100644 --- a/tests/test_orders.py +++ b/tests/test_orders.py @@ -3,7 +3,6 @@ import pytest -from aiorobinhood import ClientAPIError, ClientRequestError from aiorobinhood.urls import INSTRUMENTS, ORDERS, QUOTES @@ -36,38 +35,6 @@ async def test_get_orders(logged_in_client): assert result == [{"foo": "bar"}, {"baz": "quux"}] -@pytest.mark.asyncio -async def test_get_orders_api_error(logged_in_client): - client, server = logged_in_client - order_id = "12345" - task = asyncio.create_task(client.get_orders(order_id)) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (ORDERS / f"{order_id}/").path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_orders_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_orders()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == ORDERS.path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_cancel_order(logged_in_client): client, server = logged_in_client @@ -84,72 +51,6 @@ async def test_cancel_order(logged_in_client): assert result is None -@pytest.mark.asyncio -async def test_cancel_order_api_error(logged_in_client): - client, server = logged_in_client - order_id = "12345" - task = asyncio.create_task(client.cancel_order(order_id)) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (ORDERS / order_id / "cancel/").path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_cancel_order_timeout_error(logged_in_client): - client, server = logged_in_client - order_id = "12345" - task = asyncio.create_task(client.cancel_order(order_id)) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (ORDERS / order_id / "cancel/").path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - -@pytest.mark.asyncio -async def test_order_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.place_order()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == ORDERS.path - assert (await request.json())["account"] == pytest.ACCOUNT_URL - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_order_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.place_order()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "POST" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == ORDERS.path - assert (await request.json())["account"] == pytest.ACCOUNT_URL - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_place_limit_buy_order(logged_in_client): client, server = logged_in_client diff --git a/tests/test_profile.py b/tests/test_profile.py index 9f888a6..b447312 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -3,12 +3,7 @@ import pytest -from aiorobinhood import ( - ClientAPIError, - ClientRequestError, - HistoricalInterval, - HistoricalSpan, -) +from aiorobinhood import HistoricalInterval, HistoricalSpan from aiorobinhood.urls import ACCOUNTS, PORTFOLIOS @@ -29,37 +24,6 @@ async def test_get_account(logged_in_client): assert result == {} -@pytest.mark.asyncio -async def test_get_account_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_account()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == ACCOUNTS.path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_account_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_account()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == ACCOUNTS.path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_get_portfolio(logged_in_client): client, server = logged_in_client @@ -77,37 +41,6 @@ async def test_get_portfolio(logged_in_client): assert result == {} -@pytest.mark.asyncio -async def test_get_portfolio_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_portfolio()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == PORTFOLIOS.path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_portfolio_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_portfolio()) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == PORTFOLIOS.path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_get_historical_portfolio(logged_in_client): client, server = logged_in_client @@ -130,48 +63,3 @@ async def test_get_historical_portfolio(logged_in_client): result = await asyncio.wait_for(task, pytest.TIMEOUT) assert result == {} - - -@pytest.mark.asyncio -async def test_get_historical_portfolio_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task( - client.get_historical_portfolio( - interval=HistoricalInterval.FIVE_MIN, span=HistoricalSpan.DAY - ) - ) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (PORTFOLIOS / "historicals" / f"{pytest.ACCOUNT_NUM}/").path - assert request.query["bounds"] == "regular" - assert request.query["interval"] == HistoricalInterval.FIVE_MIN.value - assert request.query["span"] == HistoricalSpan.DAY.value - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_historical_portfolio_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task( - client.get_historical_portfolio( - interval=HistoricalInterval.FIVE_MIN, span=HistoricalSpan.DAY - ) - ) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (PORTFOLIOS / "historicals" / f"{pytest.ACCOUNT_NUM}/").path - assert request.query["bounds"] == "regular" - assert request.query["interval"] == HistoricalInterval.FIVE_MIN.value - assert request.query["span"] == HistoricalSpan.DAY.value - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) diff --git a/tests/test_stocks.py b/tests/test_stocks.py index 36acef9..bc3d365 100644 --- a/tests/test_stocks.py +++ b/tests/test_stocks.py @@ -3,12 +3,7 @@ import pytest -from aiorobinhood import ( - ClientAPIError, - ClientRequestError, - HistoricalInterval, - HistoricalSpan, -) +from aiorobinhood import HistoricalInterval, HistoricalSpan from aiorobinhood.urls import ( FUNDAMENTALS, HISTORICALS, @@ -20,7 +15,7 @@ @pytest.mark.asyncio -async def test_get_fundamentals(logged_in_client): +async def test_get_fundamentals_by_symbols(logged_in_client): client, server = logged_in_client task = asyncio.create_task(client.get_fundamentals(symbols=["ABCD"])) @@ -38,7 +33,7 @@ async def test_get_fundamentals(logged_in_client): @pytest.mark.asyncio -async def test_get_fundamentals_api_error(logged_in_client): +async def test_get_fundamentals_by_instruments(logged_in_client): client, server = logged_in_client task = asyncio.create_task(client.get_fundamentals(instruments=["<>"])) @@ -47,27 +42,12 @@ async def test_get_fundamentals_api_error(logged_in_client): assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" assert request.path == FUNDAMENTALS.path assert request.query["instruments"] == "<>" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_fundamentals_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_fundamentals(symbols=["ABCD"])) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == FUNDAMENTALS.path - assert request.query["symbols"] == "ABCD" + server.send_response( + request, content_type="application/json", text=json.dumps({"results": [{}]}), + ) - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) + result = await asyncio.wait_for(task, pytest.TIMEOUT) + assert result == [{}] @pytest.mark.asyncio @@ -80,7 +60,7 @@ async def test_get_fundamentals_value_error(logged_in_client): @pytest.mark.asyncio -async def test_get_instruments(logged_in_client): +async def test_get_instruments_by_symbol(logged_in_client): client, server = logged_in_client task = asyncio.create_task(client.get_instruments(symbol="ABCD")) @@ -110,7 +90,7 @@ async def test_get_instruments(logged_in_client): @pytest.mark.asyncio -async def test_get_instruments_api_error(logged_in_client): +async def test_get_instruments_by_ids(logged_in_client): client, server = logged_in_client task = asyncio.create_task(client.get_instruments(ids=["12345"])) @@ -119,27 +99,24 @@ async def test_get_instruments_api_error(logged_in_client): assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" assert request.path == INSTRUMENTS.path assert request.query["ids"] == "12345" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_instruments_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_instruments(symbol="ABCD")) + server.send_response( + request, + content_type="application/json", + text=json.dumps({"next": str(pytest.NEXT), "results": [{"foo": "bar"}]}), + ) request = await server.receive_request(timeout=pytest.TIMEOUT) assert request.method == "GET" assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == INSTRUMENTS.path - assert request.query["symbol"] == "ABCD" + assert request.path == pytest.NEXT.path + server.send_response( + request, + content_type="application/json", + text=json.dumps({"next": None, "results": [{"baz": "quux"}]}), + ) - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) + result = await asyncio.wait_for(task, pytest.TIMEOUT) + assert result == [{"foo": "bar"}, {"baz": "quux"}] @pytest.mark.asyncio @@ -152,7 +129,7 @@ async def test_get_instruments_value_error(logged_in_client): @pytest.mark.asyncio -async def test_get_quotes(logged_in_client): +async def test_get_quotes_by_symbols(logged_in_client): client, server = logged_in_client task = asyncio.create_task(client.get_quotes(symbols=["ABCD"])) @@ -170,7 +147,7 @@ async def test_get_quotes(logged_in_client): @pytest.mark.asyncio -async def test_get_quotes_api_error(logged_in_client): +async def test_get_quotes_by_instruments(logged_in_client): client, server = logged_in_client task = asyncio.create_task(client.get_quotes(instruments=["<>"])) @@ -179,27 +156,12 @@ async def test_get_quotes_api_error(logged_in_client): assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" assert request.path == QUOTES.path assert request.query["instruments"] == "<>" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_quotes_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_quotes(symbols=["ABCD"])) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == QUOTES.path - assert request.query["symbols"] == "ABCD" + server.send_response( + request, content_type="application/json", text=json.dumps({"results": [{}]}), + ) - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) + result = await asyncio.wait_for(task, pytest.TIMEOUT) + assert result == [{}] @pytest.mark.asyncio @@ -212,7 +174,7 @@ async def test_get_quotes_value_error(logged_in_client): @pytest.mark.asyncio -async def test_get_historical_quotes(logged_in_client): +async def test_get_historical_quotes_by_symbols(logged_in_client): client, server = logged_in_client task = asyncio.create_task( client.get_historical_quotes( @@ -239,13 +201,12 @@ async def test_get_historical_quotes(logged_in_client): @pytest.mark.asyncio -async def test_get_historical_quotes_api_error(logged_in_client): +async def test_get_historical_quotes_by_instruments(logged_in_client): client, server = logged_in_client task = asyncio.create_task( client.get_historical_quotes( interval=HistoricalInterval.FIVE_MIN, span=HistoricalSpan.DAY, - extended_hours=True, instruments=["<>"], ) ) @@ -254,40 +215,16 @@ async def test_get_historical_quotes_api_error(logged_in_client): assert request.method == "GET" assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" assert request.path == HISTORICALS.path - assert request.query["bounds"] == "extended" + assert request.query["bounds"] == "regular" assert request.query["interval"] == HistoricalInterval.FIVE_MIN.value assert request.query["span"] == HistoricalSpan.DAY.value assert request.query["instruments"] == "<>" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_historical_quotes_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task( - client.get_historical_quotes( - interval=HistoricalInterval.FIVE_MIN, - span=HistoricalSpan.DAY, - symbols=["ABCD"], - ) + server.send_response( + request, content_type="application/json", text=json.dumps({"results": [{}]}), ) - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == HISTORICALS.path - assert request.query["bounds"] == "regular" - assert request.query["interval"] == HistoricalInterval.FIVE_MIN.value - assert request.query["span"] == HistoricalSpan.DAY.value - assert request.query["symbols"] == "ABCD" - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) + result = await asyncio.wait_for(task, pytest.TIMEOUT) + assert result == [{}] @pytest.mark.asyncio @@ -329,39 +266,6 @@ async def test_get_ratings(logged_in_client): assert result == [{"foo": "bar"}, {"baz": "quux"}] -@pytest.mark.asyncio -async def test_get_ratings_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_ratings(ids=["12345", "67890"])) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == RATINGS.path - assert request.query["ids"] == "12345,67890" - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_ratings_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_ratings(ids=["12345", "67890"])) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == RATINGS.path - assert request.query["ids"] == "12345,67890" - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_get_tags(logged_in_client): client, server = logged_in_client @@ -381,37 +285,6 @@ async def test_get_tags(logged_in_client): assert result == ["foo"] -@pytest.mark.asyncio -async def test_get_tags_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_tags(id_="12345")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (TAGS / "instrument" / "12345/").path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_tags_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_tags(id_="12345")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (TAGS / "instrument" / "12345/").path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError) - - @pytest.mark.asyncio async def test_get_tag_members(logged_in_client): client, server = logged_in_client @@ -429,34 +302,3 @@ async def test_get_tag_members(logged_in_client): result = await asyncio.wait_for(task, pytest.TIMEOUT) assert result == ["<>"] - - -@pytest.mark.asyncio -async def test_get_tag_members_api_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_tag_members(tag="foo")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (TAGS / "tag" / "foo/").path - server.send_response(request, status=400, content_type="application/json") - - with pytest.raises(ClientAPIError): - await task - - -@pytest.mark.asyncio -async def test_get_tag_members_timeout_error(logged_in_client): - client, server = logged_in_client - task = asyncio.create_task(client.get_tag_members(tag="foo")) - - request = await server.receive_request(timeout=pytest.TIMEOUT) - assert request.method == "GET" - assert request.headers["Authorization"] == f"Bearer {pytest.ACCESS_TOKEN}" - assert request.path == (TAGS / "tag" / "foo/").path - - with pytest.raises(ClientRequestError) as exc_info: - await asyncio.sleep(pytest.TIMEOUT + 1) - await task - assert isinstance(exc_info.value.__cause__, asyncio.TimeoutError)