From 5cd8b39460d6fa56688605b89eba7475731a5ee3 Mon Sep 17 00:00:00 2001 From: Benjamin Bolte Date: Thu, 30 May 2024 22:54:26 -0700 Subject: [PATCH] fix to use user ids instead --- store/app/api/crud/users.py | 24 ++++++++++++---- store/app/api/routers/users.py | 50 ++++++++++++++-------------------- store/app/api/token.py | 39 +++++++++++++++++++------- 3 files changed, 69 insertions(+), 44 deletions(-) diff --git a/store/app/api/crud/users.py b/store/app/api/crud/users.py index e28bb3af..948fba6e 100644 --- a/store/app/api/crud/users.py +++ b/store/app/api/crud/users.py @@ -15,7 +15,15 @@ async def add_user(self, user: User) -> None: table = await self.db.Table("Users") await table.put_item(Item=user.model_dump()) - async def get_user(self, email: str) -> User | None: + async def get_user(self, user_id: str) -> User | None: + table = await self.db.Table("Users") + user_dict = await table.get_item(Key={"user_id": user_id}) + if "Item" not in user_dict: + return None + user = User.model_validate(user_dict["Item"]) + return user + + async def get_user_from_email(self, email: str) -> User | None: table = await self.db.Table("Users") user_dict = await table.query(IndexName="emailIndex", KeyConditionExpression=KeyCondition("email").eq(email)) items = user_dict["Items"] @@ -44,14 +52,20 @@ async def add_token(self, token: Token) -> None: table = await self.db.Table("Tokens") await table.put_item(Item=token.model_dump()) - async def get_token(self, email: str) -> Token | None: + async def get_token(self, token_id: str) -> Token | None: table = await self.db.Table("Tokens") - token_dict = await table.query(IndexName="emailIndex", KeyConditionExpression=KeyCondition("email").eq(email)) - if len(token_dict["Items"]) == 0: + token_dict = await table.get_item(Key={"token_id": token_id}) + if "Item" not in token_dict: return None - token = Token.model_validate(token_dict["Items"][0]) + token = Token.model_validate(token_dict["Item"]) return token + async def get_user_tokens(self, user_id: str) -> list[Token]: + table = await self.db.Table("Tokens") + tokens = table.query(IndexName="userIdIndex", KeyConditionExpression=KeyCondition("user_id").eq(user_id)) + tokens = [Token.model_validate(token) for token in await tokens] + return tokens + async def test_adhoc() -> None: async with UserCrud() as crud: diff --git a/store/app/api/routers/users.py b/store/app/api/routers/users.py index cec9c80f..eba84580 100644 --- a/store/app/api/routers/users.py +++ b/store/app/api/routers/users.py @@ -1,8 +1,8 @@ """Defines the API endpoint for creating, deleting and updating user information.""" -import asyncio import datetime import logging +import uuid from email.utils import parseaddr as parse_email_address from typing import Annotated @@ -12,7 +12,7 @@ from pydantic.main import BaseModel from store.app.api.db import Crud -from store.app.api.email import OneTimePassPayload, send_delete_email, send_otp_email, send_waitlist_email +from store.app.api.email import OneTimePassPayload, send_delete_email, send_otp_email from store.app.api.model import User from store.app.api.token import create_refresh_token, create_token, load_refresh_token, load_token from store.settings import settings @@ -39,7 +39,8 @@ def set_token_cookie(response: Response, token: str, key: str) -> None: class RefreshTokenData(BaseModel): - email: str + user_id: str + token_id: str @classmethod async def encode(cls, user: User, crud: Crud) -> str: @@ -47,23 +48,23 @@ async def encode(cls, user: User, crud: Crud) -> str: @classmethod def decode(cls, payload: str) -> "RefreshTokenData": - email = load_refresh_token(payload) - return cls(email=email) + user_id, token_id = load_refresh_token(payload) + return cls(user_id=user_id, token_id=token_id) class SessionTokenData(BaseModel): - email: str + user_id: str + token_id: str def encode(self) -> str: expire_minutes = settings.crypto.expire_token_minutes expire_after = datetime.timedelta(minutes=expire_minutes) - return create_token({"eml": self.email}, expire_after=expire_after) + return create_token({"uid": self.user_id, "tid": self.token_id}, expire_after=expire_after, extra="session") @classmethod def decode(cls, payload: str) -> "SessionTokenData": - data = load_token(payload) - email = data["eml"] - return cls(email=email) + data = load_token(payload, extra="session") + return cls(user_id=data["uid"], token_id=data["tid"]) class UserSignup(BaseModel): @@ -79,6 +80,10 @@ def validate_email(email: str) -> str: return email +def get_user_id() -> str: + return str(uuid.uuid4()) + + @users_router.post("/login") async def login_user_endpoint(data: UserSignup) -> bool: email = validate_email(data.email) @@ -96,27 +101,14 @@ class UserLoginResponse(BaseModel): token_type: str -async def add_to_waitlist(email: str, crud: Crud) -> None: - await asyncio.gather( - send_waitlist_email(email), - crud.add_user(User(email=email, banned=True)), - ) - - async def create_or_get(email: str, crud: Crud) -> User: # Gets or creates the user object. user_obj = await crud.get_user(email) if user_obj is None: - await crud.add_user(User(email=email)) + await crud.add_user(User(user_id=get_user_id(), email=email)) if (user_obj := await crud.get_user(email)) is None: raise RuntimeError("Failed to add user to the database") - # Validates user. - if user_obj.banned: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User is not allowed to log in") - if user_obj.deleted: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User is deleted") - return user_obj @@ -220,7 +212,7 @@ async def get_user_info_endpoint( data: Annotated[SessionTokenData, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> UserInfoResponse: - user_obj = await crud.get_user(data.email) + user_obj = await crud.get_user(data.user_id) if user_obj is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found") return UserInfoResponse(email=user_obj.email) @@ -231,7 +223,7 @@ async def delete_user_endpoint( data: Annotated[SessionTokenData, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_obj = await crud.get_user(data.email) + user_obj = await crud.get_user(data.user_id) if user_obj is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found") await crud.delete_user(user_obj) @@ -260,9 +252,9 @@ async def refresh_endpoint( data: Annotated[RefreshTokenData, Depends(get_refresh_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> RefreshTokenResponse: - token = await crud.get_token(data.email) - if not token or token.disabled: + token = await crud.get_token(data.token_id) + if token is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token") - session_token = SessionTokenData(email=data.email).encode() + session_token = SessionTokenData(user_id=data.user_id, token_id=data.token_id).encode() set_token_cookie(response, session_token, SESSION_TOKEN_COOKIE_KEY) return RefreshTokenResponse(token=session_token, token_type=TOKEN_TYPE) diff --git a/store/app/api/token.py b/store/app/api/token.py index 8ce9ae18..e68e9af7 100644 --- a/store/app/api/token.py +++ b/store/app/api/token.py @@ -2,6 +2,7 @@ import datetime import logging +import uuid import jwt from fastapi import HTTPException, status @@ -16,7 +17,16 @@ TIME_FORMAT = "%Y-%m-%d %H:%M:%S" -def create_token(data: dict, expire_after: datetime.timedelta | None = None) -> str: +def get_token_id() -> str: + """Generates a unique token ID. + + Returns: + A unique token ID. + """ + return str(uuid.uuid4()) + + +def create_token(data: dict, expire_after: datetime.timedelta | None = None, extra: str | None = None) -> str: """Creates a token from a dictionary. The "exp" key is reserved for internal use. @@ -24,10 +34,14 @@ def create_token(data: dict, expire_after: datetime.timedelta | None = None) -> Args: data: The data to encode. expire_after: If provided, token will expire after this amount of time. + extra: Additional secret to append to the secret key. Returns: The encoded JWT. """ + secret = settings.crypto.jwt_secret + if extra is not None: + secret += extra if "exp" in data: raise ValueError("The payload should not contain an expiration time") to_encode = data.copy() @@ -38,46 +52,51 @@ def create_token(data: dict, expire_after: datetime.timedelta | None = None) -> expires = server_time() + expire_after to_encode.update({"exp": expires}) - encoded_jwt = jwt.encode(to_encode, settings.crypto.jwt_secret, algorithm=settings.crypto.algorithm) + encoded_jwt = jwt.encode(to_encode, secret, algorithm=settings.crypto.algorithm) return encoded_jwt -def load_token(payload: str) -> dict: +def load_token(payload: str, extra: str | None = None) -> dict: """Loads the token payload. Args: payload: The JWT-encoded payload. only_once: If ``True``, the token will be marked as used. + extra: Additional secret to append to the secret key. Returns: The decoded payload. """ + secret = settings.crypto.jwt_secret + if extra is not None: + secret += extra try: - data: dict = jwt.decode(payload, settings.crypto.jwt_secret, algorithms=[settings.crypto.algorithm]) + data: dict = jwt.decode(payload, secret, algorithms=[settings.crypto.algorithm]) except Exception: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") return data -async def create_refresh_token(email: str, crud: Crud) -> str: +async def create_refresh_token(user_id: str, crud: Crud) -> str: """Creates a refresh token for a user. Refresh tokens never expire. They are used to generate short-lived session tokens which are used for authentication. Args: - email: The email for the user associated with this token. + user_id: The user ID to associate with the token. crud: The CRUD class for the databases. Returns: The encoded JWT. """ - token = Token(email=email) + token_id = get_token_id() + token = Token(user_id=user_id, token_id=token_id) await crud.add_token(token) - return create_token({"eml": email}) + return create_token({"uid": user_id, "tid": token_id}) -def load_refresh_token(payload: str) -> str: +def load_refresh_token(payload: str) -> tuple[str, str]: """Loads the refresh token payload. Args: @@ -87,4 +106,4 @@ def load_refresh_token(payload: str) -> str: The decoded refresh token data. """ data = load_token(payload) - return data["eml"] + return data["uid"], data["tid"]