Skip to content

Commit

Permalink
Merge pull request #442 from greatest0fallt1me/crud-method
Browse files Browse the repository at this point in the history
Create seperate crud method to fetch positions independent on status
  • Loading branch information
iamnovichek authored Dec 25, 2024
2 parents 9e06f47 + 58b74bd commit 65d8fa0
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
2 changes: 1 addition & 1 deletion web_app/api/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def get_dashboard(wallet_id: str) -> DashboardResponse:

# Fetching first 10 positions at the moment
opened_positions = position_db_connector.get_positions_by_wallet_id(
wallet_id, 0, 10
wallet_id
)

# At the moment, we only support one position per wallet
Expand Down
2 changes: 1 addition & 1 deletion web_app/api/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def get_user_positions(wallet_id: str, start: Optional[int] = None) -> lis

start_index = max(0, start) if start is not None else 0

positions = position_db_connector.get_positions_by_wallet_id(
positions = position_db_connector.get_all_positions_by_wallet_id(
wallet_id, start_index, PAGINATION_STEP
)
return positions
36 changes: 36 additions & 0 deletions web_app/db/crud/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,42 @@ def get_positions_by_wallet_id(
logger.error(f"Failed to retrieve positions: {str(e)}")
return []

def get_all_positions_by_wallet_id(
self, wallet_id: str, start: int, limit: int
) -> list:
"""
Retrieves paginated positions for a user by their wallet ID
and returns them as a list of dictionaries.
:param wallet_id: str
:param start: starting index for pagination
:param limit: number of records to return
:return: list of dict
"""
with self.Session() as db:
user = self._get_user_by_wallet_id(wallet_id)
if not user:
return []

try:
positions = (
db.query(Position)
.filter(
Position.user_id == user.id,
)
.offset(start)
.limit(limit)
.all()
)
# Convert positions to a list of dictionaries
positions_dicts = [
self._position_to_dict(position) for position in positions
]
return positions_dict

except SQLAlchemyError as e:
logger.error(f"Failed to retrieve positions: {str(e)}")
return []

def has_opened_position(self, wallet_id: str) -> bool:
"""
Checks if a user has any opened positions.
Expand Down
2 changes: 1 addition & 1 deletion web_app/tests/test_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ async def test_get_user_positions_success(client: TestClient) -> None:
]

with patch(
"web_app.db.crud.PositionDBConnector.get_positions_by_wallet_id"
"web_app.db.crud.PositionDBConnector.get_all_positions_by_wallet_id"
) as mock_get_positions:
mock_get_positions.return_value = mock_positions
response = client.get(f"/api/user-positions/{wallet_id}")
Expand Down

0 comments on commit 65d8fa0

Please sign in to comment.