Skip to content

Commit

Permalink
Merge pull request #307 from Elsie-ND/test-case-delete
Browse files Browse the repository at this point in the history
add test cases for delete methods
  • Loading branch information
djeck1432 authored Dec 1, 2024
2 parents 5dce11e + 1ca6173 commit 9de4e20
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 74 deletions.
145 changes: 86 additions & 59 deletions web_app/tests/db/test_PositionDBConnector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@

import uuid
from decimal import Decimal
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest

from web_app.db.models import Position, Status, User
from sqlalchemy.exc import SQLAlchemyError
from web_app.db.crud import PositionDBConnector
from web_app.db.models import Position, Status, User


@pytest.fixture(scope="function")
def sample_user():
"""Fixture to create a sample user for testing."""
return User(
wallet_id="test_wallet_id"
)
return User(wallet_id="test_wallet_id")


@pytest.fixture(scope="function")
Expand All @@ -27,7 +25,7 @@ def sample_position(sample_user):
amount="100",
multiplier=2,
status=Status.PENDING.value,
start_price=0.0
start_price=0.0,
)


Expand All @@ -40,10 +38,11 @@ def mock_session():

### Positive Test Cases ###


def test_position_to_dict(mock_position_db_connector, sample_position):
"""Test converting a Position object to dictionary."""
result = PositionDBConnector._position_to_dict(sample_position)

assert isinstance(result, dict)
assert result["id"] == str(sample_position.id)
assert result["user_id"] == str(sample_position.user_id)
Expand All @@ -54,23 +53,21 @@ def test_position_to_dict(mock_position_db_connector, sample_position):


def test_get_positions_by_wallet_id_success(
mock_position_db_connector,
sample_user,
sample_position
mock_position_db_connector, sample_user, sample_position
):
"""Test successfully retrieving positions by wallet ID."""
mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user

position_dict = {
"id": str(sample_position.id),
"user_id": str(sample_position.user_id),
"token_symbol": sample_position.token_symbol,
"amount": sample_position.amount,
"multiplier": sample_position.multiplier,
"status": sample_position.status,
"start_price": sample_position.start_price
"start_price": sample_position.start_price,
}

mock_position_db_connector.get_positions_by_wallet_id.return_value = [position_dict]

positions = mock_position_db_connector.get_positions_by_wallet_id("test_wallet_id")
Expand All @@ -86,56 +83,48 @@ def test_get_positions_by_wallet_id_success(
def test_create_position_success(mock_position_db_connector, sample_user):
"""Test successfully creating a new position."""
mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user

new_position = Position(
user_id=sample_user.id,
token_symbol="ETH",
amount="200",
multiplier=3,
status=Status.PENDING.value,
start_price=0.0
start_price=0.0,
)
mock_position_db_connector.write_to_db.return_value = new_position
mock_position_db_connector.create_position.return_value = new_position

result = mock_position_db_connector.create_position(
wallet_id="test_wallet_id",
token_symbol="ETH",
amount="200",
multiplier=3
wallet_id="test_wallet_id", token_symbol="ETH", amount="200", multiplier=3
)

assert result is not None
assert result.token_symbol == "ETH"
assert result.amount == "200"
assert result.multiplier == 3


def test_update_existing_pending_position(
mock_position_db_connector,
sample_user,
sample_position
mock_position_db_connector, sample_user, sample_position
):
"""Test updating an existing pending position."""
mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user

updated_position = Position(
user_id=sample_user.id,
token_symbol="LTC",
amount="300",
multiplier=4,
status=Status.PENDING.value,
start_price=0.0
start_price=0.0,
)
mock_position_db_connector.create_position.return_value = updated_position

result = mock_position_db_connector.create_position(
wallet_id="test_wallet_id",
token_symbol="LTC",
amount="300",
multiplier=4
wallet_id="test_wallet_id", token_symbol="LTC", amount="300", multiplier=4
)

assert result.token_symbol == "LTC"
assert result.amount == "300"
assert result.multiplier == 4
Expand All @@ -145,92 +134,130 @@ def test_close_position_success(mock_position_db_connector, sample_position):
"""Test successfully closing a position."""
mock_position_db_connector.get_object.return_value = sample_position
mock_position_db_connector.close_position.return_value = Status.CLOSED.value

result = mock_position_db_connector.close_position(sample_position.id)

assert result == Status.CLOSED.value


def test_open_position_success(mock_position_db_connector, sample_position):
"""Test successfully opening a position."""
mock_position_db_connector.get_object.return_value = sample_position
mock_position_db_connector.open_position.return_value = Status.OPENED.value

result = mock_position_db_connector.open_position(sample_position.id)

assert result == Status.OPENED.value

assert result == Status.OPENED.value


def test_get_total_amounts_for_open_positions(mock_position_db_connector):
"""Test getting total amounts for open positions."""
mock_position_db_connector.get_total_amounts_for_open_positions.return_value = Decimal(1000.0)

mock_position_db_connector.get_total_amounts_for_open_positions.return_value = (
Decimal(1000.0)
)

result = mock_position_db_connector.get_total_amounts_for_open_positions()

assert result == Decimal(1000.0)


@patch("web_app.db.connectors.PositionDBConnector")
def test_delete_all_user_positions_success(mock_position_db_connector):
"""Test successfully deleting all positions for a user."""
user_id = uuid.uuid4()
mock_positions = [
Position(id=uuid.uuid4(), user_id=user_id, token_symbol="BTC", amount="10"),
Position(id=uuid.uuid4(), user_id=user_id, token_symbol="ETH", amount="5"),
]
mock_session = mock_position_db_connector.Session.return_value
mock_session.query.return_value.filter_by.return_value.all.return_value = (
mock_positions
)

position_connector = PositionDBConnector()
position_connector.delete_all_user_positions(user_id)

mock_session.query.assert_called_once_with(Position)
mock_session.query.return_value.filter_by.assert_called_once_with(user_id=user_id)
assert mock_session.delete.call_count == len(mock_positions)
mock_session.commit.assert_called_once()


### Negative Test Cases ###


def test_get_positions_by_wallet_id_no_user(mock_position_db_connector):
"""Test retrieving positions for non-existent user."""
mock_position_db_connector._get_user_by_wallet_id.return_value = None
mock_position_db_connector.get_positions_by_wallet_id.return_value = []

positions = mock_position_db_connector.get_positions_by_wallet_id("nonexistent_wallet")


positions = mock_position_db_connector.get_positions_by_wallet_id(
"nonexistent_wallet"
)

assert positions == []


def test_get_positions_by_wallet_id_db_error(mock_position_db_connector, sample_user):
"""Test handling database error when retrieving positions."""
mock_position_db_connector._get_user_by_wallet_id.return_value = sample_user

positions = mock_position_db_connector.get_positions_by_wallet_id("test_wallet_id")

assert positions == []


def test_create_position_no_user(mock_position_db_connector):
"""Test creating position for non-existent user."""
mock_position_db_connector._get_user_by_wallet_id.return_value = None
mock_position_db_connector.create_position.return_value = None

result = mock_position_db_connector.create_position(
wallet_id="nonexistent_wallet",
token_symbol="ETH",
amount="100",
multiplier=2
wallet_id="nonexistent_wallet", token_symbol="ETH", amount="100", multiplier=2
)

assert result is None


def test_close_position_not_found(mock_position_db_connector):
"""Test closing non-existent position."""
mock_position_db_connector.get_object.return_value = None
mock_position_db_connector.close_position.return_value = None

result = mock_position_db_connector.close_position(uuid.uuid4())

assert result is None


def test_get_total_amounts_db_error(mock_position_db_connector):
"""Test handling database error when getting total amounts."""
mock_position_db_connector.get_total_amounts_for_open_positions.return_value = None

result = mock_position_db_connector.get_total_amounts_for_open_positions()

assert result is None


def test_get_position_id_by_wallet_id_no_positions(mock_position_db_connector):
"""Test getting position ID when no positions exist."""
mock_position_db_connector.get_positions_by_wallet_id.return_value = []
mock_position_db_connector.get_position_id_by_wallet_id.return_value = None

result = mock_position_db_connector.get_position_id_by_wallet_id("test_wallet_id")

assert result is None

assert result is None


@patch("web_app.db.connectors.PositionDBConnector")
def test_delete_all_user_positions_failure(mock_position_db_connector):
"""Test failure during deletion of all positions for a user."""
user_id = uuid.uuid4()
mock_session = mock_position_db_connector.Session.return_value
mock_session.query.side_effect = SQLAlchemyError("Database error")

position_connector = PositionDBConnector()

with pytest.raises(SQLAlchemyError):
position_connector.delete_all_user_positions(user_id)

mock_session.rollback.assert_called_once()
Loading

0 comments on commit 9de4e20

Please sign in to comment.