Skip to content

Commit

Permalink
Merge pull request #136 from djeck1432/feat/refactoring
Browse files Browse the repository at this point in the history
Feat/refactoring
  • Loading branch information
djeck1432 authored Oct 29, 2024
2 parents 9e9125c + 2182181 commit a421b8d
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 179 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ This guide explains how to start the development environment for the project usi
```sh
docker-compose -f docker-compose.dev.yaml build --no-cache
```

## How to run test cases
In root folder run next commands:
```bash
poetry install
```
Activate env
```bash
poetry shell
```
Run test cases
```bash
poetry run pytest
```

## Stopping the Development Environment

Expand Down
180 changes: 90 additions & 90 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ celery = "5.4.0"
redis = "5.2.0"
trio = "0.27.0"
aiogram = ">=3.13.1"
aiohttp = "^3.10.10"

[tool.poetry.group.dev.dependencies]
black = "24.8.0"
Expand Down
22 changes: 0 additions & 22 deletions requirements.txt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ def column_exists(table_name, column_name):

def upgrade() -> None:
"""Upgrade the database."""

if column_exists("position", "start_price"):
logger.info("Column 'start_price' already exists, skipping creation.")
else:
op.add_column(
"position", sa.Column("start_price", sa.DECIMAL(), nullable=False)
"position",
sa.Column("start_price", sa.DECIMAL(), nullable=False, server_default="0")
)
logger.info("Column 'start_price' added to the 'position' table.")
# ### commands auto generated by Alembic - please adjust! ###

# Remove the default value after applying it, if necessary
op.alter_column(
"position",
"start_price",
existing_type=sa.DOUBLE_PRECISION(precision=53),
type_=sa.DECIMAL(),
server_default=None,
existing_type=sa.DECIMAL(),
existing_nullable=False,
)
# ### end Alembic commands ###


def downgrade() -> None:
Expand Down
4 changes: 3 additions & 1 deletion web_app/api/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TokenMultipliers,
)
from web_app.contract_tools.mixins.deposit import DepositMixin
from web_app.contract_tools.mixins.dashboard import DashboardMixin
from web_app.db.crud import PositionDBConnector

router = APIRouter() # Initialize the router
Expand Down Expand Up @@ -184,5 +185,6 @@ async def open_position(position_id: str) -> str:
if not position_id:
raise HTTPException(status_code=404, detail="Position not found")

position_status = position_db_connector.open_position(position_id)
current_prices = await DashboardMixin.get_current_prices()
position_status = position_db_connector.open_position(position_id, current_prices)
return position_status
10 changes: 5 additions & 5 deletions web_app/api/serializers/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module defines the serializers for the user data.
"""

from decimal import Decimal
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -42,13 +42,13 @@ class GetStatsResponse(BaseModel):
Pydantic model for the get_stats response.
"""

total_opened_amount: float = Field(
...,
example=1000.0,
total_opened_amount: Decimal = Field(
default=None,
example="1000.0",
description="Total amount for all open positions across all users.",
)
unique_users: int = Field(
...,
default=0,
example=5,
description="Number of unique users in the database.",
)
20 changes: 11 additions & 9 deletions web_app/contract_tools/airdrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
"""

from typing import List
from api.serializers.airdrop import AirdropItem, AirdropResponseModel
from contract_tools.api_request import APIRequest
from web_app.api.serializers.airdrop import AirdropItem, AirdropResponseModel
from web_app.contract_tools.api_request import APIRequest
from web_app.contract_tools.constants import TokenParams


class ZkLendAirdrop:
"""
A class to fetch and validate airdrop data
for a specified contract.
"""
REWARD_API_ENDPOINT = "https://app.zklend.com/api/reward/all/"

def __init__(self, api: APIRequest):
def __init__(self):
"""
Initializes the ZkLendAirdrop class with an APIRequest instance.
Args:
api (APIRequest): An instance of APIRequest for making API calls.
"""
self.api = api
self.api = APIRequest(base_url=self.REWARD_API_ENDPOINT)

async def get_contract_airdrop(self, contract_id: str) -> AirdropResponseModel:
"""
Expand All @@ -31,11 +32,12 @@ async def get_contract_airdrop(self, contract_id: str) -> AirdropResponseModel:
AirdropResponseModel: A validated list of airdrop items
for the specified contract.
"""
endpoint = f"/contracts/{contract_id}/airdrops"
response = await self.api.fetch(endpoint)
underlying_contract_id = TokenParams.add_underlying_address(contract_id)
response = await self.api.fetch(underlying_contract_id)
return self._validate_response(response)

def _validate_response(self, data: List[dict]) -> AirdropResponseModel:
@staticmethod
def _validate_response(data: List[dict]) -> AirdropResponseModel:
"""
Validates and formats the response data, keeping only necessary fields.
Args:
Expand Down
51 changes: 25 additions & 26 deletions web_app/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from decimal import Decimal
from typing import List, Type, TypeVar

from sqlalchemy import create_engine, update
from sqlalchemy import create_engine, update, func, cast, Numeric
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker

from web_app.contract_tools.mixins.dashboard import DashboardMixin
from web_app.db.database import SQLALCHEMY_DATABASE_URL, get_database
from web_app.db.database import SQLALCHEMY_DATABASE_URL
from web_app.db.models import AirDrop, Base, Position, Status, User, TelegramUser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -169,6 +168,21 @@ def update_user_contract(self, user: User, contract_address: str) -> None:
user.contract_address = contract_address
self.write_to_db(user)

def get_unique_users_count(self) -> int:
"""
Retrieves the number of unique users in the database.
:return: The count of unique users.
"""
with self.Session() as db:
try:
# Query to count distinct users based on wallet ID
unique_users_count = db.query(User.wallet_id).distinct().count()
return unique_users_count

except SQLAlchemyError as e:
logger.error(f"Failed to retrieve unique users count: {str(e)}")
return 0


class PositionDBConnector(UserDBConnector):
"""
Expand Down Expand Up @@ -330,62 +344,47 @@ def close_position(self, position_id: uuid) -> Position | None:
self.write_to_db(position)
return position.status

def open_position(self, position_id: uuid.UUID) -> str | None:
def open_position(self, position_id: uuid.UUID, current_prices: dict) -> str | None:
"""
Opens a position by updating its status and creating an AirDrop claim.
:param position_id: uuid.UUID
:param current_prices: dict
:return: str | None
"""
position = self.get_object(Position, position_id)
if position:
position.status = Status.OPENED.value
self.write_to_db(position)
self.create_empty_claim(position.user_id)
self.save_current_price(position)
self.save_current_price(position, current_prices)
return position.status
else:
logger.error(f"Position with ID {position_id} not found")
return None

def get_unique_users_count(self) -> int:
"""
Retrieves the number of unique users in the database.
:return: The count of unique users.
"""
with self.Session() as db:
try:
# Query to count distinct users based on wallet ID
unique_users_count = db.query(User.wallet_id).distinct().count()
return unique_users_count

except SQLAlchemyError as e:
logger.error(f"Failed to retrieve unique users count: {str(e)}")
return 0

def get_total_amounts_for_open_positions(self) -> float | None:
def get_total_amounts_for_open_positions(self) -> Decimal:
"""
Calculates the total amount for all positions where status is 'OPENED'.
:return: Total amount for all opened positions, or None if no open positions are found
:return: Total amount for all opened positions
"""
with self.Session() as db:
try:
total_opened_amount = (
db.query(db.func.sum(Position.amount))
db.query(func.sum(cast(Position.amount, Numeric)))
.filter(Position.status == Status.OPENED.value)
.scalar()
)
return total_opened_amount
except SQLAlchemyError as e:
logger.error(f"Error calculating total amount for open positions: {e}")
return None
return Decimal(0.0)

def save_current_price(self, position: Position) -> None:
def save_current_price(self, position: Position, price_dict: dict) -> None:
"""
Saves current prices into db.
:return: None
"""
price_dict = DashboardMixin.get_current_prices()
start_price = price_dict.get(position.token_symbol)
try:
position.start_price = start_price
Expand Down
Empty file added web_app/tests/db/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test cases for PositionDBConnector"""

import uuid
from decimal import Decimal
from unittest.mock import MagicMock

import pytest
from sqlalchemy.exc import SQLAlchemyError

from web_app.db.models import Position, Status, User
from web_app.db.crud import PositionDBConnector
Expand Down Expand Up @@ -161,22 +161,14 @@ def test_open_position_success(mock_position_db_connector, sample_position):
assert result == Status.OPENED.value


def test_get_unique_users_count(mock_position_db_connector):
"""Test getting count of unique users."""
mock_position_db_connector.get_unique_users_count.return_value = 5

result = mock_position_db_connector.get_unique_users_count()

assert result == 5


def test_get_total_amounts_for_open_positions(mock_position_db_connector):
"""Test getting total amounts for open positions."""
mock_position_db_connector.get_total_amounts_for_open_positions.return_value = 1000.0
mock_position_db_connector.get_total_amounts_for_open_positions.return_value = Decimal(1000.0)

result = mock_position_db_connector.get_total_amounts_for_open_positions()

assert result == 1000.0
assert result == Decimal(1000.0)


### Negative Test Cases ###
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def user_db(mock_db_connector):
"""
Fixture to create a UserDBConnector instance with mocked dependencies.
"""
with patch('web_app.db.crud.UserDBConnector.get_object_by_field',
with patch('web_app.db.crud.UserDBConnector.get_object_by_field',
new_callable=MagicMock) as mock_get:
mock_get.side_effect = mock_db_connector.get_object_by_field
connector = UserDBConnector()
Expand All @@ -34,11 +34,11 @@ def test_get_user_by_wallet_id_success(user_db, mock_db_connector):
id=1,
wallet_id=wallet_id,
)

mock_db_connector.get_object_by_field.return_value = expected_user

result = user_db.get_user_by_wallet_id(wallet_id)

assert result == expected_user
mock_db_connector.get_object_by_field.assert_called_once_with(
User,
Expand All @@ -52,9 +52,9 @@ def test_get_user_by_wallet_id_not_found(user_db, mock_db_connector):
"""
wallet_id = "0x987654321"
mock_db_connector.get_object_by_field.return_value = None

result = user_db.get_user_by_wallet_id(wallet_id)

assert result is None
mock_db_connector.get_object_by_field.assert_called_once_with(
User,
Expand All @@ -68,12 +68,21 @@ def test_get_user_by_wallet_id_empty_wallet_id(user_db, mock_db_connector):
"""
wallet_id = ""
mock_db_connector.get_object_by_field.return_value = None

result = user_db.get_user_by_wallet_id(wallet_id)

assert result is None
mock_db_connector.get_object_by_field.assert_called_once_with(
User,
"wallet_id",
wallet_id
)
)


def test_get_unique_users_count(mock_user_db_connector):
"""Test getting count of unique users."""
mock_user_db_connector.get_unique_users_count.return_value = 5

result = mock_user_db_connector.get_unique_users_count()

assert result == 5

0 comments on commit a421b8d

Please sign in to comment.