From a704fea97488f8502932d65761a54b98a21d00e7 Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:32:02 -0400 Subject: [PATCH] convert from using tokens to using api keys --- store/app/api/crud/users.py | 49 +++++--- store/app/api/crypto.py | 33 +++++ store/app/api/db.py | 4 +- store/app/api/email.py | 12 +- store/app/api/model.py | 26 +++- store/app/api/routers/users.py | 214 +++++++++------------------------ store/app/api/token.py | 109 ----------------- tests/api/test_users.py | 62 ++++------ 8 files changed, 176 insertions(+), 333 deletions(-) create mode 100644 store/app/api/crypto.py delete mode 100644 store/app/api/token.py diff --git a/store/app/api/crud/users.py b/store/app/api/crud/users.py index f5189607..44ebe0ff 100644 --- a/store/app/api/crud/users.py +++ b/store/app/api/crud/users.py @@ -3,11 +3,13 @@ import asyncio import uuid import warnings +from typing import cast from boto3.dynamodb.conditions import Key as KeyCondition from store.app.api.crud.base import BaseCrud -from store.app.api.model import Token, User +from store.app.api.crypto import hash_api_key +from store.app.api.model import ApiKey, User class UserCrud(BaseCrud): @@ -15,9 +17,9 @@ async def add_user(self, user: User) -> None: table = await self.db.Table("Users") await table.put_item(Item=user.model_dump()) - async def get_user(self, user_id: str) -> User | None: + async def get_user(self, user_id: uuid.UUID) -> User | None: table = await self.db.Table("Users") - user_dict = await table.get_item(Key={"user_id": user_id}) + user_dict = await table.get_item(Key={"user_id": str(user_id)}) if "Item" not in user_dict: return None user = User.model_validate(user_dict["Item"]) @@ -34,6 +36,15 @@ async def get_user_from_email(self, email: str) -> User | None: user = User.model_validate(items[0]) return user + async def get_user_id_from_api_key(self, api_key: uuid.UUID) -> uuid.UUID | None: + table = await self.db.Table("ApiKeys") + api_key_hash = hash_api_key(api_key) + row = await table.get_item(Key={"api_key_hash": api_key_hash}) + if "Item" not in row: + return None + user_id = cast(str, row["Item"]["user_id"]) + return uuid.UUID(user_id) + async def delete_user(self, user: User) -> None: table = await self.db.Table("Users") await table.delete_item(Key={"user_id": user.user_id}) @@ -48,23 +59,21 @@ async def get_user_count(self) -> int: table = await self.db.Table("Users") return await table.item_count - async def add_token(self, token: Token) -> None: - table = await self.db.Table("Tokens") - await table.put_item(Item=token.model_dump()) - - async def get_token(self, token_id: str) -> Token | None: - table = await self.db.Table("Tokens") - token_dict = await table.get_item(Key={"token_id": token_id}) - if "Item" not in token_dict: - return None - token = Token.model_validate(token_dict["Item"]) - return token - - async def get_user_tokens(self, user_id: str) -> list[Token]: - table = await self.db.Table("Tokens") - tokens = table.query(IndexName="userIdIndex", KeyConditionExpression=KeyCondition("user_id").eq(user_id)) - tokens = [Token.model_validate(token) for token in await tokens] - return tokens + async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> None: + row = ApiKey.from_api_key(api_key, user_id) + table = await self.db.Table("ApiKeys") + await table.put_item(Item=row.model_dump()) + + async def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool: + table = await self.db.Table("ApiKeys") + row = await table.get_item(Key={"api_key_hash": hash_api_key(api_key)}) + if "Item" not in row: + return False + return row["Item"]["user_id"] == str(user_id) + + async def delete_api_key(self, api_key: uuid.UUID) -> None: + table = await self.db.Table("ApiKeys") + await table.delete_item(Key={"api_key_hash": hash_api_key(api_key)}) async def test_adhoc() -> None: diff --git a/store/app/api/crypto.py b/store/app/api/crypto.py new file mode 100644 index 00000000..fd725f21 --- /dev/null +++ b/store/app/api/crypto.py @@ -0,0 +1,33 @@ +"""Defines crypto functions.""" + +import datetime +import hashlib +import uuid +from typing import Any + +import jwt + +from store.settings import settings + + +def hash_api_key(api_key: uuid.UUID) -> str: + return hashlib.sha256(api_key.bytes).hexdigest() + + +def get_new_user_id() -> uuid.UUID: + return uuid.uuid4() + + +def get_new_api_key(user_id: uuid.UUID) -> uuid.UUID: + user_id_hash = hashlib.sha1(user_id.bytes).digest() + return uuid.UUID(bytes=user_id_hash[:16], version=5) + + +def encode_jwt(data: dict[str, Any], expire_after: datetime.timedelta | None = None) -> str: # noqa: ANN401 + if expire_after is not None: + data["exp"] = datetime.datetime.utcnow() + expire_after + return jwt.encode(data, settings.crypto.jwt_secret, algorithm=settings.crypto.algorithm) + + +def decode_jwt(token: str) -> dict[str, Any]: # noqa: ANN401 + return jwt.decode(token, settings.crypto.jwt_secret, algorithms=[settings.crypto.algorithm]) diff --git a/store/app/api/db.py b/store/app/api/db.py index 56722b6a..5bb4ec72 100644 --- a/store/app/api/db.py +++ b/store/app/api/db.py @@ -43,9 +43,9 @@ async def create_tables(crud: Crud | None = None) -> None: ], ) await crud._create_dynamodb_table( - name="Tokens", + name="ApiKeys", keys=[ - ("token_id", "S", "HASH"), + ("api_key_hash", "S", "HASH"), ], gsis=[ ("userIdIndex", "user_id", "S", "HASH"), diff --git a/store/app/api/email.py b/store/app/api/email.py index 857a5749..6b396a6b 100644 --- a/store/app/api/email.py +++ b/store/app/api/email.py @@ -11,7 +11,7 @@ import aiosmtplib -from store.app.api.token import create_token, load_token +from store.app.api.crypto import decode_jwt, encode_jwt from store.settings import settings logger = logging.getLogger(__name__) @@ -40,11 +40,11 @@ class OneTimePassPayload: def encode(self) -> str: expire_minutes = settings.crypto.expire_otp_minutes expire_after = datetime.timedelta(minutes=expire_minutes) - return create_token({"email": self.email}, expire_after=expire_after) + return encode_jwt({"email": self.email}, expire_after=expire_after) @classmethod def decode(cls, payload: str) -> "OneTimePassPayload": - data = load_token(payload) + data = decode_jwt(payload) return cls(email=data["email"]) @@ -53,7 +53,7 @@ async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None: body = textwrap.dedent( f""" -

don't panic
stay human

+

K-Scale Labs

log in

Or copy-paste this link: {url}

""" @@ -65,7 +65,7 @@ async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None: async def send_delete_email(email: str) -> None: body = textwrap.dedent( """ -

don't panic
stay human

+

K-Scale Labs

your account has been deleted

""" ) @@ -76,7 +76,7 @@ async def send_delete_email(email: str) -> None: async def send_waitlist_email(email: str) -> None: body = textwrap.dedent( """ -

don't panic
stay human

+

K-Scale Labs

you're on the waitlist!

Thanks for signing up! We'll let you know when you can log in.

""" diff --git a/store/app/api/model.py b/store/app/api/model.py index 0d35eea6..2fa2f165 100644 --- a/store/app/api/model.py +++ b/store/app/api/model.py @@ -1,22 +1,42 @@ -"""Defines the table models for the API.""" +"""Defines the table models for the API. + +These correspond directly with the rows in our database, and provide helper +methods for converting from our input data into the format the database +expects (for example, converting a UUID into a string). +""" import datetime +import uuid from dataclasses import field from decimal import Decimal from pydantic import BaseModel +from store.app.api.crypto import hash_api_key + class User(BaseModel): user_id: str # Primary key email: str + @classmethod + def from_uuid(cls, user_id: uuid.UUID, email: str) -> "User": + return cls(user_id=str(user_id), email=email) + + def to_uuid(self) -> uuid.UUID: + return uuid.UUID(self.user_id) -class Token(BaseModel): - token_id: str # Primary key + +class ApiKey(BaseModel): + api_key_hash: str # Primary key user_id: str issued: Decimal = field(default_factory=lambda: Decimal(datetime.datetime.now().timestamp())) + @classmethod + def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID) -> "ApiKey": + api_key_hash = hash_api_key(api_key) + return cls(api_key_hash=api_key_hash, user_id=str(user_id)) + class PurchaseLink(BaseModel): name: str diff --git a/store/app/api/routers/users.py b/store/app/api/routers/users.py index f27e43aa..6d30b2bf 100644 --- a/store/app/api/routers/users.py +++ b/store/app/api/routers/users.py @@ -1,29 +1,23 @@ """Defines the API endpoint for creating, deleting and updating user information.""" -import datetime import logging import uuid from email.utils import parseaddr as parse_email_address from typing import Annotated -import aiohttp from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.security.utils import get_authorization_scheme_param from pydantic.main import BaseModel +from store.app.api.crypto import get_new_api_key, get_new_user_id from store.app.api.db import Crud from store.app.api.email import OneTimePassPayload, send_delete_email, send_otp_email from store.app.api.model import User -from store.app.api.token import create_refresh_token, create_token, load_refresh_token, load_token -from store.settings import settings logger = logging.getLogger(__name__) users_router = APIRouter() -REFRESH_TOKEN_COOKIE_KEY = "__REFRESH_TOKEN" -SESSION_TOKEN_COOKIE_KEY = "__SESSION_TOKEN" - TOKEN_TYPE = "Bearer" @@ -38,33 +32,22 @@ def set_token_cookie(response: Response, token: str, key: str) -> None: ) -class RefreshTokenData(BaseModel): - user_id: str - token_id: str - - @classmethod - async def encode(cls, user: User, crud: Crud) -> str: - return await create_refresh_token(user.user_id, crud) - - @classmethod - def decode(cls, payload: str) -> "RefreshTokenData": - user_id, token_id = load_refresh_token(payload) - return cls(user_id=user_id, token_id=token_id) +class ApiKeyData(BaseModel): + api_key: uuid.UUID -class SessionTokenData(BaseModel): - user_id: str - token_id: str - - def encode(self) -> str: - expire_minutes = settings.crypto.expire_token_minutes - expire_after = datetime.timedelta(minutes=expire_minutes) - return create_token({"uid": self.user_id, "tid": self.token_id}, expire_after=expire_after, extra="session") +async def get_api_key(request: Request) -> ApiKeyData: + # Tries Authorization header. + authorization = request.headers.get("Authorization") or request.headers.get("authorization") + if authorization: + scheme, credentials = get_authorization_scheme_param(authorization) + if not (scheme and credentials): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + if scheme.lower() != TOKEN_TYPE.lower(): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + return ApiKeyData(api_key=uuid.UUID(credentials)) - @classmethod - def decode(cls, payload: str) -> "SessionTokenData": - data = load_token(payload, extra="session") - return cls(user_id=data["uid"], token_id=data["tid"]) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") class UserSignup(BaseModel): @@ -80,12 +63,17 @@ def validate_email(email: str) -> str: return email -def get_new_user_id() -> str: - return str(uuid.uuid4()) - - @users_router.post("/login") async def login_user_endpoint(data: UserSignup) -> bool: + """Takes the user email and sends them a one-time login password. + + Args: + data: The payload with the user email and the login URL to redirect to + when the user logs in. + + Returns: + True if the email was sent successfully. + """ email = validate_email(data.email) payload = OneTimePassPayload(email) await send_otp_email(payload, data.login_url) @@ -97,109 +85,39 @@ class OneTimePass(BaseModel): class UserLoginResponse(BaseModel): - token: str - token_type: str - - -async def create_or_get(email: str, crud: Crud) -> User: - # Gets or creates the user object. - user_obj = await crud.get_user_from_email(email) - if user_obj is None: - await crud.add_user(User(user_id=get_new_user_id(), email=email)) - if (user_obj := await crud.get_user_from_email(email)) is None: - raise RuntimeError("Failed to add user to the database") - return user_obj - - -async def get_login_response( - response: Response, - user_obj: User, - crud: Crud, -) -> UserLoginResponse: - refresh_token = await RefreshTokenData.encode(user_obj, crud) - set_token_cookie(response, refresh_token, REFRESH_TOKEN_COOKIE_KEY) - return UserLoginResponse(token=refresh_token, token_type=TOKEN_TYPE) + api_key: str @users_router.post("/otp", response_model=UserLoginResponse) async def otp_endpoint( data: OneTimePass, - response: Response, crud: Annotated[Crud, Depends(Crud.get)], ) -> UserLoginResponse: - payload = OneTimePassPayload.decode(data.payload) - user_obj = await create_or_get(payload.email, crud) - return await get_login_response(response, user_obj, crud) - - -class GoogleLogin(BaseModel): - token: str - + """Takes the one-time password and returns an API key. -async def get_google_user_info(token: str) -> dict: - async with aiohttp.ClientSession() as session: - response = await session.get( - "https://www.googleapis.com/oauth2/v3/userinfo", - params={"access_token": token}, - ) - if response.status != 200: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Google token") - return await response.json() - - -@users_router.post("/google") -async def google_login_endpoint( - data: GoogleLogin, - response: Response, - crud: Annotated[Crud, Depends(Crud.get)], -) -> UserLoginResponse: - try: - idinfo = await get_google_user_info(data.token) - email = idinfo["email"] - except ValueError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Google token") - if idinfo.get("email_verified") is not True: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Google email not verified") - user_obj = await create_or_get(email, crud) - return await get_login_response(response, user_obj, crud) - - -async def get_refresh_token(request: Request) -> RefreshTokenData: - # Tries Authorization header. - authorization = request.headers.get("Authorization") - if authorization: - scheme, credentials = get_authorization_scheme_param(authorization) - if not (scheme and credentials): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - if scheme.lower() != TOKEN_TYPE.lower(): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - return RefreshTokenData.decode(credentials) - - # Tries Cookie. - cookie_token = request.cookies.get(REFRESH_TOKEN_COOKIE_KEY) - if cookie_token: - return RefreshTokenData.decode(cookie_token) - - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + Args: + data: The one-time password payload. + crud: The database CRUD object. + Returns: + The API key if the one-time password is valid. + """ + payload = OneTimePassPayload.decode(data.payload) -async def get_session_token(request: Request) -> SessionTokenData: - # Tries Authorization header. - authorization = request.headers.get("Authorization") or request.headers.get("authorization") - if authorization: - scheme, credentials = get_authorization_scheme_param(authorization) - if not (scheme and credentials): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - if scheme.lower() != TOKEN_TYPE.lower(): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - return SessionTokenData.decode(credentials) + # If the user doesn't exist, then create a new user. + email = payload.email + user_obj = await crud.get_user_from_email(email) + if user_obj is None: + await crud.add_user(User(user_id=str(get_new_user_id()), email=email)) + if (user_obj := await crud.get_user_from_email(email)) is None: + raise RuntimeError("Failed to add user to the database") - # Tries Cookie. - cookie_token = request.cookies.get(SESSION_TOKEN_COOKIE_KEY) - if cookie_token: - return SessionTokenData.decode(cookie_token) + # Issue a new API key for the user. + user_id: uuid.UUID = user_obj.to_uuid() + api_key: uuid.UUID = get_new_api_key(user_id) + await crud.add_api_key(api_key, user_id) - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + return UserLoginResponse(api_key=str(api_key)) class UserInfoResponse(BaseModel): @@ -208,23 +126,29 @@ class UserInfoResponse(BaseModel): @users_router.get("/me", response_model=UserInfoResponse) async def get_user_info_endpoint( - data: Annotated[SessionTokenData, Depends(get_session_token)], + data: Annotated[ApiKeyData, Depends(get_api_key)], crud: Annotated[Crud, Depends(Crud.get)], ) -> UserInfoResponse: - user_obj = await crud.get_user(data.user_id) + user_id = await crud.get_user_id_from_api_key(data.api_key) + if user_id is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + user_obj = await crud.get_user(user_id) if user_obj is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") return UserInfoResponse(email=user_obj.email) @users_router.delete("/me") async def delete_user_endpoint( - data: Annotated[SessionTokenData, Depends(get_session_token)], + data: Annotated[ApiKeyData, Depends(get_api_key)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_obj = await crud.get_user(data.user_id) + user_id = await crud.get_user_id_from_api_key(data.api_key) + if user_id is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + user_obj = await crud.get_user(user_id) if user_obj is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") await crud.delete_user(user_obj) await send_delete_email(user_obj.email) return True @@ -232,28 +156,8 @@ async def delete_user_endpoint( @users_router.delete("/logout") async def logout_user_endpoint( - response: Response, - data: Annotated[SessionTokenData, Depends(get_session_token)], + data: Annotated[ApiKeyData, Depends(get_api_key)], + crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - response.delete_cookie(key=SESSION_TOKEN_COOKIE_KEY) - response.delete_cookie(key=REFRESH_TOKEN_COOKIE_KEY) + await crud.delete_api_key(data.api_key) return True - - -class RefreshTokenResponse(BaseModel): - token: str - token_type: str - - -@users_router.post("/refresh", response_model=RefreshTokenResponse) -async def refresh_endpoint( - response: Response, - data: Annotated[RefreshTokenData, Depends(get_refresh_token)], - crud: Annotated[Crud, Depends(Crud.get)], -) -> RefreshTokenResponse: - token = await crud.get_token(data.token_id) - if token is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token") - session_token = SessionTokenData(user_id=data.user_id, token_id=data.token_id).encode() - set_token_cookie(response, session_token, SESSION_TOKEN_COOKIE_KEY) - return RefreshTokenResponse(token=session_token, token_type=TOKEN_TYPE) diff --git a/store/app/api/token.py b/store/app/api/token.py deleted file mode 100644 index e68e9af7..00000000 --- a/store/app/api/token.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Defines functions for controlling access tokens.""" - -import datetime -import logging -import uuid - -import jwt -from fastapi import HTTPException, status - -from store.app.api.db import Crud -from store.app.api.model import Token -from store.settings import settings -from store.utils import server_time - -logger = logging.getLogger(__name__) - -TIME_FORMAT = "%Y-%m-%d %H:%M:%S" - - -def get_token_id() -> str: - """Generates a unique token ID. - - Returns: - A unique token ID. - """ - return str(uuid.uuid4()) - - -def create_token(data: dict, expire_after: datetime.timedelta | None = None, extra: str | None = None) -> str: - """Creates a token from a dictionary. - - The "exp" key is reserved for internal use. - - Args: - data: The data to encode. - expire_after: If provided, token will expire after this amount of time. - extra: Additional secret to append to the secret key. - - Returns: - The encoded JWT. - """ - secret = settings.crypto.jwt_secret - if extra is not None: - secret += extra - if "exp" in data: - raise ValueError("The payload should not contain an expiration time") - to_encode = data.copy() - - # JWT exp claim expects a timestamp in seconds. This will automatically be - # used to determine if the token is expired. - if expire_after is not None: - expires = server_time() + expire_after - to_encode.update({"exp": expires}) - - encoded_jwt = jwt.encode(to_encode, secret, algorithm=settings.crypto.algorithm) - return encoded_jwt - - -def load_token(payload: str, extra: str | None = None) -> dict: - """Loads the token payload. - - Args: - payload: The JWT-encoded payload. - only_once: If ``True``, the token will be marked as used. - extra: Additional secret to append to the secret key. - - Returns: - The decoded payload. - """ - secret = settings.crypto.jwt_secret - if extra is not None: - secret += extra - try: - data: dict = jwt.decode(payload, secret, algorithms=[settings.crypto.algorithm]) - except Exception: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") - return data - - -async def create_refresh_token(user_id: str, crud: Crud) -> str: - """Creates a refresh token for a user. - - Refresh tokens never expire. They are used to generate short-lived session - tokens which are used for authentication. - - Args: - user_id: The user ID to associate with the token. - crud: The CRUD class for the databases. - - Returns: - The encoded JWT. - """ - token_id = get_token_id() - token = Token(user_id=user_id, token_id=token_id) - await crud.add_token(token) - return create_token({"uid": user_id, "tid": token_id}) - - -def load_refresh_token(payload: str) -> tuple[str, str]: - """Loads the refresh token payload. - - Args: - payload: The JWT-encoded payload. - - Returns: - The decoded refresh token data. - """ - data = load_token(payload) - return data["uid"], data["tid"] diff --git a/tests/api/test_users.py b/tests/api/test_users.py index 1a1b9ce7..af8a62c3 100644 --- a/tests/api/test_users.py +++ b/tests/api/test_users.py @@ -14,68 +14,54 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType) test_email = "test@example.com" login_url = "/" - bad_actor_email = "badactor@gmail.com" - # Creates a bad actor user for testing admin actions later. - otp = OneTimePassPayload(email=bad_actor_email) - response = app_client.post("/api/users/otp", json={"payload": otp.encode()}) - assert response.status_code == 200, response.json() - - # Sends an email to the user with their one-time pass. - response = app_client.post( - "/api/users/login", - json={ - "email": test_email, - "login_url": login_url, - }, - ) + # Sends the one-time password to the test email. + response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url}) assert response.status_code == 200, response.json() assert mock_send_email.call_count == 1 - # Uses the one-time pass to set client cookies. + # Uses the one-time pass to get an API key. We need to make a new OTP + # manually because we can't send emails in unit tests. otp = OneTimePassPayload(email=test_email) response = app_client.post("/api/users/otp", json={"payload": otp.encode()}) assert response.status_code == 200, response.json() + response_data = response.json() + api_key = response_data["api_key"] - # Checks that we get a 401 without a session token. + # Checks that without the API key we get a 401 response. response = app_client.get("/api/users/me") assert response.status_code == 401, response.json() assert response.json()["detail"] == "Not authenticated" - # Get a session token. - response = app_client.post("/api/users/refresh") - assert response.status_code == 200, response.json() - assert response.json()["token_type"] == "Bearer" - - # Gets the user's profile using the token. - response = app_client.get("/api/users/me") + # Checks that with the API key we get a 200 response. + response = app_client.get("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() assert response.json()["email"] == test_email - # Log the user out. + # Checks that we can't log the user out without the API key. response = app_client.delete("/api/users/logout") + assert response.status_code == 401, response.json() + + # Log the user out, which deletes the API key. + response = app_client.delete("/api/users/logout", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() assert response.json() is True - # Check that the user cookie has been cleared. - response = app_client.get("/api/users/me") - assert response.status_code == 401, response.json() - assert response.json()["detail"] == "Not authenticated" + # Checks that we can no longer use that API key to get the user's info. + response = app_client.get("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) + assert response.status_code == 404, response.json() + assert response.json()["detail"] == "User not found" - # Log the user back in. + # Log the user back in, getting new API key. response = app_client.post("/api/users/otp", json={"payload": otp.encode()}) assert response.status_code == 200, response.json() - # Gets another session token. - response = app_client.post("/api/users/refresh") - assert response.status_code == 200, response.json() - - # Delete the user. - response = app_client.delete("/api/users/me") + # Delete the user using the new API key. + response = app_client.delete("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() assert response.json() is True - # Make sure the user is gone. - response = app_client.get("/api/users/me") - assert response.status_code == 400, response.json() + # Tries deleting the user again, which should fail. + response = app_client.delete("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) + assert response.status_code == 404, response.json() assert response.json()["detail"] == "User not found"