diff --git a/web_app/tests/db/test_PositionDBConnector.py b/web_app/tests/db/test_PositionDBConnector.py index 8099f22d..d70f5314 100644 --- a/web_app/tests/db/test_PositionDBConnector.py +++ b/web_app/tests/db/test_PositionDBConnector.py @@ -2,20 +2,18 @@ import uuid from decimal import Decimal -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest - -from web_app.db.models import Position, Status, User +from sqlalchemy.exc import SQLAlchemyError from web_app.db.crud import PositionDBConnector +from web_app.db.models import Position, Status, User @pytest.fixture(scope="function") def sample_user(): """Fixture to create a sample user for testing.""" - return User( - wallet_id="test_wallet_id" - ) + return User(wallet_id="test_wallet_id") @pytest.fixture(scope="function") @@ -27,7 +25,7 @@ def sample_position(sample_user): amount="100", multiplier=2, status=Status.PENDING.value, - start_price=0.0 + start_price=0.0, ) @@ -40,10 +38,11 @@ def mock_session(): ### Positive Test Cases ### + def test_position_to_dict(mock_position_db_connector, sample_position): """Test converting a Position object to dictionary.""" result = PositionDBConnector._position_to_dict(sample_position) - + assert isinstance(result, dict) assert result["id"] == str(sample_position.id) assert result["user_id"] == str(sample_position.user_id) @@ -54,13 +53,11 @@ def test_position_to_dict(mock_position_db_connector, sample_position): def test_get_positions_by_wallet_id_success( - mock_position_db_connector, - sample_user, - sample_position + mock_position_db_connector, sample_user, sample_position ): """Test successfully retrieving positions by wallet ID.""" mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user - + position_dict = { "id": str(sample_position.id), "user_id": str(sample_position.user_id), @@ -68,9 +65,9 @@ def test_get_positions_by_wallet_id_success( "amount": sample_position.amount, "multiplier": sample_position.multiplier, "status": sample_position.status, - "start_price": sample_position.start_price + "start_price": sample_position.start_price, } - + mock_position_db_connector.get_positions_by_wallet_id.return_value = [position_dict] positions = mock_position_db_connector.get_positions_by_wallet_id("test_wallet_id") @@ -86,25 +83,22 @@ def test_get_positions_by_wallet_id_success( def test_create_position_success(mock_position_db_connector, sample_user): """Test successfully creating a new position.""" mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user - + new_position = Position( user_id=sample_user.id, token_symbol="ETH", amount="200", multiplier=3, status=Status.PENDING.value, - start_price=0.0 + start_price=0.0, ) mock_position_db_connector.write_to_db.return_value = new_position mock_position_db_connector.create_position.return_value = new_position - + result = mock_position_db_connector.create_position( - wallet_id="test_wallet_id", - token_symbol="ETH", - amount="200", - multiplier=3 + wallet_id="test_wallet_id", token_symbol="ETH", amount="200", multiplier=3 ) - + assert result is not None assert result.token_symbol == "ETH" assert result.amount == "200" @@ -112,30 +106,25 @@ def test_create_position_success(mock_position_db_connector, sample_user): def test_update_existing_pending_position( - mock_position_db_connector, - sample_user, - sample_position + mock_position_db_connector, sample_user, sample_position ): """Test updating an existing pending position.""" mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user - + updated_position = Position( user_id=sample_user.id, token_symbol="LTC", amount="300", multiplier=4, status=Status.PENDING.value, - start_price=0.0 + start_price=0.0, ) mock_position_db_connector.create_position.return_value = updated_position - + result = mock_position_db_connector.create_position( - wallet_id="test_wallet_id", - token_symbol="LTC", - amount="300", - multiplier=4 + wallet_id="test_wallet_id", token_symbol="LTC", amount="300", multiplier=4 ) - + assert result.token_symbol == "LTC" assert result.amount == "300" assert result.multiplier == 4 @@ -145,9 +134,9 @@ def test_close_position_success(mock_position_db_connector, sample_position): """Test successfully closing a position.""" mock_position_db_connector.get_object.return_value = sample_position mock_position_db_connector.close_position.return_value = Status.CLOSED.value - + result = mock_position_db_connector.close_position(sample_position.id) - + assert result == Status.CLOSED.value @@ -155,40 +144,66 @@ def test_open_position_success(mock_position_db_connector, sample_position): """Test successfully opening a position.""" mock_position_db_connector.get_object.return_value = sample_position mock_position_db_connector.open_position.return_value = Status.OPENED.value - + result = mock_position_db_connector.open_position(sample_position.id) - - assert result == Status.OPENED.value + assert result == Status.OPENED.value def test_get_total_amounts_for_open_positions(mock_position_db_connector): """Test getting total amounts for open positions.""" - mock_position_db_connector.get_total_amounts_for_open_positions.return_value = Decimal(1000.0) - + mock_position_db_connector.get_total_amounts_for_open_positions.return_value = ( + Decimal(1000.0) + ) + result = mock_position_db_connector.get_total_amounts_for_open_positions() - + assert result == Decimal(1000.0) +@patch("web_app.db.connectors.PositionDBConnector") +def test_delete_all_user_positions_success(mock_position_db_connector): + """Test successfully deleting all positions for a user.""" + user_id = uuid.uuid4() + mock_positions = [ + Position(id=uuid.uuid4(), user_id=user_id, token_symbol="BTC", amount="10"), + Position(id=uuid.uuid4(), user_id=user_id, token_symbol="ETH", amount="5"), + ] + mock_session = mock_position_db_connector.Session.return_value + mock_session.query.return_value.filter_by.return_value.all.return_value = ( + mock_positions + ) + + position_connector = PositionDBConnector() + position_connector.delete_all_user_positions(user_id) + + mock_session.query.assert_called_once_with(Position) + mock_session.query.return_value.filter_by.assert_called_once_with(user_id=user_id) + assert mock_session.delete.call_count == len(mock_positions) + mock_session.commit.assert_called_once() + + ### Negative Test Cases ### + def test_get_positions_by_wallet_id_no_user(mock_position_db_connector): """Test retrieving positions for non-existent user.""" mock_position_db_connector._get_user_by_wallet_id.return_value = None mock_position_db_connector.get_positions_by_wallet_id.return_value = [] - - positions = mock_position_db_connector.get_positions_by_wallet_id("nonexistent_wallet") - + + positions = mock_position_db_connector.get_positions_by_wallet_id( + "nonexistent_wallet" + ) + assert positions == [] def test_get_positions_by_wallet_id_db_error(mock_position_db_connector, sample_user): """Test handling database error when retrieving positions.""" mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user - + positions = mock_position_db_connector.get_positions_by_wallet_id("test_wallet_id") - + assert positions == [] @@ -196,14 +211,11 @@ def test_create_position_no_user(mock_position_db_connector): """Test creating position for non-existent user.""" mock_position_db_connector._get_user_by_wallet_id.return_value = None mock_position_db_connector.create_position.return_value = None - + result = mock_position_db_connector.create_position( - wallet_id="nonexistent_wallet", - token_symbol="ETH", - amount="100", - multiplier=2 + wallet_id="nonexistent_wallet", token_symbol="ETH", amount="100", multiplier=2 ) - + assert result is None @@ -211,18 +223,18 @@ def test_close_position_not_found(mock_position_db_connector): """Test closing non-existent position.""" mock_position_db_connector.get_object.return_value = None mock_position_db_connector.close_position.return_value = None - + result = mock_position_db_connector.close_position(uuid.uuid4()) - + assert result is None def test_get_total_amounts_db_error(mock_position_db_connector): """Test handling database error when getting total amounts.""" mock_position_db_connector.get_total_amounts_for_open_positions.return_value = None - + result = mock_position_db_connector.get_total_amounts_for_open_positions() - + assert result is None @@ -230,7 +242,22 @@ def test_get_position_id_by_wallet_id_no_positions(mock_position_db_connector): """Test getting position ID when no positions exist.""" mock_position_db_connector.get_positions_by_wallet_id.return_value = [] mock_position_db_connector.get_position_id_by_wallet_id.return_value = None - + result = mock_position_db_connector.get_position_id_by_wallet_id("test_wallet_id") - - assert result is None \ No newline at end of file + + assert result is None + + +@patch("web_app.db.connectors.PositionDBConnector") +def test_delete_all_user_positions_failure(mock_position_db_connector): + """Test failure during deletion of all positions for a user.""" + user_id = uuid.uuid4() + mock_session = mock_position_db_connector.Session.return_value + mock_session.query.side_effect = SQLAlchemyError("Database error") + + position_connector = PositionDBConnector() + + with pytest.raises(SQLAlchemyError): + position_connector.delete_all_user_positions(user_id) + + mock_session.rollback.assert_called_once() diff --git a/web_app/tests/db/test_user_dbconnector.py b/web_app/tests/db/test_user_dbconnector.py index d670615f..1ebe10f4 100644 --- a/web_app/tests/db/test_user_dbconnector.py +++ b/web_app/tests/db/test_user_dbconnector.py @@ -2,10 +2,13 @@ Unit tests for the UserDBConnector module. """ -import pytest from unittest.mock import MagicMock, patch -from web_app.db.models import User -from web_app.db.crud import UserDBConnector + +import pytest +from sqlalchemy.exc import SQLAlchemyError +from web_app.db.crud import AirDropDBConnector, UserDBConnector +from web_app.db.models import AirDrop, User + @pytest.fixture def mock_db_connector(): @@ -14,17 +17,20 @@ def mock_db_connector(): """ return MagicMock() + @pytest.fixture def user_db(mock_db_connector): """ Fixture to create a UserDBConnector instance with mocked dependencies. """ - with patch('web_app.db.crud.UserDBConnector.get_object_by_field', - new_callable=MagicMock) as mock_get: + with patch( + "web_app.db.crud.UserDBConnector.get_object_by_field", new_callable=MagicMock + ) as mock_get: mock_get.side_effect = mock_db_connector.get_object_by_field connector = UserDBConnector() yield connector + def test_get_user_by_wallet_id_success(user_db, mock_db_connector): """ Test successful retrieval of user by wallet ID. @@ -41,11 +47,10 @@ def test_get_user_by_wallet_id_success(user_db, mock_db_connector): assert result == expected_user mock_db_connector.get_object_by_field.assert_called_once_with( - User, - "wallet_id", - wallet_id + User, "wallet_id", wallet_id ) + def test_get_user_by_wallet_id_not_found(user_db, mock_db_connector): """ Test when user is not found by wallet ID. @@ -57,11 +62,10 @@ def test_get_user_by_wallet_id_not_found(user_db, mock_db_connector): assert result is None mock_db_connector.get_object_by_field.assert_called_once_with( - User, - "wallet_id", - wallet_id + User, "wallet_id", wallet_id ) + def test_get_user_by_wallet_id_empty_wallet_id(user_db, mock_db_connector): """ Test behavior when empty wallet ID is provided. @@ -73,9 +77,7 @@ def test_get_user_by_wallet_id_empty_wallet_id(user_db, mock_db_connector): assert result is None mock_db_connector.get_object_by_field.assert_called_once_with( - User, - "wallet_id", - wallet_id + User, "wallet_id", wallet_id ) @@ -85,4 +87,51 @@ def test_get_unique_users_count(mock_user_db_connector): result = mock_user_db_connector.get_unique_users_count() - assert result == 5 \ No newline at end of file + assert result == 5 + + +def test_delete_all_users_airdrop_success(user_db): + """ + Test successful deletion of all airdrops for a user. + """ + user_id = "123e4567-e89b-12d3-a456-426614174000" + mock_session = MagicMock() + mock_airdrops = [ + AirDrop(id=1, user_id=user_id), + AirDrop(id=2, user_id=user_id), + ] + with patch.object( + "web_app.db.crud.AirDropDBConnector.Session", return_value=mock_session + ): + mock_session.query.return_value.filter_by.return_value.all.return_value = ( + mock_airdrops + ) + + air_drop_connector = AirDropDBConnector() + air_drop_connector.delete_all_users_airdrop(user_id) + + mock_session.query.assert_called_once_with(AirDrop) + mock_session.query.return_value.filter_by.assert_called_once_with( + user_id=user_id + ) + assert mock_session.delete.call_count == len(mock_airdrops) + mock_session.commit.assert_called_once() + + +def test_delete_all_users_airdrop_failure(user_db): + """ + Test failure while deleting airdrops for a user. + """ + user_id = "123e4567-e89b-12d3-a456-426614174000" + mock_session = MagicMock() + mock_session.query.side_effect = SQLAlchemyError("Database error") + with patch.object( + "web_app.db.crud.AirDropDBConnector.Session", return_value=mock_session + ): + air_drop_connector = AirDropDBConnector() + + with pytest.raises(SQLAlchemyError): + air_drop_connector.delete_all_users_airdrop(user_id) + + mock_session.query.assert_called_once_with(AirDrop) + mock_session.rollback.assert_called_once()