diff --git a/web_app/tests/test_vault.py b/web_app/tests/test_vault.py index 94966a97..7423426a 100644 --- a/web_app/tests/test_vault.py +++ b/web_app/tests/test_vault.py @@ -42,7 +42,13 @@ async def async_client(): {"detail": "User not found"} ), ]) -async def test_deposit_to_vault(test_data, expected_status, expected_response, mock_user_db_connector, async_client): +async def test_deposit_to_vault( + test_data, + expected_status, + expected_response, + mock_user_db_connector, + async_client +): """Test vault deposit with different scenarios.""" mock_user = MagicMock() mock_vault = MagicMock() @@ -50,44 +56,70 @@ async def test_deposit_to_vault(test_data, expected_status, expected_response, m mock_vault.amount = test_data["amount"] with patch.object(UserDBConnector, "__new__", return_value=mock_user_db_connector): - mock_user_db_connector.get_user_by_wallet_id.return_value = \ + mock_user_db_connector.get_user_by_wallet_id.return_value = ( mock_user if test_data["wallet_id"] == "test_wallet" else None + ) if test_data["wallet_id"] == "test_wallet": - with patch("web_app.db.crud.DepositDBConnector.create_vault", return_value=mock_vault): + with patch( + "web_app.db.crud.DepositDBConnector.create_vault", + return_value=mock_vault + ): response = await async_client.post("/api/vault/deposit", json=test_data) else: response = await async_client.post("/api/vault/deposit", json=test_data) assert response.status_code == expected_status - expected = expected_response(mock_vault.id) if callable(expected_response) else expected_response + expected = ( + expected_response(mock_vault.id) + if callable(expected_response) + else expected_response + ) assert response.json() == expected @pytest.mark.anyio -@pytest.mark.parametrize("wallet_id, symbol, balance, expected_status, expected_response", [ - ( - "test_wallet", - "ETH", - "1.5", - 200, - lambda w, s, b: {"wallet_id": w, "symbol": s, "amount": b} - ), - ( - "invalid_wallet", - "ETH", - None, - 404, - {"detail": "Vault not found or user does not exist"} - ), -]) -async def test_get_vault_balance(wallet_id, symbol, balance, expected_status, expected_response, async_client): +@pytest.mark.parametrize( + "wallet_id, symbol, balance, expected_status, expected_response", + [ + ( + "test_wallet", + "ETH", + "1.5", + 200, + lambda w, s, b: {"wallet_id": w, "symbol": s, "amount": b} + ), + ( + "invalid_wallet", + "ETH", + None, + 404, + {"detail": "Vault not found or user does not exist"} + ), + ] +) +async def test_get_vault_balance( + wallet_id, + symbol, + balance, + expected_status, + expected_response, + async_client +): """Test vault balance retrieval with different scenarios.""" - with patch("web_app.db.crud.DepositDBConnector.get_vault_balance", return_value=balance): - response = await async_client.get(f"/api/vault/api/balance?wallet_id={wallet_id}&symbol={symbol}") + with patch( + "web_app.db.crud.DepositDBConnector.get_vault_balance", + return_value=balance + ): + url = f"/api/vault/api/balance?wallet_id={wallet_id}&symbol={symbol}" + response = await async_client.get(url) assert response.status_code == expected_status - expected = expected_response(wallet_id, symbol, balance) if callable(expected_response) else expected_response + expected = ( + expected_response(wallet_id, symbol, balance) + if callable(expected_response) + else expected_response + ) assert response.json() == expected @@ -104,7 +136,12 @@ async def test_get_vault_balance(wallet_id, symbol, balance, expected_status, ex {"detail": "Failed to update vault balance: Amount must be positive"} ), ]) -async def test_add_vault_balance(test_data, expected_status, expected_response, async_client): +async def test_add_vault_balance( + test_data, + expected_status, + expected_response, + async_client +): """Test adding to vault balance with different scenarios.""" mock_vault = MagicMock() mock_vault.amount = "2.0" @@ -118,9 +155,19 @@ async def test_add_vault_balance(test_data, expected_status, expected_response, "return_value": mock_vault } - with patch("web_app.db.crud.DepositDBConnector.add_vault_balance", **patch_kwargs): - response = await async_client.post("/api/vault/api/add_balance", json=test_data) + with patch( + "web_app.db.crud.DepositDBConnector.add_vault_balance", + **patch_kwargs + ): + response = await async_client.post( + "/api/vault/api/add_balance", + json=test_data + ) assert response.status_code == expected_status - expected = expected_response(mock_vault.amount) if callable(expected_response) else expected_response + expected = ( + expected_response(mock_vault.amount) + if callable(expected_response) + else expected_response + ) assert response.json() == expected