Skip to content

Commit

Permalink
fix to use user ids instead
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed May 31, 2024
1 parent 722aa17 commit 5cd8b39
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 44 deletions.
24 changes: 19 additions & 5 deletions store/app/api/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ async def add_user(self, user: User) -> None:
table = await self.db.Table("Users")
await table.put_item(Item=user.model_dump())

async def get_user(self, email: str) -> User | None:
async def get_user(self, user_id: str) -> User | None:
table = await self.db.Table("Users")
user_dict = await table.get_item(Key={"user_id": user_id})
if "Item" not in user_dict:
return None
user = User.model_validate(user_dict["Item"])
return user

async def get_user_from_email(self, email: str) -> User | None:
table = await self.db.Table("Users")
user_dict = await table.query(IndexName="emailIndex", KeyConditionExpression=KeyCondition("email").eq(email))
items = user_dict["Items"]
Expand Down Expand Up @@ -44,14 +52,20 @@ async def add_token(self, token: Token) -> None:
table = await self.db.Table("Tokens")
await table.put_item(Item=token.model_dump())

async def get_token(self, email: str) -> Token | None:
async def get_token(self, token_id: str) -> Token | None:
table = await self.db.Table("Tokens")
token_dict = await table.query(IndexName="emailIndex", KeyConditionExpression=KeyCondition("email").eq(email))
if len(token_dict["Items"]) == 0:
token_dict = await table.get_item(Key={"token_id": token_id})
if "Item" not in token_dict:
return None
token = Token.model_validate(token_dict["Items"][0])
token = Token.model_validate(token_dict["Item"])
return token

async def get_user_tokens(self, user_id: str) -> list[Token]:
table = await self.db.Table("Tokens")
tokens = table.query(IndexName="userIdIndex", KeyConditionExpression=KeyCondition("user_id").eq(user_id))
tokens = [Token.model_validate(token) for token in await tokens]
return tokens


async def test_adhoc() -> None:
async with UserCrud() as crud:
Expand Down
50 changes: 21 additions & 29 deletions store/app/api/routers/users.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Defines the API endpoint for creating, deleting and updating user information."""

import asyncio
import datetime
import logging
import uuid
from email.utils import parseaddr as parse_email_address
from typing import Annotated

Expand All @@ -12,7 +12,7 @@
from pydantic.main import BaseModel

from store.app.api.db import Crud
from store.app.api.email import OneTimePassPayload, send_delete_email, send_otp_email, send_waitlist_email
from store.app.api.email import OneTimePassPayload, send_delete_email, send_otp_email
from store.app.api.model import User
from store.app.api.token import create_refresh_token, create_token, load_refresh_token, load_token
from store.settings import settings
Expand All @@ -39,31 +39,32 @@ def set_token_cookie(response: Response, token: str, key: str) -> None:


class RefreshTokenData(BaseModel):
email: str
user_id: str
token_id: str

@classmethod
async def encode(cls, user: User, crud: Crud) -> str:
return await create_refresh_token(user.email, crud)

@classmethod
def decode(cls, payload: str) -> "RefreshTokenData":
email = load_refresh_token(payload)
return cls(email=email)
user_id, token_id = load_refresh_token(payload)
return cls(user_id=user_id, token_id=token_id)


class SessionTokenData(BaseModel):
email: str
user_id: str
token_id: str

def encode(self) -> str:
expire_minutes = settings.crypto.expire_token_minutes
expire_after = datetime.timedelta(minutes=expire_minutes)
return create_token({"eml": self.email}, expire_after=expire_after)
return create_token({"uid": self.user_id, "tid": self.token_id}, expire_after=expire_after, extra="session")

@classmethod
def decode(cls, payload: str) -> "SessionTokenData":
data = load_token(payload)
email = data["eml"]
return cls(email=email)
data = load_token(payload, extra="session")
return cls(user_id=data["uid"], token_id=data["tid"])


class UserSignup(BaseModel):
Expand All @@ -79,6 +80,10 @@ def validate_email(email: str) -> str:
return email


def get_user_id() -> str:
return str(uuid.uuid4())


@users_router.post("/login")
async def login_user_endpoint(data: UserSignup) -> bool:
email = validate_email(data.email)
Expand All @@ -96,27 +101,14 @@ class UserLoginResponse(BaseModel):
token_type: str


async def add_to_waitlist(email: str, crud: Crud) -> None:
await asyncio.gather(
send_waitlist_email(email),
crud.add_user(User(email=email, banned=True)),
)


async def create_or_get(email: str, crud: Crud) -> User:
# Gets or creates the user object.
user_obj = await crud.get_user(email)
if user_obj is None:
await crud.add_user(User(email=email))
await crud.add_user(User(user_id=get_user_id(), email=email))
if (user_obj := await crud.get_user(email)) is None:
raise RuntimeError("Failed to add user to the database")

# Validates user.
if user_obj.banned:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User is not allowed to log in")
if user_obj.deleted:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User is deleted")

return user_obj


Expand Down Expand Up @@ -220,7 +212,7 @@ async def get_user_info_endpoint(
data: Annotated[SessionTokenData, Depends(get_session_token)],
crud: Annotated[Crud, Depends(Crud.get)],
) -> UserInfoResponse:
user_obj = await crud.get_user(data.email)
user_obj = await crud.get_user(data.user_id)
if user_obj is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found")
return UserInfoResponse(email=user_obj.email)
Expand All @@ -231,7 +223,7 @@ async def delete_user_endpoint(
data: Annotated[SessionTokenData, Depends(get_session_token)],
crud: Annotated[Crud, Depends(Crud.get)],
) -> bool:
user_obj = await crud.get_user(data.email)
user_obj = await crud.get_user(data.user_id)
if user_obj is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found")
await crud.delete_user(user_obj)
Expand Down Expand Up @@ -260,9 +252,9 @@ async def refresh_endpoint(
data: Annotated[RefreshTokenData, Depends(get_refresh_token)],
crud: Annotated[Crud, Depends(Crud.get)],
) -> RefreshTokenResponse:
token = await crud.get_token(data.email)
if not token or token.disabled:
token = await crud.get_token(data.token_id)
if token is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token")
session_token = SessionTokenData(email=data.email).encode()
session_token = SessionTokenData(user_id=data.user_id, token_id=data.token_id).encode()
set_token_cookie(response, session_token, SESSION_TOKEN_COOKIE_KEY)
return RefreshTokenResponse(token=session_token, token_type=TOKEN_TYPE)
39 changes: 29 additions & 10 deletions store/app/api/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import logging
import uuid

import jwt
from fastapi import HTTPException, status
Expand All @@ -16,18 +17,31 @@
TIME_FORMAT = "%Y-%m-%d %H:%M:%S"


def create_token(data: dict, expire_after: datetime.timedelta | None = None) -> str:
def get_token_id() -> str:
"""Generates a unique token ID.
Returns:
A unique token ID.
"""
return str(uuid.uuid4())


def create_token(data: dict, expire_after: datetime.timedelta | None = None, extra: str | None = None) -> str:
"""Creates a token from a dictionary.
The "exp" key is reserved for internal use.
Args:
data: The data to encode.
expire_after: If provided, token will expire after this amount of time.
extra: Additional secret to append to the secret key.
Returns:
The encoded JWT.
"""
secret = settings.crypto.jwt_secret
if extra is not None:
secret += extra
if "exp" in data:
raise ValueError("The payload should not contain an expiration time")
to_encode = data.copy()
Expand All @@ -38,46 +52,51 @@ def create_token(data: dict, expire_after: datetime.timedelta | None = None) ->
expires = server_time() + expire_after
to_encode.update({"exp": expires})

encoded_jwt = jwt.encode(to_encode, settings.crypto.jwt_secret, algorithm=settings.crypto.algorithm)
encoded_jwt = jwt.encode(to_encode, secret, algorithm=settings.crypto.algorithm)
return encoded_jwt


def load_token(payload: str) -> dict:
def load_token(payload: str, extra: str | None = None) -> dict:
"""Loads the token payload.
Args:
payload: The JWT-encoded payload.
only_once: If ``True``, the token will be marked as used.
extra: Additional secret to append to the secret key.
Returns:
The decoded payload.
"""
secret = settings.crypto.jwt_secret
if extra is not None:
secret += extra
try:
data: dict = jwt.decode(payload, settings.crypto.jwt_secret, algorithms=[settings.crypto.algorithm])
data: dict = jwt.decode(payload, secret, algorithms=[settings.crypto.algorithm])
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
return data


async def create_refresh_token(email: str, crud: Crud) -> str:
async def create_refresh_token(user_id: str, crud: Crud) -> str:
"""Creates a refresh token for a user.
Refresh tokens never expire. They are used to generate short-lived session
tokens which are used for authentication.
Args:
email: The email for the user associated with this token.
user_id: The user ID to associate with the token.
crud: The CRUD class for the databases.
Returns:
The encoded JWT.
"""
token = Token(email=email)
token_id = get_token_id()
token = Token(user_id=user_id, token_id=token_id)
await crud.add_token(token)
return create_token({"eml": email})
return create_token({"uid": user_id, "tid": token_id})


def load_refresh_token(payload: str) -> str:
def load_refresh_token(payload: str) -> tuple[str, str]:
"""Loads the refresh token payload.
Args:
Expand All @@ -87,4 +106,4 @@ def load_refresh_token(payload: str) -> str:
The decoded refresh token data.
"""
data = load_token(payload)
return data["eml"]
return data["uid"], data["tid"]

0 comments on commit 5cd8b39

Please sign in to comment.