From cc0b14ec8f24588a8cd1bf24d2a3ebe7c4d0aa96 Mon Sep 17 00:00:00 2001 From: Teri-anric <2005ahi2005@gmail.com> Date: Mon, 16 Dec 2024 00:13:09 +0200 Subject: [PATCH 1/4] temp check --- web_app/api/main.py | 3 -- web_app/api/serializers/transaction.py | 4 +- web_app/api/user.py | 23 ++++++--- web_app/api/vault.py | 6 +-- web_app/contract_tools/mixins/dashboard.py | 5 +- web_app/db/crud/airdrop.py | 8 +-- web_app/db/crud/telegram.py | 57 ++++++---------------- web_app/db/crud/user.py | 28 +---------- 8 files changed, 44 insertions(+), 90 deletions(-) diff --git a/web_app/api/main.py b/web_app/api/main.py index 3f1de5cf..af4a4f64 100644 --- a/web_app/api/main.py +++ b/web_app/api/main.py @@ -45,9 +45,6 @@ }, ) -# Set up the templates directory -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) - # Add session middleware with a secret key app.add_middleware(SessionMiddleware, secret_key=f"Secret:{str(uuid4())}") # CORS middleware for React frontend diff --git a/web_app/api/serializers/transaction.py b/web_app/api/serializers/transaction.py index 5597353a..f3a2a24d 100644 --- a/web_app/api/serializers/transaction.py +++ b/web_app/api/serializers/transaction.py @@ -91,7 +91,7 @@ class UpdateUserContractRequest(BaseModel): contract_address: str -class DeploymentStatus(BaseModel): +class DeploymentStatus(BaseModel): # FIXME: Not used anymore """ Pydantic model for the deployment status. """ @@ -99,7 +99,7 @@ class DeploymentStatus(BaseModel): is_contract_deployed: bool -class ContractAddress(BaseModel): +class ContractAddress(BaseModel): # FIXME: Not used anymore """ Pydantic model for the contract address. """ diff --git a/web_app/api/user.py b/web_app/api/user.py index 54c3a300..c6d2dc87 100644 --- a/web_app/api/user.py +++ b/web_app/api/user.py @@ -52,7 +52,7 @@ async def has_user_opened_position(wallet_id: str) -> dict: ) -@router.get( +@router.get( # FIXME: Not used, only used in tests "/api/get-user-contract", tags=["User Operations"], summary="Get user's contract status", @@ -155,18 +155,27 @@ async def subscribe_to_notification( Success status of the subscription. """ user = user_db.get_user_by_wallet_id(data.wallet_id) + # Check if the user exists; if not, raise a 404 error if not user: raise HTTPException(status_code=404, detail="User not found") - is_allowed_notification = telegram_db.allow_notification(data.telegram_id) - if is_allowed_notification: + telegram_id = data.telegram_id + # Is not provided, attempt to retrieve it from the database + if not telegram_id: + tg_user = telegram_db.get_telegram_user_by_wallet_id(data.wallet_id) + telegram_id = tg_user.telegram_id + # Is found, set the notification preference for the user + if telegram_id: + telegram_db.set_allow_notification(telegram_id, data.wallet_id) return {"detail": "User subscribed to notifications successfully"} + + # If no Telegram ID is available, raise raise HTTPException( status_code=400, detail="Failed to subscribe user to notifications" ) -@router.get( +@router.get( # FIXME: Not used, only used in tests "/api/get-user-contract-address", tags=["User Operations"], summary="Get user's contract address", @@ -244,7 +253,7 @@ async def get_stats() -> GetStatsResponse: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") -@router.get( +@router.get( # FIXME: Not used anymore "/api/get-user-history", tags=["User Operations"], summary="Get user position history", @@ -281,14 +290,14 @@ async def get_user_history(user_id: str) -> list[dict]: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") -@router.post("/allow-notification/{telegram_id}") +@router.post("/allow-notification/{telegram_id}") # FIXME: Not used anymore async def allow_notification( telegram_id: int, telegram_db: TelegramUserDBConnector = Depends(lambda: TelegramUserDBConnector()), ): """Enable notifications for a specific telegram user""" try: - telegram_db.allow_notification(telegram_id=telegram_id) + telegram_db.set_allow_notification(telegram_id=telegram_id) return {"message": "Notifications enabled successfully"} except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/web_app/api/vault.py b/web_app/api/vault.py index 840351e5..7cc49c19 100644 --- a/web_app/api/vault.py +++ b/web_app/api/vault.py @@ -19,7 +19,7 @@ router = APIRouter(prefix="/api/vault", tags=["vault"]) -@router.post("/deposit", response_model=VaultDepositResponse) +@router.post("/deposit", response_model=VaultDepositResponse) # FIXME: Not used, only used in tests async def deposit_to_vault( request: VaultDepositRequest, deposit_connector: DepositDBConnector = Depends(DepositDBConnector), @@ -50,7 +50,7 @@ async def deposit_to_vault( raise HTTPException(status_code=400, detail=str(e)) -@router.get("/balance", response_model=VaultBalanceResponse) +@router.get("/balance", response_model=VaultBalanceResponse) # FIXME: Not used, only used in tests async def get_user_vault_balance( wallet_id: str, symbol: str, @@ -67,7 +67,7 @@ async def get_user_vault_balance( return VaultBalanceResponse(wallet_id=wallet_id, symbol=symbol, amount=balance) -@router.post("/add_balance", response_model=UpdateVaultBalanceResponse) +@router.post("/add_balance", response_model=UpdateVaultBalanceResponse) # FIXME: Not used, only used in tests async def add_vault_balance( request: UpdateVaultBalanceRequest, deposit_connector: DepositDBConnector = Depends(DepositDBConnector), diff --git a/web_app/contract_tools/mixins/dashboard.py b/web_app/contract_tools/mixins/dashboard.py index d3cb30ad..84a231d5 100644 --- a/web_app/contract_tools/mixins/dashboard.py +++ b/web_app/contract_tools/mixins/dashboard.py @@ -85,7 +85,7 @@ async def get_wallet_balances(cls, holder_address: str) -> Dict[str, str]: return wallet_balances @classmethod - async def get_zklend_position( + async def get_zklend_position( # FIXME: Not used, only used in tests cls, contract_address: str, position: "Position" @@ -99,7 +99,7 @@ async def get_zklend_position( pass @classmethod - def _get_products(cls, dapps: list) -> list[dict]: + def _get_products(cls, dapps: list) -> list[dict]: # FIXME: Not used anymore """ Get the products from the dapps. :param dapps: List of dapps @@ -114,7 +114,6 @@ async def get_current_position_sum(cls, position: dict) -> Decimal: :param position: Position data :return: current position sum """ - # TODO add test cases for this method current_prices = await cls.get_current_prices() try: result = current_prices.get(position["token_symbol"], Decimal(0)) * Decimal( diff --git a/web_app/db/crud/airdrop.py b/web_app/db/crud/airdrop.py index bd177052..fa9a40ed 100644 --- a/web_app/db/crud/airdrop.py +++ b/web_app/db/crud/airdrop.py @@ -55,16 +55,16 @@ def get_all_unclaimed(self) -> List[AirDrop]: ) return [] - def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None: + def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None: # FIXME: Not used, only used in tests """ Delete all airdrops for a user. :param user_id: User ID """ with self.Session() as db: try: - airdrops = db.query(AirDrop).filter_by(user_id=user_id).all() - for airdrop in airdrops: - db.delete(airdrop) + db.query(AirDrop).filter_by(user_id=user_id).delete( + synchronize_session=False + ) db.commit() except SQLAlchemyError as e: logger.error(f"Error deleting airdrops for user {user_id}: {str(e)}") diff --git a/web_app/db/crud/telegram.py b/web_app/db/crud/telegram.py index 98ea55c8..d4386819 100644 --- a/web_app/db/crud/telegram.py +++ b/web_app/db/crud/telegram.py @@ -20,6 +20,14 @@ class TelegramUserDBConnector(DBConnector): Provides database connection and operations management for the TelegramUser model. """ + def get_telegram_user_by_wallet_id(self, wallet_id: str) -> TelegramUser | None: + """ + Retrieves a TelegramUser by their wallet ID. + :param wallet_id: str + :return: TelegramUser | None + """ + return self.get_object_by_field(TelegramUser, "wallet_id", wallet_id) + def get_user_by_telegram_id(self, telegram_id: str) -> TelegramUser | None: """ Retrieves a TelegramUser by their Telegram ID. @@ -91,46 +99,11 @@ def set_allow_notification(self, telegram_id: str, wallet_id: str) -> bool: """ Set wallet_id and is_allowed_notification to True for a user by their telegram ID. """ - with self.Session() as session: - if telegram_user := self.get_user_by_telegram_id(telegram_id): - telegram_user.is_allowed_notification = True - session.commit() - logger.info(f"Notification allowed for user with telegram_id {telegram_id}") - return telegram_user - else: - logger.info(f"User with telegram_id {telegram_id} not found, creating new one") - self.create_telegram_user( - dict(telegram_id=telegram_id, wallet_id=wallet_id) - ) - self.allow_notification(telegram_id) - - def allow_notification(self, telegram_id: str) -> bool: - """ - Update is_allowed_notification field to True for a specific telegram user - - Args: - telegram_id: Telegram user ID - - Raises: - ValueError: If the user with the given telegram_id is not found - """ - with self.Session() as session: - user = ( - session.query(TelegramUser).filter_by(telegram_id=telegram_id).first() + self.save_or_update_user( + dict( + telegram_id=telegram_id, + wallet_id=wallet_id, + is_allowed_notification=True, ) - if not user: - raise ValueError(f"User with telegram_id {telegram_id} not found") - - user.is_allowed_notification = True - session.commit() - return True - - def is_allowed_notification(self, wallet_id: str = None) -> bool | None: - """ - Returns true or false if a telegram user allowed notification. - - Args: - wallet_id: Wallet ID of the user. - """ - user = self.get_object_by_field(TelegramUser, "wallet_id", wallet_id) - return user.is_allowed_notification if user else None + ) + return True \ No newline at end of file diff --git a/web_app/db/crud/user.py b/web_app/db/crud/user.py index 2ae908b3..63768a90 100644 --- a/web_app/db/crud/user.py +++ b/web_app/db/crud/user.py @@ -20,6 +20,7 @@ class UserDBConnector(DBConnector): """ def get_all_users_with_opened_position(self) -> List[User]: + # FIXME: Not used anymore """ Retrieves all users with an OPENED position status from the database. First queries Position table for OPENED positions, then gets the associated users. @@ -118,31 +119,6 @@ def get_unique_users_count(self) -> int: logger.error(f"Failed to retrieve unique users count: {str(e)}") return 0 - def delete_user_by_wallet_id(self, wallet_id: str) -> None: - """ - Deletes a user from the database by their wallet ID. - Rolls back the transaction if the operation fails. - - :param wallet_id: str - :return: None - :raises SQLAlchemyError: If the operation fails - """ - with self.Session() as session: - try: - user = session.query(User).filter(User.wallet_id == wallet_id).first() - if user: - session.delete(user) - session.commit() - logger.info( - f"User with wallet_id {wallet_id} deleted successfully." - ) - else: - logger.warning(f"No user found with wallet_id {wallet_id}.") - except SQLAlchemyError as e: - session.rollback() - logger.error(f"Failed to delete user with wallet_id {wallet_id}: {e}") - raise e - def fetch_user_history(self, user_id: int) -> List[dict]: """ Fetches all positions for a user with the specified fields: @@ -191,7 +167,7 @@ def fetch_user_history(self, user_id: int) -> List[dict]: ) return [] - def delete_user_by_wallet_id(self, wallet_id: str) -> None: + def delete_user_by_wallet_id(self, wallet_id: str) -> None: # FIXME: Not used, only used in tests """ Deletes a user from the database by their wallet ID. Rolls back the transaction if the operation fails. From e6acdc8cd3d826836b7dae221f033fe387b5de7b Mon Sep 17 00:00:00 2001 From: Teri-anric <2005ahi2005@gmail.com> Date: Tue, 17 Dec 2024 12:48:59 +0200 Subject: [PATCH 2/4] Remove unused endpoints, models, functions --- web_app/api/serializers/transaction.py | 16 ------ web_app/api/telegram.py | 2 +- web_app/api/user.py | 57 +--------------------- web_app/api/vault.py | 6 +-- web_app/contract_tools/mixins/dashboard.py | 9 ---- web_app/db/crud/airdrop.py | 2 +- web_app/db/crud/user.py | 24 +-------- 7 files changed, 8 insertions(+), 108 deletions(-) diff --git a/web_app/api/serializers/transaction.py b/web_app/api/serializers/transaction.py index b877d22a..7ea86794 100644 --- a/web_app/api/serializers/transaction.py +++ b/web_app/api/serializers/transaction.py @@ -97,19 +97,3 @@ class UpdateUserContractRequest(BaseModel): wallet_id: str contract_address: str - - -class DeploymentStatus(BaseModel): # FIXME: Not used anymore - """ - Pydantic model for the deployment status. - """ - - is_contract_deployed: bool - - -class ContractAddress(BaseModel): # FIXME: Not used anymore - """ - Pydantic model for the contract address. - """ - - contract_address: str | None diff --git a/web_app/api/telegram.py b/web_app/api/telegram.py index 5048b848..8d0319d0 100644 --- a/web_app/api/telegram.py +++ b/web_app/api/telegram.py @@ -98,7 +98,7 @@ async def telegram_webhook(update: Update): return b"", 200 -@router.post( +@router.post( # FIXME REMOVE IT (delete and frontend, not used) "/api/telegram/save-user", tags=["Telegram Operations"], summary="Save or update Telegram user information", diff --git a/web_app/api/user.py b/web_app/api/user.py index 52894607..e5c2e0e7 100644 --- a/web_app/api/user.py +++ b/web_app/api/user.py @@ -56,7 +56,7 @@ async def has_user_opened_position(wallet_id: str) -> dict: ) -@router.get( # FIXME: Not used, only used in tests +@router.get( "/api/get-user-contract", tags=["User Operations"], summary="Get user's contract status", @@ -179,7 +179,7 @@ async def subscribe_to_notification( ) -@router.get( # FIXME: Not used, only used in tests +@router.get( "/api/get-user-contract-address", tags=["User Operations"], summary="Get user's contract address", @@ -255,56 +255,3 @@ async def get_stats() -> GetStatsResponse: except Exception as e: logger.error(f"Error in get_stats: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@router.get( # FIXME: Not used anymore - "/api/get-user-history", - tags=["User Operations"], - summary="Get user position history", - response_model=UserHistoryResponse, - response_description="List of user positions including status,created_at, \ - start_price, amount, and multiplier.", -) -async def get_user_history(user_id: str) -> list[dict]: - """ - Retrieves the history of positions for a specified user. - - ### Parameters: - - **user_id**: The unique ID of the user whose position history is being fetched. - - ### Returns: - - A list of positions with the following details: - - `status`: Current status of the position. - - `created_at`: Timestamp when the position was created. - - `start_price`: Initial price of the asset when the position was opened. - - `amount`: Amount involved in the position. - - `multiplier`: Leverage multiplier applied to the position. - """ - # FIXME REMOVE IT - try: - # Fetch user history from the database - positions = user_db.fetch_user_history(user_id) - - if not positions: - logger.info(f"No positions found for user_id={user_id}") - return [] - - return positions - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@router.post("/allow-notification/{telegram_id}") # FIXME: Not used anymore -async def allow_notification( - telegram_id: int, - telegram_db: TelegramUserDBConnector = Depends(lambda: TelegramUserDBConnector()), -): - """Enable notifications for a specific telegram user""" - try: - telegram_db.set_allow_notification(telegram_id=telegram_id) - return {"message": "Notifications enabled successfully"} - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/web_app/api/vault.py b/web_app/api/vault.py index 7cc49c19..840351e5 100644 --- a/web_app/api/vault.py +++ b/web_app/api/vault.py @@ -19,7 +19,7 @@ router = APIRouter(prefix="/api/vault", tags=["vault"]) -@router.post("/deposit", response_model=VaultDepositResponse) # FIXME: Not used, only used in tests +@router.post("/deposit", response_model=VaultDepositResponse) async def deposit_to_vault( request: VaultDepositRequest, deposit_connector: DepositDBConnector = Depends(DepositDBConnector), @@ -50,7 +50,7 @@ async def deposit_to_vault( raise HTTPException(status_code=400, detail=str(e)) -@router.get("/balance", response_model=VaultBalanceResponse) # FIXME: Not used, only used in tests +@router.get("/balance", response_model=VaultBalanceResponse) async def get_user_vault_balance( wallet_id: str, symbol: str, @@ -67,7 +67,7 @@ async def get_user_vault_balance( return VaultBalanceResponse(wallet_id=wallet_id, symbol=symbol, amount=balance) -@router.post("/add_balance", response_model=UpdateVaultBalanceResponse) # FIXME: Not used, only used in tests +@router.post("/add_balance", response_model=UpdateVaultBalanceResponse) async def add_vault_balance( request: UpdateVaultBalanceRequest, deposit_connector: DepositDBConnector = Depends(DepositDBConnector), diff --git a/web_app/contract_tools/mixins/dashboard.py b/web_app/contract_tools/mixins/dashboard.py index 760abe05..6e8327cc 100644 --- a/web_app/contract_tools/mixins/dashboard.py +++ b/web_app/contract_tools/mixins/dashboard.py @@ -83,15 +83,6 @@ async def get_wallet_balances(cls, holder_address: str) -> Dict[str, str]: return wallet_balances - @classmethod - def _get_products(cls, dapps: list) -> list[dict]: - """ - Get the products from the dapps. - :param dapps: List of dapps - :return: List of positions - """ - return [product for dapp in dapps for product in dapp.get("products", [])] - @classmethod def _calculate_sum( cls, price: Decimal, amount: Decimal, multiplier: Decimal diff --git a/web_app/db/crud/airdrop.py b/web_app/db/crud/airdrop.py index fa9a40ed..5f5abc49 100644 --- a/web_app/db/crud/airdrop.py +++ b/web_app/db/crud/airdrop.py @@ -55,7 +55,7 @@ def get_all_unclaimed(self) -> List[AirDrop]: ) return [] - def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None: # FIXME: Not used, only used in tests + def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None: """ Delete all airdrops for a user. :param user_id: User ID diff --git a/web_app/db/crud/user.py b/web_app/db/crud/user.py index 63768a90..18d57128 100644 --- a/web_app/db/crud/user.py +++ b/web_app/db/crud/user.py @@ -19,28 +19,6 @@ class UserDBConnector(DBConnector): Provides database connection and operations management for the User model. """ - def get_all_users_with_opened_position(self) -> List[User]: - # FIXME: Not used anymore - """ - Retrieves all users with an OPENED position status from the database. - First queries Position table for OPENED positions, then gets the associated users. - - :return: List[User] - """ - with self.Session() as db: - try: - users = ( - db.query(User) - .join(Position, Position.user_id == User.id) - .filter(Position.status == Status.OPENED.value) - .distinct() - .all() - ) - return users - except SQLAlchemyError as e: - logger.error(f"Error retrieving users with OPENED positions: {e}") - return [] - def get_users_for_notifications(self) -> List[Tuple[str, str]]: """ Retrieves the contract_address of users with an OPENED position status and @@ -167,7 +145,7 @@ def fetch_user_history(self, user_id: int) -> List[dict]: ) return [] - def delete_user_by_wallet_id(self, wallet_id: str) -> None: # FIXME: Not used, only used in tests + def delete_user_by_wallet_id(self, wallet_id: str) -> None: """ Deletes a user from the database by their wallet ID. Rolls back the transaction if the operation fails. From dbbabff5e813e808d730f63ea664828f16559e41 Mon Sep 17 00:00:00 2001 From: Teri-anric <2005ahi2005@gmail.com> Date: Tue, 17 Dec 2024 18:36:46 +0200 Subject: [PATCH 3/4] fix test case --- web_app/api/serializers/user.py | 6 +-- web_app/api/user.py | 7 +-- web_app/tests/test_user.py | 82 ++++++++++++++++++--------------- 3 files changed, 52 insertions(+), 43 deletions(-) diff --git a/web_app/api/serializers/user.py b/web_app/api/serializers/user.py index 94b43221..a1464d91 100644 --- a/web_app/api/serializers/user.py +++ b/web_app/api/serializers/user.py @@ -91,12 +91,12 @@ class UserHistoryResponse(BaseModel): positions: list[PositionHistoryItem] -class SubscribeToNotificationResponse(BaseModel): +class SubscribeToNotificationRequest(BaseModel): """ Pydantic model for the notification subscription request. """ - telegram_id: str = Field( - ..., example="123456789", description="Telegram ID of the user" + telegram_id: str | None = Field( + None, example="123456789", description="Telegram ID of the user" ) wallet_id: str = Field(..., example="0xabc123", description="Wallet ID of the user") diff --git a/web_app/api/user.py b/web_app/api/user.py index e5c2e0e7..9c1648b1 100644 --- a/web_app/api/user.py +++ b/web_app/api/user.py @@ -11,7 +11,7 @@ CheckUserResponse, GetStatsResponse, GetUserContractAddressResponse, - SubscribeToNotificationResponse, + SubscribeToNotificationRequest, UpdateUserContractResponse, UserHistoryResponse, ) @@ -146,7 +146,7 @@ async def update_user_contract( response_description="Returns success status of notification subscription", ) async def subscribe_to_notification( - data: SubscribeToNotificationResponse, + data: SubscribeToNotificationRequest, ): """ This endpoint subscribes a user to notifications by linking their telegram ID to their wallet. @@ -167,7 +167,8 @@ async def subscribe_to_notification( # Is not provided, attempt to retrieve it from the database if not telegram_id: tg_user = telegram_db.get_telegram_user_by_wallet_id(data.wallet_id) - telegram_id = tg_user.telegram_id + if tg_user: + telegram_id = tg_user.telegram_id # Is found, set the notification preference for the user if telegram_id: telegram_db.set_allow_notification(telegram_id, data.wallet_id) diff --git a/web_app/tests/test_user.py b/web_app/tests/test_user.py index 6bdacdb1..53b32257 100644 --- a/web_app/tests/test_user.py +++ b/web_app/tests/test_user.py @@ -7,7 +7,7 @@ import pytest from web_app.api.serializers.transaction import UpdateUserContractRequest -from web_app.api.serializers.user import SubscribeToNotificationResponse +from web_app.db.models import TelegramUser, User from web_app.tests.conftest import client, mock_user_db_connector @@ -184,84 +184,92 @@ async def test_get_user_contract_address( @pytest.mark.asyncio -@patch("web_app.db.crud.TelegramUserDBConnector.allow_notification") +@patch("web_app.db.crud.TelegramUserDBConnector.set_allow_notification") +@patch("web_app.db.crud.TelegramUserDBConnector.get_telegram_user_by_wallet_id") @patch("web_app.db.crud.UserDBConnector.get_user_by_wallet_id") @pytest.mark.parametrize( - "telegram_id, wallet_id, expected_status_code, expected_response, is_allowed_notification", + "telegram_id, wallet_id, user_telegram_id, expected_status_code, expected_response", [ ( "123456789", "0x27994c503bd8c32525fbdaf9d398bdd4e86757988c64581b055a06c5955ea49", + "123456789", 200, {"detail": "User subscribed to notifications successfully"}, - True, ), ( + None, + "0x27994c503bd8c32525fbdaf9d398bdd4e86757988c64581b055a06c5955ea49", "123456789", + 200, + {"detail": "User subscribed to notifications successfully"}, + ), + ( + "123456789", "invalid_wallet_id", + None, 404, {"detail": "User not found"}, - False, ), ( None, "0x27994c503bd8c32525fbdaf9d398bdd4e86757988c64581b055a06c5955ea49", - 422, None, - False, + 400, + {"detail": "Failed to subscribe user to notifications"}, ), ], ) async def test_subscribe_to_notification( mock_get_user_by_wallet_id: MagicMock, - mock_allow_notification: MagicMock, + mock_get_telegram_user_by_wallet_id: MagicMock, + mock_set_allow_notification: MagicMock, client, - telegram_id: str, + telegram_id: str | None, wallet_id: str, + user_telegram_id: str | None, expected_status_code: int, - expected_response: dict | None, - is_allowed_notification: bool, + expected_response: dict, ) -> None: """ Test subscribe_to_notification endpoint with both positive and negative cases. :param client: fastapi.testclient.TestClient :param mock_get_user_by_wallet_id: unittest.mock.MagicMock for get_user_by_wallet_id - :param mock_allow_notification: unittest.mock.MagicMock for allow_notification + :param mock_get_telegram_user_by_wallet_id: unittest.mock.MagicMock for get_telegram_user_by_wallet_id + :param mock_set_allow_notification: unittest.mock.MagicMock for set_allow_notification :param telegram_id: str[Telegram ID of the user] - :param wallet_id: str[Wallet ID of the user] + :param wallet_id: str[Wallet ID of the user] + :param user_telegram_id: str[Telegram ID of the db user] :param expected_status_code: int[Expected HTTP status code] - :param expected_response: dict | None[Expected JSON response] + :param expected_response: dict[Expected JSON response] :return: None """ # Define the behavior of the mocks - mock_allow_notification.return_value = is_allowed_notification - - if wallet_id == "invalid_wallet_id": - mock_get_user_by_wallet_id.return_value = None - else: - mock_get_user_by_wallet_id.return_value = {"wallet_id": wallet_id} - - if telegram_id and wallet_id: - data = { - "telegram_id": telegram_id, - "wallet_id": wallet_id, - } - else: - data = {"telegram_id": telegram_id, "wallet_id": wallet_id} + mock_set_allow_notification.return_value = True + + mock_get_user_by_wallet_id.return_value = None + if wallet_id != "invalid_wallet_id": + mock_get_user_by_wallet_id.return_value = User( + wallet_id=wallet_id, + is_contract_deployed=True, + ) + + mock_get_telegram_user_by_wallet_id.return_value = None + if user_telegram_id: + tg_user = TelegramUser( + telegram_id=user_telegram_id, + wallet_id=wallet_id, + ) + mock_get_telegram_user_by_wallet_id.return_value = tg_user + data = {"telegram_id": telegram_id, "wallet_id": wallet_id} + response = client.post( url="/api/subscribe-to-notification", json=data, ) - response_json = response.json() + assert response.status_code == expected_status_code - if expected_response: - assert response_json == expected_response - elif expected_status_code == 422: - assert "detail" in response_json - assert isinstance(response_json["detail"], list) - elif expected_status_code == 404: - assert "detail" in response_json - assert response_json["detail"] == "User not found" + assert response.json() == expected_response From fb7554b17424985eb39b5017d4be60ddba1a22bc Mon Sep 17 00:00:00 2001 From: Teri-anric <2005ahi2005@gmail.com> Date: Tue, 17 Dec 2024 18:42:51 +0200 Subject: [PATCH 4/4] fix pylint --- web_app/tests/test_user.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web_app/tests/test_user.py b/web_app/tests/test_user.py index 53b32257..5329dddf 100644 --- a/web_app/tests/test_user.py +++ b/web_app/tests/test_user.py @@ -236,7 +236,8 @@ async def test_subscribe_to_notification( :param client: fastapi.testclient.TestClient :param mock_get_user_by_wallet_id: unittest.mock.MagicMock for get_user_by_wallet_id - :param mock_get_telegram_user_by_wallet_id: unittest.mock.MagicMock for get_telegram_user_by_wallet_id + :param mock_get_telegram_user_by_wallet_id: unittest.mock.MagicMock + for get_telegram_user_by_wallet_id :param mock_set_allow_notification: unittest.mock.MagicMock for set_allow_notification :param telegram_id: str[Telegram ID of the user] :param wallet_id: str[Wallet ID of the user]