diff --git a/cschwabpy/SchwabAsyncClient.py b/cschwabpy/SchwabAsyncClient.py index 46aa890..19c89ee 100644 --- a/cschwabpy/SchwabAsyncClient.py +++ b/cschwabpy/SchwabAsyncClient.py @@ -6,10 +6,12 @@ OptionExpiration, OptionExpirationChainResponse, ) +from cschwabpy.models.trade_models import AccountNumberModel from typing import Optional, List, Mapping from cschwabpy.costants import ( SCHWAB_API_BASE_URL, SCHWAB_MARKET_DATA_API_BASE_URL, + SCHWAB_TRADER_API_BASE_URL, SCHWAB_AUTH_PATH, SCHWAB_TOKEN_PATH, ) @@ -100,6 +102,26 @@ def __auth_header(self) -> Mapping[str, str]: "Accept": "application/json", } + async def get_account_numbers_async(self) -> List[AccountNumberModel]: + await self._ensure_valid_access_token() + import json + + target_url = f"{SCHWAB_TRADER_API_BASE_URL}/accounts/accountNumbers" + client = httpx.AsyncClient() if self.__client is None else self.__client + try: + response = await client.get( + url=target_url, params={}, headers=self.__auth_header() + ) + json_res = response.json() + print("json_res: ", json_res) + account_numbers: List[AccountNumberModel] = [] + for account_json in json_res: + account_numbers.append(AccountNumberModel(**account_json)) + return account_numbers + finally: + if not self.__keep_client_alive: + await client.aclose() + async def get_option_expirations_async( self, underlying_symbol: str ) -> List[OptionExpiration]: diff --git a/cschwabpy/costants.py b/cschwabpy/costants.py index 6be3d3f..678bc93 100644 --- a/cschwabpy/costants.py +++ b/cschwabpy/costants.py @@ -1,4 +1,5 @@ SCHWAB_API_BASE_URL = "https://api.schwabapi.com/v1" SCHWAB_MARKET_DATA_API_BASE_URL = "https://api.schwabapi.com/marketdata/v1" +SCHWAB_TRADER_API_BASE_URL = "https://api.schwabapi.com/trader/v1" SCHWAB_AUTH_PATH = "oauth/authorize" SCHWAB_TOKEN_PATH = "oauth/token" diff --git a/cschwabpy/models/__init__.py b/cschwabpy/models/__init__.py index 663ef50..392b2b8 100644 --- a/cschwabpy/models/__init__.py +++ b/cschwabpy/models/__init__.py @@ -33,6 +33,11 @@ def to_json(self) -> Mapping[str, Any]: return self.model_dump(by_alias=True) +class ErrorMessage(JSONSerializableBaseModel): + message: str + errors: List[str] + + class QueryFilterBase(JSONSerializableBaseModel): """Base class for query parameters filters.""" diff --git a/cschwabpy/models/trade_models.py b/cschwabpy/models/trade_models.py new file mode 100644 index 0000000..76cb2b2 --- /dev/null +++ b/cschwabpy/models/trade_models.py @@ -0,0 +1,6 @@ +from cschwabpy.models import JSONSerializableBaseModel + + +class AccountNumberModel(JSONSerializableBaseModel): + accountNumber: str + hashValue: str diff --git a/tests/data/mock_schwab_api_resp.json b/tests/data/mock_schwab_api_resp.json index f71cef8..5b9198b 100644 --- a/tests/data/mock_schwab_api_resp.json +++ b/tests/data/mock_schwab_api_resp.json @@ -1,5 +1,15 @@ { - "option_expirations_list": { + "account_numbers": [ + { + "accountNumber": "123456789", + "hashValue": "hash1" + }, + { + "accountNumber": "987654321", + "hashValue": "hash2" + } + ], + "option_expirations_list": { "expirationList": [ { "expirationDate": "2022-01-07", diff --git a/tests/test_models.py b/tests/test_models.py index e360798..ed74644 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -121,3 +121,38 @@ async def test_get_option_expirations(httpx_mock: HTTPXMock): assert opt_expirations_list[0].daysToExpiration == 2 assert opt_expirations_list[0].expirationType == "W" assert opt_expirations_list[0].standard + + +@pytest.mark.asyncio +async def test_get_account_numbers(httpx_mock: HTTPXMock): + # Mock response for account numbers API + mock_data = get_mock_response() + mocked_token = mock_tokens() + token_store = LocalTokenStore() + token_store.save_tokens(mocked_token) + if os.path.exists(Path(token_store.token_output_path)): + os.remove(token_store.token_output_path) # clean up before test + + mock_account_numbers_response = mock_data["account_numbers"] + # Combine mock response with token JSON + httpx_mock.add_response(json=mock_account_numbers_response) + + async with httpx.AsyncClient() as client: + cschwab_client = SchwabAsyncClient( + app_client_id="fake_id", + app_secret="fake_secret", + token_store=token_store, + tokens=mocked_token, + http_client=client, + ) + + account_numbers = await cschwab_client.get_account_numbers_async() + # Assertions to verify the correctness of the API call + assert account_numbers is not None + assert ( + len(account_numbers) == 2 + ) # Expecting 2 account numbers in the mock response + assert account_numbers[0].accountNumber == "123456789" + assert account_numbers[0].hashValue == "hash1" + assert account_numbers[1].accountNumber == "987654321" + assert account_numbers[1].hashValue == "hash2"