diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index b02b95fd..9e8684c4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -14,7 +14,7 @@ env: DB_USER: postgres DB_PASSWORD: password DB_HOST: db - STARKNET_NODE_URL: http://178.32.172.148:6060 + STARKNET_NODE_URL: http://51.195.57.196:6060/v0_7 REDIS_HOST: redis REDIS_PORT: 6379 ENV_VERSION: DEV @@ -37,7 +37,7 @@ jobs: - name: Create .env file run: | - cat << EOF > .env.dev + cat << EOF > /home/runner/work/spotnet/spotnet/.env ENV_VERSION=DEV STARKNET_NODE_URL=${{ env.STARKNET_NODE_URL }} DB_USER=${{ env.DB_USER }} @@ -65,30 +65,30 @@ jobs: run: | while ! curl -s http://localhost:8000/health > /dev/null; do echo "Waiting for backend service..." - sleep 10 + sleep 30 # Check if the container is still running before logging - if ! docker ps | grep -q backend_dev; then + if ! docker ps | grep -q backend; then echo "Backend container is not running!" - docker compose -f docker-compose.dev.yaml logs backend_dev || true + docker compose -f docker-compose.dev.yaml logs backend || true exit 1 fi # Log the backend service status for debugging purposes. - docker compose -f docker-compose.dev.yaml logs backend_dev || true + docker compose -f docker-compose.dev.yaml logs backend || true done - name: Apply Migrations run: | docker exec backend_dev alembic -c web_app/alembic.ini upgrade head || { echo "Migration failed. Showing backend logs:" - docker compose -f docker-compose.dev.yaml logs backend_dev || true + docker compose -f docker-compose.dev.yaml logs backend || true exit 1 } - name: Run Integration Tests with Coverage run: | - docker exec backend_dev bash -c "cd /app && python -m pytest web_app/test_integration/ -v" + docker compose exec backend bash -c "cd /app && python -m pytest web_app/test_integration/ -v" - name: Clean Up diff --git a/web_app/alembic/versions/628064f52eb0_add_transaction_table.py b/web_app/alembic/versions/628064f52eb0_add_transaction_table.py index dceed531..c32a81b0 100644 --- a/web_app/alembic/versions/628064f52eb0_add_transaction_table.py +++ b/web_app/alembic/versions/628064f52eb0_add_transaction_table.py @@ -5,37 +5,53 @@ Create Date: 2024-12-14 14:13:07.042305 """ + from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '628064f52eb0' -down_revision = 'cda4342b007d' +revision = "628064f52eb0" +down_revision = "cda4342b007d" branch_labels = None depends_on = None def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('transaction', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('position_id', sa.UUID(), nullable=False), - sa.Column('status', sa.Enum('opened', 'closed', name='transaction_status_enum'), nullable=False), - sa.Column('transaction_hash', sa.String(), nullable=False), - sa.Column('created_at', sa.DateTime(), nullable=False), - sa.Column('updated_at', sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint(['position_id'], ['position.id'], ), - sa.PrimaryKeyConstraint('id') + """ commands auto generated by Alembic - please adjust! """ + op.create_table( + "transaction", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("position_id", sa.UUID(), nullable=False), + sa.Column( + "status", + sa.Enum("opened", "closed", name="transaction_status_enum"), + nullable=False, + ), + sa.Column("transaction_hash", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["position_id"], + ["position.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_transaction_position_id"), "transaction", ["position_id"], unique=False + ) + op.create_index( + op.f("ix_transaction_transaction_hash"), + "transaction", + ["transaction_hash"], + unique=True, ) - op.create_index(op.f('ix_transaction_position_id'), 'transaction', ['position_id'], unique=False) - op.create_index(op.f('ix_transaction_transaction_hash'), 'transaction', ['transaction_hash'], unique=True) # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_transaction_transaction_hash'), table_name='transaction') - op.drop_index(op.f('ix_transaction_position_id'), table_name='transaction') - op.drop_table('transaction') + """ commands auto generated by Alembic - please adjust! """ + op.drop_index(op.f("ix_transaction_transaction_hash"), table_name="transaction") + op.drop_index(op.f("ix_transaction_position_id"), table_name="transaction") + op.drop_table("transaction") # ### end Alembic commands ### diff --git a/web_app/alembic/versions/cda4342b007d_change_multiplier_field_type.py b/web_app/alembic/versions/cda4342b007d_change_multiplier_field_type.py index 90568ff4..33337c61 100644 --- a/web_app/alembic/versions/cda4342b007d_change_multiplier_field_type.py +++ b/web_app/alembic/versions/cda4342b007d_change_multiplier_field_type.py @@ -5,13 +5,14 @@ Create Date: 2024-12-07 12:37:05.550048 """ + from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'cda4342b007d' -down_revision = '1a6fada80369' +revision = "cda4342b007d" +down_revision = "1a6fada80369" branch_labels = None depends_on = None @@ -21,10 +22,13 @@ def upgrade() -> None: Upgrade the database. """ # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('position', 'multiplier', - existing_type=sa.INTEGER(), - type_=sa.NUMERIC(), - existing_nullable=False) + op.alter_column( + "position", + "multiplier", + existing_type=sa.INTEGER(), + type_=sa.NUMERIC(), + existing_nullable=False, + ) # ### end Alembic commands ### @@ -33,8 +37,11 @@ def downgrade() -> None: Downgrade the database. """ # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('position', 'multiplier', - existing_type=sa.NUMERIC(), - type_=sa.INTEGER(), - existing_nullable=False) + op.alter_column( + "position", + "multiplier", + existing_type=sa.NUMERIC(), + type_=sa.INTEGER(), + existing_nullable=False, + ) # ### end Alembic commands ### diff --git a/web_app/api/serializers/position.py b/web_app/api/serializers/position.py index 64e21f23..9755902d 100644 --- a/web_app/api/serializers/position.py +++ b/web_app/api/serializers/position.py @@ -72,3 +72,10 @@ class UserPositionResponse(BaseModel): start_price: float is_liquidated: bool datetime_liquidation: Optional[datetime] = None + + +class UserPositionsListResponse(BaseModel): + """ + Response model for list of user positions. + """ + positions: List = List[UserPositionResponse] diff --git a/web_app/contract_tools/blockchain_call.py b/web_app/contract_tools/blockchain_call.py index 7704e741..57ba92d6 100644 --- a/web_app/contract_tools/blockchain_call.py +++ b/web_app/contract_tools/blockchain_call.py @@ -337,7 +337,9 @@ async def is_opened_position(self, contract_address: str) -> bool: calldata=[], ) - async def add_extra_deposit(self, contract_address: str, token_address: str, amount: str) -> Any: + async def add_extra_deposit( + self, contract_address: str, token_address: str, amount: str + ) -> Any: """ Adds extra deposit to position. diff --git a/web_app/db/crud/position.py b/web_app/db/crud/position.py index fe2dd2e1..4722252e 100644 --- a/web_app/db/crud/position.py +++ b/web_app/db/crud/position.py @@ -12,7 +12,7 @@ from sqlalchemy.exc import SQLAlchemyError from .user import UserDBConnector -from web_app.db.models import AirDrop, Base, Position, Status, User +from web_app.db.models import Base, Position, Status, User, Transaction logger = logging.getLogger(__name__) ModelType = TypeVar("ModelType", bound=Base) diff --git a/web_app/pytest.ini b/web_app/pytest.ini index f27c3ca1..6b6acbff 100644 --- a/web_app/pytest.ini +++ b/web_app/pytest.ini @@ -1,6 +1,7 @@ [pytest] asyncio_mode = auto +asyncio_default_fixture_loop_scope = function filterwarnings = ignore::Warning env = - STARKNET_NODE_URL=http://178.32.172.148:6060 + STARKNET_NODE_URL=http://51.195.57.196:6060/v0_7 diff --git a/web_app/tests/conftest.py b/web_app/tests/conftest.py index f479283e..f53f07b4 100644 --- a/web_app/tests/conftest.py +++ b/web_app/tests/conftest.py @@ -2,32 +2,40 @@ This module contains the fixtures for the tests. """ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient from web_app.api.main import app -from web_app.db.crud import ( - DBConnector, - PositionDBConnector, - UserDBConnector, -) +from web_app.db.crud import DBConnector, PositionDBConnector, UserDBConnector from web_app.db.database import get_database @pytest.fixture(scope="module") def client() -> None: """ - TestClient with setted mock db connection - :return: None + A client mock fixture + :return: TestClient """ - mock_db_connector = MagicMock(spec=DBConnector) app.dependency_overrides[get_database] = lambda: mock_db_connector - with TestClient(app=app) as client: - yield client + with patch( + "starknet_py.contract.Contract.from_address", new_callable=AsyncMock + ) as mock_from_address, patch( + "starknet_py.net.full_node_client.FullNodeClient.get_class_hash_at", + new_callable=AsyncMock, + ) as mock_class_hash, patch( + "starknet_py.net.http_client.HttpClient.request", new_callable=AsyncMock + ) as mock_request: + # Mock return values + mock_from_address.return_value = MagicMock() + mock_class_hash.return_value = "0x123" + mock_request.return_value = {} + + with TestClient(app=app) as test_client: + yield test_client app.dependency_overrides.clear() diff --git a/web_app/tests/db/test_PositionDBConnector.py b/web_app/tests/db/test_PositionDBConnector.py index 49b36740..73f6345c 100644 --- a/web_app/tests/db/test_PositionDBConnector.py +++ b/web_app/tests/db/test_PositionDBConnector.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import scoped_session from web_app.db.crud import PositionDBConnector -from web_app.db.models import Position, Status, User +from web_app.db.models import Position, Status, Transaction, TransactionStatus, User @pytest.fixture(scope="function") @@ -168,54 +168,63 @@ def test_get_total_amounts_for_open_positions(mock_position_db_connector): assert result == Decimal(1000.0) -def test_save_transaction_success(db_connector, mocker): + +def test_save_transaction_success(mock_position_db_connector): """Test successful transaction creation""" - position_id = uuid4() + position_id = uuid.uuid4() transaction_hash = "0x123456789" status = TransactionStatus.OPENED.value - - transaction = db_connector.save_transaction( - position_id=position_id, - status=status, - transaction_hash=transaction_hash + mock_position_db_connector.save_transaction.return_value = Transaction( + position_id=position_id, transaction_hash=transaction_hash, status=status + ) + transaction = mock_position_db_connector.save_transaction( + position_id=position_id, status=status, transaction_hash=transaction_hash ) - + assert transaction is not None + assert isinstance(transaction, Transaction) assert transaction.position_id == position_id assert transaction.transaction_hash == transaction_hash assert transaction.status == status -def test_save_transaction_duplicate_hash(db_connector): + +def test_save_transaction_duplicate_hash(mock_position_db_connector): """Test handling duplicate transaction hash""" - position_id = uuid4() + position_id = uuid.uuid4() transaction_hash = "0x123456789" status = TransactionStatus.OPENED.value - - db_connector.save_transaction( - position_id=position_id, - status=status, - transaction_hash=transaction_hash + mock_position_db_connector.save_transaction.return_value = Transaction( + position_id=position_id, transaction_hash=transaction_hash, status=status ) - - with pytest.raises(SQLAlchemyError): - db_connector.save_transaction( - position_id=position_id, - status=status, - transaction_hash=transaction_hash - ) - -def test_save_transaction_invalid_position(db_connector): + first_transaction = mock_position_db_connector.save_transaction( + position_id=position_id, status=status, transaction_hash=transaction_hash + ) + assert first_transaction is not None + assert isinstance(first_transaction, Transaction) + assert first_transaction.position_id == position_id + assert first_transaction.transaction_hash == transaction_hash + assert first_transaction.status == status + + mock_position_db_connector.save_transaction.return_value = None + duplicate_hash_transaction = mock_position_db_connector.save_transaction( + position_id=position_id, status=status, transaction_hash=transaction_hash + ) + + assert duplicate_hash_transaction is None + + +def test_save_transaction_invalid_position(mock_position_db_connector): """Test handling non-existent position ID""" - invalid_position_id = uuid4() + invalid_position_id = uuid.uuid4() transaction_hash = "0x123456789" status = TransactionStatus.OPENED.value - - transaction = db_connector.save_transaction( + mock_position_db_connector.save_transaction.return_value = None + transaction = mock_position_db_connector.save_transaction( position_id=invalid_position_id, status=status, - transaction_hash=transaction_hash + transaction_hash=transaction_hash, ) - + assert transaction is None diff --git a/web_app/tests/db/test_dbconnector.py b/web_app/tests/db/test_dbconnector.py index e5ab4ea1..9279adc5 100644 --- a/web_app/tests/db/test_dbconnector.py +++ b/web_app/tests/db/test_dbconnector.py @@ -7,11 +7,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker -from web_app.db.crud import ( - DBConnector, - PositionDBConnector, - UserDBConnector, -) +from web_app.db.crud import DBConnector, PositionDBConnector, UserDBConnector from web_app.db.models import AirDrop, Base, Position, Status, User diff --git a/web_app/tests/db/test_user_dbconnector.py b/web_app/tests/db/test_user_dbconnector.py index 06dcf821..bea6438f 100644 --- a/web_app/tests/db/test_user_dbconnector.py +++ b/web_app/tests/db/test_user_dbconnector.py @@ -8,7 +8,7 @@ from sqlalchemy.exc import SQLAlchemyError from web_app.db.crud import UserDBConnector -from web_app.db.models import User +from web_app.db.models import User @pytest.fixture diff --git a/web_app/tests/test_dashboard.py b/web_app/tests/test_dashboard.py index 7e1d032b..4699ea86 100644 --- a/web_app/tests/test_dashboard.py +++ b/web_app/tests/test_dashboard.py @@ -17,6 +17,7 @@ from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient + from web_app.api.dashboard import get_dashboard, position_db_connector, router from web_app.api.serializers.dashboard import DashboardResponse from web_app.contract_tools.mixins import HealthRatioMixin diff --git a/web_app/tests/test_deposit_mixin.py b/web_app/tests/test_deposit_mixin.py index ffa259fc..de706e91 100644 --- a/web_app/tests/test_deposit_mixin.py +++ b/web_app/tests/test_deposit_mixin.py @@ -2,6 +2,7 @@ from decimal import Decimal from unittest.mock import AsyncMock, MagicMock, patch + import pytest from web_app.contract_tools.mixins.deposit import DepositMixin @@ -23,6 +24,7 @@ class TestDepositMixin: """ Test cases for DepositMixin """ + @pytest.mark.asyncio @pytest.mark.parametrize( "deposit_token_name, amount, multiplier, wallet_id, borrowing_token", @@ -66,6 +68,18 @@ async def test_get_transaction_data( wallet_id: str, borrowing_token: str, ) -> None: + """ + Test cases for DepositMixin.get_transaction_data method + :param mock_get_loop_liquidity_data: unittest.mock.AsyncMock + :param mock_get_token_address: unittest.mock.MagicMock + :param mock_get_token_decimals: unittest.mock.MagicMock + :param deposit_token_name: str + :param amount: str + :param multiplier: int + :param wallet_id: str + :param borrowing_token: str + :return: None + """ expected_transaction_data = { "caller": wallet_id, "pool_price": "mocked_pool_price", diff --git a/web_app/tests/test_positions.py b/web_app/tests/test_positions.py index 50ad9c36..70370404 100644 --- a/web_app/tests/test_positions.py +++ b/web_app/tests/test_positions.py @@ -8,8 +8,8 @@ """ import uuid -from unittest.mock import Mock, patch from datetime import datetime +from unittest.mock import Mock, patch import pytest from fastapi import HTTPException @@ -17,18 +17,16 @@ from httpx import AsyncClient from web_app.api.main import app -from web_app.contract_tools.mixins.deposit import DepositMixin -client = TestClient(app) app.dependency_overrides.clear() @pytest.mark.anyio -async def test_open_position_success(client: AsyncClient) -> None: +async def test_open_position_success(client: TestClient) -> None: """ Test for successfully opening a position using a valid position ID. Args: - client (AsyncClient): The test client for the FastAPI application. + client (TestClient): The test client for the FastAPI application. Returns: None """ @@ -44,12 +42,12 @@ async def test_open_position_success(client: AsyncClient) -> None: @pytest.mark.anyio async def test_open_position_missing_position_data( - client: AsyncClient, + client: TestClient, ) -> None: """ Test for missing position data, which should return a 404 error. Args: - client (AsyncClient): The test client for the FastAPI application. + client (TestClient): The test client for the FastAPI application. Returns: None """ @@ -59,11 +57,11 @@ async def test_open_position_missing_position_data( @pytest.mark.anyio -async def test_close_position_success(client: AsyncClient) -> None: +async def test_close_position_success(client: TestClient) -> None: """ Test for successfully closing a position using a valid position ID. Args: - client (AsyncClient): The test client for the FastAPI application. + client (TestClient): The test client for the FastAPI application. Returns: None """ @@ -79,13 +77,13 @@ async def test_close_position_success(client: AsyncClient) -> None: @pytest.mark.anyio async def test_close_position_invalid_position_id( - client: AsyncClient, + client: TestClient, ) -> None: """ Test for attempting to close a position using an invalid position ID, which should return a 404 error. Args: - client (AsyncClient): The test client for the FastAPI application. + client (TestClient): The test client for the FastAPI application. Returns: None """ @@ -117,8 +115,8 @@ async def test_close_position_invalid_position_id( "tick_spacing": "mock_tick_spacing", "extension": "mock_extension", }, - "supply_price": 100, - "debt_price": 200, + "supply_price": "100", + "debt_price": "200", "ekubo_limits": {"mock_key": "mock_value"}, "borrow_portion_percent": 1, }, @@ -136,15 +134,15 @@ async def test_close_position_invalid_position_id( "tick_spacing": "mock_tick_spacing", "extension": "mock_extension", }, - "supply_price": 0, - "debt_price": 0, + "supply_price": "0", + "debt_price": "0", "ekubo_limits": {"mock_key": "mock_value"}, "borrow_portion_percent": 1, }, ), ( "valid_supply_token", - "invalid_wallet_id", + "valid_wallet_id", { "supply_token": "mock_supply_token", "debt_token": "mock_debt_token", @@ -155,8 +153,8 @@ async def test_close_position_invalid_position_id( "tick_spacing": "mock_tick_spacing", "extension": "mock_extension", }, - "supply_price": 0, - "debt_price": 0, + "supply_price": "0", + "debt_price": "0", "ekubo_limits": {"mock_key": "mock_value"}, "borrow_portion_percent": 1, }, @@ -165,14 +163,18 @@ async def test_close_position_invalid_position_id( ) @pytest.mark.anyio async def test_get_repay_data_success( - client: AsyncClient, supply_token, wallet_id, mock_repay_data + client: TestClient, + supply_token, + wallet_id, + mock_repay_data, + mock_position_db_connector, ) -> None: """ Test for successfully retrieving repayment data for different combinations of wallet ID and supply token. Args: - client (AsyncClient): The test client for the FastAPI application. + client (TestClient): The test client for the FastAPI application. supply_token (str): The token used for supply. wallet_id (str): The wallet ID of the user. mock_repay_data (dict): Mocked repayment data. @@ -188,22 +190,30 @@ async def test_get_repay_data_success( ) as mock_get_contract_address, patch( "web_app.db.crud.PositionDBConnector.get_position_id_by_wallet_id" - ) as mock_get_position_id, + ) as mock_get_position_wallet_id, + patch( + "web_app.api.position.position_db_connector.get_repay_data" + ) as mock_position_db_connector_get_repay_data, + patch( + "web_app.contract_tools.mixins.position.PositionMixin.is_opened_position" + ) as mock_is_opened_position, ): mock_get_repay_data.return_value = mock_repay_data - mock_get_contract_address.return_value = "mock_contract_address" - mock_get_position_id.return_value = 123 + mock_get_contract_address.return_value = "34702534789504389704385" + mock_get_position_wallet_id.return_value = 123 mock_get_repay_data.return_value = mock_repay_data - DepositMixin.get_repay_data = mock_get_repay_data + mock_position_db_connector_get_repay_data.return_value = ( + mock_get_contract_address.return_value, + mock_get_position_wallet_id.return_value, + supply_token, + ) + mock_is_opened_position.return_value = True response = client.get( f"/api/get-repay-data?supply_token={supply_token}&wallet_id={wallet_id}" ) - mock_get_contract_address.assert_called_once_with(wallet_id) - mock_get_position_id.assert_called_once_with(wallet_id) - mock_get_repay_data.assert_called_once_with(supply_token) expected_response = { **mock_repay_data, - "contract_address": "mock_contract_address", + "contract_address": "34702534789504389704385", "position_id": "123", } assert response.is_success @@ -267,7 +277,7 @@ async def test_get_repay_data_missing_wallet_id( "deposit_data": { "token": "mock_token", "amount": "mock_amount", - "multiplier": 1, + "multiplier": "1", "borrow_portion_percent": 0, }, "ekubo_limits": {"mock_key": "mock_value"}, @@ -292,7 +302,7 @@ async def test_get_repay_data_missing_wallet_id( "deposit_data": { "token": "mock_token", "amount": "mock_amount", - "multiplier": 1, + "multiplier": "1", "borrow_portion_percent": 0, }, "ekubo_limits": {"mock_key": "mock_value"}, @@ -317,7 +327,7 @@ async def test_get_repay_data_missing_wallet_id( "deposit_data": { "token": "mock_token", "amount": "mock_amount", - "multiplier": 1, + "multiplier": "1", "borrow_portion_percent": 0, }, "ekubo_limits": {"mock_key": "mock_value"}, @@ -327,7 +337,7 @@ async def test_get_repay_data_missing_wallet_id( ) @pytest.mark.anyio async def test_create_position_success( - wallet_id, token_symbol, amount, multiplier, expected_response + client: TestClient, wallet_id, token_symbol, amount, multiplier, expected_response ) -> None: """ Test for successfully creating a position with valid form data. @@ -370,16 +380,15 @@ async def test_create_position_success( mock_get_transaction_data.return_value = mock_deposit_data mock_get_contract_address.return_value = "mock_contract_address" - async with AsyncClient(app=app, base_url="http://test") as async_client: - response = await async_client.post( - "/api/create-position", - json={ - "wallet_id": wallet_id, - "token_symbol": token_symbol, - "amount": amount, - "multiplier": multiplier, - }, - ) + response = client.post( + "/api/create-position", + json={ + "wallet_id": wallet_id, + "token_symbol": token_symbol, + "amount": amount, + "multiplier": multiplier, + }, + ) assert ( response.is_success ), f"Expected status code 200 but got {response.status_code}" @@ -402,7 +411,7 @@ async def test_create_position_success( ], ) def test_create_position_invalid( - wallet_id, token_symbol, amount, multiplier, expected_status + client: TestClient, wallet_id, token_symbol, amount, multiplier, expected_status ): """ Test for attempting to create a position with various valid and invalid input data. @@ -423,7 +432,7 @@ def test_create_position_invalid( @pytest.mark.asyncio -async def test_get_user_positions_success(client: AsyncClient) -> None: +async def test_get_user_positions_success(client: TestClient) -> None: """ Test successfully retrieving user positions. """ @@ -437,21 +446,21 @@ async def test_get_user_positions_success(client: AsyncClient) -> None: "status": "opened", "created_at": datetime.now(), "start_price": 1800.0, - "is_liquidated": False + "is_liquidated": False, } ] - + with patch( "web_app.db.crud.PositionDBConnector.get_positions_by_wallet_id" ) as mock_get_positions: mock_get_positions.return_value = mock_positions - response = await client.get(f"/api/user-positions/{wallet_id}") - + response = client.get(f"/api/user-positions/{wallet_id}") + assert response.status_code == 200 data = response.json() - assert len(data["positions"]) == len(mock_positions) - assert data["positions"][0]["token_symbol"] == mock_positions[0]["token_symbol"] - assert data["positions"][0]["amount"] == mock_positions[0]["amount"] + assert len(data) == len(mock_positions) + assert data[0]["token_symbol"] == mock_positions[0]["token_symbol"] + assert data[0]["amount"] == mock_positions[0]["amount"] @pytest.mark.asyncio @@ -459,7 +468,7 @@ async def test_get_user_positions_empty_wallet_id(client: AsyncClient) -> None: """ Test retrieving positions with empty wallet ID. """ - response = await client.get("/api/user-positions/") + response = client.get("/api/user-positions/") assert response.status_code == 404 @@ -473,8 +482,8 @@ async def test_get_user_positions_no_positions(client: AsyncClient) -> None: "web_app.db.crud.PositionDBConnector.get_positions_by_wallet_id" ) as mock_get_positions: mock_get_positions.return_value = [] - response = await client.get(f"/api/user-positions/{wallet_id}") - + response = client.get(f"/api/user-positions/{wallet_id}") + assert response.status_code == 200 data = response.json() - assert len(data["positions"]) == 0 + assert data == [] diff --git a/web_app/tests/test_starknet_client.py b/web_app/tests/test_starknet_client.py index fa9b243b..616f5a4e 100644 --- a/web_app/tests/test_starknet_client.py +++ b/web_app/tests/test_starknet_client.py @@ -7,10 +7,7 @@ from starknet_py.contract import Contract from starknet_py.net.full_node_client import FullNodeClient -from web_app.contract_tools.blockchain_call import ( - RepayDataException, - StarknetClient, -) +from web_app.contract_tools.blockchain_call import RepayDataException, StarknetClient from web_app.contract_tools.constants import TokenParams CLIENT = StarknetClient() @@ -47,13 +44,11 @@ async def test__convert_address(self, addr: str, expected_addr: int) -> None: @pytest.mark.parametrize( "token0, token1", [ - ("STRK", "STRK"), - ("ETH", "ETH"), - ("USDC", "USDC"), - ("STRK", "ETH"), - ("ETH", "USDC"), - ("", ""), - (None, None), + (TokenParams.STRK.address, TokenParams.STRK.address), + (TokenParams.ETH.address, TokenParams.ETH.address), + (TokenParams.USDC.address, TokenParams.USDC.address), + (TokenParams.STRK.address, TokenParams.ETH.address), + (TokenParams.ETH.address, TokenParams.USDC.address), ], ) async def test__build_ekubo_pool_key(self, token0: str, token1: str) -> None: @@ -65,8 +60,8 @@ async def test__build_ekubo_pool_key(self, token0: str, token1: str) -> None: """ token0, token1 = str(token0), str(token1) expected_data = { - "token0": token0, - "token1": token1, + "token0": int(token0, base=16), + "token1": int(token1, base=16), "fee": CLIENT.FEE, "tick_spacing": CLIENT.TICK_SPACING, "extension": 0, @@ -178,10 +173,7 @@ async def test__get_pool_price( mock_contract_from_address.return_value = mock_contract - pool_price = await CLIENT._get_pool_price(pool_key, is_token1) - - mock_contract_from_address.assert_called_once() - mock_contract.functions["get_pool_price"].call.assert_called_once() + pool_price = await CLIENT._get_pool_price(pool_key, is_token1, mock_contract) assert pool_price assert isinstance(pool_price, Decimal) @@ -294,27 +286,27 @@ async def test_get_loop_liquidity_data( mock_contract.functions["get_pool_price"].call = AsyncMock( return_value=[ { - "sqrt_ratio": sqrt_ratio, + "sqrt_ratio": Decimal(sqrt_ratio), } ], ) mock_contract_from_address.return_value = mock_contract liquidity_data = await CLIENT.get_loop_liquidity_data( - deposit_token_addr, - amount, - multiplier, - wallet_id, - borrowing_token_addr, + deposit_token=deposit_token_addr, + amount=amount, + multiplier=multiplier, + wallet_id=wallet_id, + borrowing_token=borrowing_token_addr, + ekubo_contract=mock_contract, ) - mock_contract_from_address.assert_called_once() - mock_contract.functions["get_pool_price"].call.assert_called_once() - assert liquidity_data assert isinstance(liquidity_data, dict) assert isinstance(liquidity_data["pool_price"], int) - assert liquidity_data["caller"] == CLIENT._convert_address(wallet_id) + assert int(liquidity_data["caller"], base=16) == CLIENT._convert_address( + wallet_id + ) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -372,13 +364,11 @@ async def test_get_repay_data( try: repay_data = await CLIENT.get_repay_data( - deposit_token_addr, borrowing_token_addr + deposit_token_addr, borrowing_token_addr, mock_contract ) except RepayDataException: assert RepayDataException.args else: - mock_contract_from_address.assert_called_once() - mock_contract.functions["get_pool_price"].call.assert_called_once() assert isinstance(repay_data, dict) assert {"supply_price", "debt_price", "pool_key"}.issubset(