diff --git a/web_app/api/dashboard.py b/web_app/api/dashboard.py index 85ec8419..aee329e6 100644 --- a/web_app/api/dashboard.py +++ b/web_app/api/dashboard.py @@ -48,6 +48,7 @@ async def get_dashboard(wallet_id: str) -> DashboardResponse: start_sum=0, borrowed="0", balance="0", + position_id="0", ) if not contract_address: return default_dashboard_response @@ -95,4 +96,5 @@ async def get_dashboard(wallet_id: str) -> DashboardResponse: start_sum=start_sum, borrowed=str(start_sum * Decimal(tvl)), balance=str(balance), + position_id=first_opened_position["id"], ) diff --git a/web_app/api/serializers/dashboard.py b/web_app/api/serializers/dashboard.py index 393ac0ec..87dd90bf 100644 --- a/web_app/api/serializers/dashboard.py +++ b/web_app/api/serializers/dashboard.py @@ -44,3 +44,8 @@ class DashboardResponse(BaseModel): example="12", description="The balance of the position.", ) + position_id: str = Field( + ..., + example="12", + description="The position ID.", + ) diff --git a/web_app/contract_tools/blockchain_call.py b/web_app/contract_tools/blockchain_call.py index b6349d67..eac1a3f7 100644 --- a/web_app/contract_tools/blockchain_call.py +++ b/web_app/contract_tools/blockchain_call.py @@ -14,11 +14,10 @@ import starknet_py.hash.selector import starknet_py.net.client_models import starknet_py.net.networks +from .constants import MULTIPLIER_POWER, ZKLEND_MARKET_ADDRESS, TokenParams from starknet_py.contract import Contract from starknet_py.net.full_node_client import FullNodeClient -from .constants import ZKLEND_MARKET_ADDRESS, TokenParams, MULTIPLIER_POWER - logger = logging.getLogger(__name__) @@ -357,7 +356,7 @@ async def add_extra_deposit( async def withdraw_all(self, contract_address: str) -> dict[str, str]: """ Withdraws all supported tokens from the contract by calling withdraw with amount=0. - + :param contract_address: The contract address to withdraw from :return: A dictionary summarizing the results for each token. """ @@ -368,7 +367,7 @@ async def withdraw_all(self, contract_address: str) -> dict[str, str]: token_symbol = token.name try: - token_addr_int = self._convert_address(token.address) + token_addr_int = self._convert_address(token.address) except ValueError as e: logger.error(f"Invalid address format for {token_symbol}: {str(e)}") @@ -376,11 +375,13 @@ async def withdraw_all(self, contract_address: str) -> dict[str, str]: continue try: - logger.info(f"Withdrawing {token_symbol} from contract {contract_address}") + logger.info( + f"Withdrawing {token_symbol} from contract {contract_address}" + ) await self._func_call( addr=contract_addr_int, selector="withdraw", - calldata=[token_addr_int, 0] + calldata=[token_addr_int, 0], ) results[token_symbol] = "Success" except Exception as e: @@ -389,5 +390,33 @@ async def withdraw_all(self, contract_address: str) -> dict[str, str]: return results + async def fetch_portfolio(self, contract_address: str) -> dict: + """ + Fetches the portfolio of the contract + + :param contract_address: the contract address to fetch the portfolio from. + :return: A dictionary containing dictionaries of available tokens in the contract address, + and the balance + """ + results = {} + z_addresses = await self.get_z_addresses() + + for key, value in z_addresses.items(): + decimals, z_address = value + balance: int = await self.get_balance(z_address, contract_address) + + key = f"z{key}" + results[key] = {"balance": balance, "decimals": decimals} + + return results + CLIENT = StarknetClient() + +if __name__ == "__main__": + call = CLIENT + spotnet_address = ( + "0x05685d6b0b493c7c939d65c175305b893870cacad780842c79a611ad9122815f" + ) + res = asyncio.run(call.fetch_portfolio(spotnet_address)) + print(res) diff --git a/web_app/tests/test_dashboard.py b/web_app/tests/test_dashboard.py index 4699ea86..7a7eae2f 100644 --- a/web_app/tests/test_dashboard.py +++ b/web_app/tests/test_dashboard.py @@ -98,6 +98,7 @@ async def test_get_dashboard_success(): "start_price": "100.0", "amount": "2.0", "token_symbol": "ETH", + "id": "0", } ] mock_get_health_ratio_and_tvl.return_value = ("1.2", "1000.0") @@ -105,15 +106,6 @@ async def test_get_dashboard_success(): "ETH": 5.0, "USDC": 1000.0, } - # mock_get_zklend_position.return_value = { - # "products": [ - # { - # "name": "ZkLend", - # "groups": {"1": {"healthRatio": "1.2"}}, - # "positions": [], - # } - # ] - # } mock_get_current_position_sum.return_value = Decimal("200.0") mock_get_start_position_sum.return_value = Decimal("200.0") @@ -133,6 +125,7 @@ async def test_get_dashboard_success(): "borrowed": "200000.00", "balance": "2.020202020202020202020202020", "health_ratio": "1.2", + "position_id": "0", } @@ -286,6 +279,7 @@ async def test_empty_positions( "borrowed": "0", "balance": "0", "health_ratio": "0", + "position_id": "0", } diff --git a/web_app/tests/test_positions.py b/web_app/tests/test_positions.py index 70370404..2dcc011c 100644 --- a/web_app/tests/test_positions.py +++ b/web_app/tests/test_positions.py @@ -9,6 +9,7 @@ import uuid from datetime import datetime +from decimal import Decimal from unittest.mock import Mock, patch import pytest @@ -17,6 +18,7 @@ from httpx import AsyncClient from web_app.api.main import app +from web_app.api.position import add_extra_deposit app.dependency_overrides.clear() @@ -487,3 +489,127 @@ async def test_get_user_positions_no_positions(client: AsyncClient) -> None: assert response.status_code == 200 data = response.json() assert data == [] + +@pytest.mark.parametrize( + "position_id, amount, mock_position, expected_response", + [ + ( + 1, + "100.0", + { + "id": 1, + "token_symbol": "ETH", + "amount": "1000", + "status": "opened" + }, + {"detail": "Successfully added extra deposit"} + ), + ( + 123, + "50.5", + { + "id": 123, + "token_symbol": "ETH", + "amount": "500", + "status": "opened" + }, + {"detail": "Successfully added extra deposit"} + ), + ( + 999, + "75.25", + { + "id": 999, + "token_symbol": "ETH", + "amount": "750", + "status": "opened" + }, + {"detail": "Successfully added extra deposit"} + ), + ], +) +@pytest.mark.anyio +async def test_add_extra_deposit_success( + client: TestClient, + position_id: int, + amount: str, + mock_position: dict, + expected_response: dict, +) -> None: + """ + Test for successfully adding extra deposit to a position. + + """ + with ( + patch( + "web_app.db.crud.PositionDBConnector.get_position_by_id" + ) as mock_get_position, + patch( + "web_app.db.crud.PositionDBConnector.add_extra_deposit_to_position" + ) as mock_add_deposit, + ): + mock_get_position.return_value = mock_position + mock_add_deposit.return_value = None + + response = client.post( + f"/api/add-extra-deposit/{position_id}?amount={amount}" + ) + + assert response.status_code == 200 + assert response.json() == expected_response + mock_get_position.assert_called_once_with(position_id) + mock_add_deposit.assert_called_once_with(mock_position, amount) + + +@pytest.mark.parametrize( + "position_id, amount, error_status, error_detail", + [ + ( + None, + "100.0", + 422, + "Position ID is required" + ), + ( + 1, + "", + 404, + "Amount is required" + ), + ( + 999, + "100.0", + 404, + "Position not found" + ), + ( + "invalid", + "100.0", + 422, + "Invalid position ID format" + ), + ], +) +@pytest.mark.anyio +async def test_add_extra_deposit_failure( + client: TestClient, + position_id: int, + amount: str, + error_status: int, + error_detail: str, +) -> None: + """ + Test various failure scenarios when adding extra deposit to a position. + + """ + with patch( + "web_app.db.crud.PositionDBConnector.get_position_by_id" + ) as mock_get_position: + if error_detail == "Position not found": + mock_get_position.return_value = None + + response = client.post( + f"/api/add-extra-deposit/{position_id}?amount={amount}" + ) + + assert response.status_code == error_status \ No newline at end of file diff --git a/web_app/tests/test_user.py b/web_app/tests/test_user.py index 5329dddf..9ac5e82a 100644 --- a/web_app/tests/test_user.py +++ b/web_app/tests/test_user.py @@ -274,3 +274,90 @@ async def test_subscribe_to_notification( assert response.status_code == expected_status_code if expected_response: assert response.json() == expected_response + + +@pytest.mark.asyncio +@patch("web_app.contract_tools.blockchain_call.CLIENT.withdraw_all") +@patch("web_app.api.user.user_db.get_contract_address_by_wallet_id") +@pytest.mark.parametrize( + "wallet_id, contract_address, withdrawal_results, expected_status_code, expected_response", + [ + # Positive case - successful withdrawal + ( + "0x27994c503bd8c32525fbdaf9d398bdd4e86757988c64581b055a06c5955ea49", + "0x698b63df00be56ba39447c9b9ca576ffd0edba0526d98b3e8e4a902ffcf12f0", + {"ETH": "success", "USDT": "success"}, + 200, + { + "detail": "Successfully initiated withdrawals for all tokens", + "results": {"ETH": "success", "USDT": "success"} + } + ), + # Negative case - contract not found + ( + "invalid_wallet_id", + None, + None, + 404, + {"detail": "Contract not found"} + ), + # Negative case - empty wallet_id + ( + "", + None, + None, + 404, + {"detail": "Contract not found"} + ), + # Edge case - valid wallet but no tokens to withdraw + ( + "0x27994c503bd8c32525fbdaf9d398bdd4e86757988c64581b055a06c5955ea49", + "0x698b63df00be56ba39447c9b9ca576ffd0edba0526d98b3e8e4a902ffcf12f0", + {}, + 200, + { + "detail": "Successfully initiated withdrawals for all tokens", + "results": {} + } + ), + ], +) +async def test_withdraw_all( + mock_get_contract_address: MagicMock, + mock_withdraw_all: MagicMock, + client: client, + wallet_id: str, + contract_address: str, + withdrawal_results: dict, + expected_status_code: int, + expected_response: dict, +) -> None: + """ + Test withdraw_all endpoint with various scenarios + + :param mock_get_contract_address: Mock for get_contract_address_by_wallet_id + :param mock_withdraw_all: Mock for CLIENT.withdraw_all + :param client: FastAPI test client + :param wallet_id: Wallet ID to test + :param contract_address: Expected contract address + :param withdrawal_results: Mock results from withdrawal operation + :param expected_status_code: Expected HTTP status code + :param expected_response: Expected response body + :return: None + """ + # Configure mocks + mock_get_contract_address.return_value = contract_address + if withdrawal_results is not None: + mock_withdraw_all.return_value = withdrawal_results + + response = client.post( + url="/api/withdraw-all", + params={"wallet_id": wallet_id}, + ) + + assert response.status_code == expected_status_code + assert response.json() == expected_response + + mock_get_contract_address.assert_called_once_with(wallet_id) + if contract_address: + mock_withdraw_all.assert_called_once_with(contract_address) \ No newline at end of file