Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: start price field in position table #100

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""start_price field added in position

Revision ID: a009512f5362
Revises: b705d1435b64
Create Date: 2024-10-24 18:56:29.399344

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect


# revision identifiers, used by Alembic.
revision = 'a009512f5362'
down_revision = 'b705d1435b64'
branch_labels = None
depends_on = None


def column_exists(table_name, column_name):
"""Utility function to check if a column exists in the table."""
bind = op.get_bind()
inspector = inspect(bind)
columns = [col['name'] for col in inspector.get_columns(table_name)]
return column_name in columns

def upgrade() -> None:
"""Upgrade the database."""

if column_exists('position', 'start_price'):
print("Column 'start_price' already exists, skipping creation.")
else:
op.add_column('position', sa.Column('start_price', sa.DECIMAL(), nullable=False))
print("Column 'start_price' added to the 'position' table.")
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('position', 'start_price',
existing_type=sa.DOUBLE_PRECISION(precision=53),
type_=sa.DECIMAL(),
existing_nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade the database."""

if column_exists('position', 'start_price'):
print("Column 'start_price' exists, downgrading.")
op.drop_column('position', 'start_price')
else:
print("Column 'start_price' already removed, skipping.")
15 changes: 6 additions & 9 deletions web_app/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,8 @@
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker

from web_app.db.database import SQLALCHEMY_DATABASE_URL
from web_app.db.models import (
Base,
User,
Position,
Status,
)
from web_app.db.models import Base, Position, Status, User

logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
Expand Down Expand Up @@ -168,6 +162,8 @@ class PositionDBConnector(UserDBConnector):
Provides database connection and operations management for the Position model.
"""

START_PRICE = 0.0

@staticmethod
def _position_to_dict(position: Position) -> dict:
"""
Expand Down Expand Up @@ -197,7 +193,7 @@ def _get_user_by_wallet_id(self, wallet_id: str) -> User | None:

def get_positions_by_wallet_id(self, wallet_id: str) -> list:
"""
Retrieves all positions for a user by their wallet ID
Retrieves all positions for a user by their wallet ID
and returns them as a list of dictionaries.
:param wallet_id: str
:return: list of dict
Expand Down Expand Up @@ -259,6 +255,7 @@ def create_position(
existing_position.token_symbol = token_symbol
existing_position.amount = amount
existing_position.multiplier = multiplier
existing_position.start_price = PositionDBConnector.START_PRICE
session.commit() # Commit the changes to the database
session.refresh(existing_position) # Refresh to get updated values
return existing_position
Expand All @@ -270,6 +267,7 @@ def create_position(
amount=amount,
multiplier=multiplier,
status=Status.PENDING.value, # Set status as 'pending' by default
start_price=PositionDBConnector.START_PRICE,
)

# Write the new position to the database
Expand All @@ -287,7 +285,6 @@ def get_position_id_by_wallet_id(self, wallet_id: str) -> str | None:
return position[0]["id"]
return None


def update_position(self, position: Position, amount: str, multiplier: int) -> None:
"""
Updates a position in the database.
Expand Down
2 changes: 1 addition & 1 deletion web_app/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def get_database() -> SessionLocal:
try:
yield database
finally:
database.close()
database.close()
15 changes: 13 additions & 2 deletions web_app/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@
This module contains the SQLAlchemy models for the database.
"""

from enum import Enum as PyEnum
from uuid import uuid4

from sqlalchemy import (
DECIMAL,
Boolean,
Column,
DateTime,
Enum,
ForeignKey,
Integer,
String,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, DateTime, Enum
from enum import Enum as PyEnum
from web_app.db.database import Base


Expand Down Expand Up @@ -62,3 +72,4 @@ class Position(Base):
nullable=True,
default="pending",
)
start_price = Column(DECIMAL, nullable=False)
Loading