From a704fea97488f8502932d65761a54b98a21d00e7 Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:32:02 -0400 Subject: [PATCH 1/9] convert from using tokens to using api keys --- store/app/api/crud/users.py | 49 +++++--- store/app/api/crypto.py | 33 +++++ store/app/api/db.py | 4 +- store/app/api/email.py | 12 +- store/app/api/model.py | 26 +++- store/app/api/routers/users.py | 214 +++++++++------------------------ store/app/api/token.py | 109 ----------------- tests/api/test_users.py | 62 ++++------ 8 files changed, 176 insertions(+), 333 deletions(-) create mode 100644 store/app/api/crypto.py delete mode 100644 store/app/api/token.py diff --git a/store/app/api/crud/users.py b/store/app/api/crud/users.py index f5189607..44ebe0ff 100644 --- a/store/app/api/crud/users.py +++ b/store/app/api/crud/users.py @@ -3,11 +3,13 @@ import asyncio import uuid import warnings +from typing import cast from boto3.dynamodb.conditions import Key as KeyCondition from store.app.api.crud.base import BaseCrud -from store.app.api.model import Token, User +from store.app.api.crypto import hash_api_key +from store.app.api.model import ApiKey, User class UserCrud(BaseCrud): @@ -15,9 +17,9 @@ async def add_user(self, user: User) -> None: table = await self.db.Table("Users") await table.put_item(Item=user.model_dump()) - async def get_user(self, user_id: str) -> User | None: + async def get_user(self, user_id: uuid.UUID) -> User | None: table = await self.db.Table("Users") - user_dict = await table.get_item(Key={"user_id": user_id}) + user_dict = await table.get_item(Key={"user_id": str(user_id)}) if "Item" not in user_dict: return None user = User.model_validate(user_dict["Item"]) @@ -34,6 +36,15 @@ async def get_user_from_email(self, email: str) -> User | None: user = User.model_validate(items[0]) return user + async def get_user_id_from_api_key(self, api_key: uuid.UUID) -> uuid.UUID | None: + table = await self.db.Table("ApiKeys") + api_key_hash = hash_api_key(api_key) + row = await table.get_item(Key={"api_key_hash": api_key_hash}) + if "Item" not in row: + return None + user_id = cast(str, row["Item"]["user_id"]) + return uuid.UUID(user_id) + async def delete_user(self, user: User) -> None: table = await self.db.Table("Users") await table.delete_item(Key={"user_id": user.user_id}) @@ -48,23 +59,21 @@ async def get_user_count(self) -> int: table = await self.db.Table("Users") return await table.item_count - async def add_token(self, token: Token) -> None: - table = await self.db.Table("Tokens") - await table.put_item(Item=token.model_dump()) - - async def get_token(self, token_id: str) -> Token | None: - table = await self.db.Table("Tokens") - token_dict = await table.get_item(Key={"token_id": token_id}) - if "Item" not in token_dict: - return None - token = Token.model_validate(token_dict["Item"]) - return token - - async def get_user_tokens(self, user_id: str) -> list[Token]: - table = await self.db.Table("Tokens") - tokens = table.query(IndexName="userIdIndex", KeyConditionExpression=KeyCondition("user_id").eq(user_id)) - tokens = [Token.model_validate(token) for token in await tokens] - return tokens + async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> None: + row = ApiKey.from_api_key(api_key, user_id) + table = await self.db.Table("ApiKeys") + await table.put_item(Item=row.model_dump()) + + async def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool: + table = await self.db.Table("ApiKeys") + row = await table.get_item(Key={"api_key_hash": hash_api_key(api_key)}) + if "Item" not in row: + return False + return row["Item"]["user_id"] == str(user_id) + + async def delete_api_key(self, api_key: uuid.UUID) -> None: + table = await self.db.Table("ApiKeys") + await table.delete_item(Key={"api_key_hash": hash_api_key(api_key)}) async def test_adhoc() -> None: diff --git a/store/app/api/crypto.py b/store/app/api/crypto.py new file mode 100644 index 00000000..fd725f21 --- /dev/null +++ b/store/app/api/crypto.py @@ -0,0 +1,33 @@ +"""Defines crypto functions.""" + +import datetime +import hashlib +import uuid +from typing import Any + +import jwt + +from store.settings import settings + + +def hash_api_key(api_key: uuid.UUID) -> str: + return hashlib.sha256(api_key.bytes).hexdigest() + + +def get_new_user_id() -> uuid.UUID: + return uuid.uuid4() + + +def get_new_api_key(user_id: uuid.UUID) -> uuid.UUID: + user_id_hash = hashlib.sha1(user_id.bytes).digest() + return uuid.UUID(bytes=user_id_hash[:16], version=5) + + +def encode_jwt(data: dict[str, Any], expire_after: datetime.timedelta | None = None) -> str: # noqa: ANN401 + if expire_after is not None: + data["exp"] = datetime.datetime.utcnow() + expire_after + return jwt.encode(data, settings.crypto.jwt_secret, algorithm=settings.crypto.algorithm) + + +def decode_jwt(token: str) -> dict[str, Any]: # noqa: ANN401 + return jwt.decode(token, settings.crypto.jwt_secret, algorithms=[settings.crypto.algorithm]) diff --git a/store/app/api/db.py b/store/app/api/db.py index 56722b6a..5bb4ec72 100644 --- a/store/app/api/db.py +++ b/store/app/api/db.py @@ -43,9 +43,9 @@ async def create_tables(crud: Crud | None = None) -> None: ], ) await crud._create_dynamodb_table( - name="Tokens", + name="ApiKeys", keys=[ - ("token_id", "S", "HASH"), + ("api_key_hash", "S", "HASH"), ], gsis=[ ("userIdIndex", "user_id", "S", "HASH"), diff --git a/store/app/api/email.py b/store/app/api/email.py index 857a5749..6b396a6b 100644 --- a/store/app/api/email.py +++ b/store/app/api/email.py @@ -11,7 +11,7 @@ import aiosmtplib -from store.app.api.token import create_token, load_token +from store.app.api.crypto import decode_jwt, encode_jwt from store.settings import settings logger = logging.getLogger(__name__) @@ -40,11 +40,11 @@ class OneTimePassPayload: def encode(self) -> str: expire_minutes = settings.crypto.expire_otp_minutes expire_after = datetime.timedelta(minutes=expire_minutes) - return create_token({"email": self.email}, expire_after=expire_after) + return encode_jwt({"email": self.email}, expire_after=expire_after) @classmethod def decode(cls, payload: str) -> "OneTimePassPayload": - data = load_token(payload) + data = decode_jwt(payload) return cls(email=data["email"]) @@ -53,7 +53,7 @@ async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None: body = textwrap.dedent( f""" -

don't panic
stay human

+

K-Scale Labs

log in

Or copy-paste this link: {url}

""" @@ -65,7 +65,7 @@ async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None: async def send_delete_email(email: str) -> None: body = textwrap.dedent( """ -

don't panic
stay human

+

K-Scale Labs

your account has been deleted

""" ) @@ -76,7 +76,7 @@ async def send_delete_email(email: str) -> None: async def send_waitlist_email(email: str) -> None: body = textwrap.dedent( """ -

don't panic
stay human

+

K-Scale Labs

you're on the waitlist!

Thanks for signing up! We'll let you know when you can log in.

""" diff --git a/store/app/api/model.py b/store/app/api/model.py index 0d35eea6..2fa2f165 100644 --- a/store/app/api/model.py +++ b/store/app/api/model.py @@ -1,22 +1,42 @@ -"""Defines the table models for the API.""" +"""Defines the table models for the API. + +These correspond directly with the rows in our database, and provide helper +methods for converting from our input data into the format the database +expects (for example, converting a UUID into a string). +""" import datetime +import uuid from dataclasses import field from decimal import Decimal from pydantic import BaseModel +from store.app.api.crypto import hash_api_key + class User(BaseModel): user_id: str # Primary key email: str + @classmethod + def from_uuid(cls, user_id: uuid.UUID, email: str) -> "User": + return cls(user_id=str(user_id), email=email) + + def to_uuid(self) -> uuid.UUID: + return uuid.UUID(self.user_id) -class Token(BaseModel): - token_id: str # Primary key + +class ApiKey(BaseModel): + api_key_hash: str # Primary key user_id: str issued: Decimal = field(default_factory=lambda: Decimal(datetime.datetime.now().timestamp())) + @classmethod + def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID) -> "ApiKey": + api_key_hash = hash_api_key(api_key) + return cls(api_key_hash=api_key_hash, user_id=str(user_id)) + class PurchaseLink(BaseModel): name: str diff --git a/store/app/api/routers/users.py b/store/app/api/routers/users.py index f27e43aa..6d30b2bf 100644 --- a/store/app/api/routers/users.py +++ b/store/app/api/routers/users.py @@ -1,29 +1,23 @@ """Defines the API endpoint for creating, deleting and updating user information.""" -import datetime import logging import uuid from email.utils import parseaddr as parse_email_address from typing import Annotated -import aiohttp from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.security.utils import get_authorization_scheme_param from pydantic.main import BaseModel +from store.app.api.crypto import get_new_api_key, get_new_user_id from store.app.api.db import Crud from store.app.api.email import OneTimePassPayload, send_delete_email, send_otp_email from store.app.api.model import User -from store.app.api.token import create_refresh_token, create_token, load_refresh_token, load_token -from store.settings import settings logger = logging.getLogger(__name__) users_router = APIRouter() -REFRESH_TOKEN_COOKIE_KEY = "__REFRESH_TOKEN" -SESSION_TOKEN_COOKIE_KEY = "__SESSION_TOKEN" - TOKEN_TYPE = "Bearer" @@ -38,33 +32,22 @@ def set_token_cookie(response: Response, token: str, key: str) -> None: ) -class RefreshTokenData(BaseModel): - user_id: str - token_id: str - - @classmethod - async def encode(cls, user: User, crud: Crud) -> str: - return await create_refresh_token(user.user_id, crud) - - @classmethod - def decode(cls, payload: str) -> "RefreshTokenData": - user_id, token_id = load_refresh_token(payload) - return cls(user_id=user_id, token_id=token_id) +class ApiKeyData(BaseModel): + api_key: uuid.UUID -class SessionTokenData(BaseModel): - user_id: str - token_id: str - - def encode(self) -> str: - expire_minutes = settings.crypto.expire_token_minutes - expire_after = datetime.timedelta(minutes=expire_minutes) - return create_token({"uid": self.user_id, "tid": self.token_id}, expire_after=expire_after, extra="session") +async def get_api_key(request: Request) -> ApiKeyData: + # Tries Authorization header. + authorization = request.headers.get("Authorization") or request.headers.get("authorization") + if authorization: + scheme, credentials = get_authorization_scheme_param(authorization) + if not (scheme and credentials): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + if scheme.lower() != TOKEN_TYPE.lower(): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + return ApiKeyData(api_key=uuid.UUID(credentials)) - @classmethod - def decode(cls, payload: str) -> "SessionTokenData": - data = load_token(payload, extra="session") - return cls(user_id=data["uid"], token_id=data["tid"]) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") class UserSignup(BaseModel): @@ -80,12 +63,17 @@ def validate_email(email: str) -> str: return email -def get_new_user_id() -> str: - return str(uuid.uuid4()) - - @users_router.post("/login") async def login_user_endpoint(data: UserSignup) -> bool: + """Takes the user email and sends them a one-time login password. + + Args: + data: The payload with the user email and the login URL to redirect to + when the user logs in. + + Returns: + True if the email was sent successfully. + """ email = validate_email(data.email) payload = OneTimePassPayload(email) await send_otp_email(payload, data.login_url) @@ -97,109 +85,39 @@ class OneTimePass(BaseModel): class UserLoginResponse(BaseModel): - token: str - token_type: str - - -async def create_or_get(email: str, crud: Crud) -> User: - # Gets or creates the user object. - user_obj = await crud.get_user_from_email(email) - if user_obj is None: - await crud.add_user(User(user_id=get_new_user_id(), email=email)) - if (user_obj := await crud.get_user_from_email(email)) is None: - raise RuntimeError("Failed to add user to the database") - return user_obj - - -async def get_login_response( - response: Response, - user_obj: User, - crud: Crud, -) -> UserLoginResponse: - refresh_token = await RefreshTokenData.encode(user_obj, crud) - set_token_cookie(response, refresh_token, REFRESH_TOKEN_COOKIE_KEY) - return UserLoginResponse(token=refresh_token, token_type=TOKEN_TYPE) + api_key: str @users_router.post("/otp", response_model=UserLoginResponse) async def otp_endpoint( data: OneTimePass, - response: Response, crud: Annotated[Crud, Depends(Crud.get)], ) -> UserLoginResponse: - payload = OneTimePassPayload.decode(data.payload) - user_obj = await create_or_get(payload.email, crud) - return await get_login_response(response, user_obj, crud) - - -class GoogleLogin(BaseModel): - token: str - + """Takes the one-time password and returns an API key. -async def get_google_user_info(token: str) -> dict: - async with aiohttp.ClientSession() as session: - response = await session.get( - "https://www.googleapis.com/oauth2/v3/userinfo", - params={"access_token": token}, - ) - if response.status != 200: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Google token") - return await response.json() - - -@users_router.post("/google") -async def google_login_endpoint( - data: GoogleLogin, - response: Response, - crud: Annotated[Crud, Depends(Crud.get)], -) -> UserLoginResponse: - try: - idinfo = await get_google_user_info(data.token) - email = idinfo["email"] - except ValueError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Google token") - if idinfo.get("email_verified") is not True: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Google email not verified") - user_obj = await create_or_get(email, crud) - return await get_login_response(response, user_obj, crud) - - -async def get_refresh_token(request: Request) -> RefreshTokenData: - # Tries Authorization header. - authorization = request.headers.get("Authorization") - if authorization: - scheme, credentials = get_authorization_scheme_param(authorization) - if not (scheme and credentials): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - if scheme.lower() != TOKEN_TYPE.lower(): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - return RefreshTokenData.decode(credentials) - - # Tries Cookie. - cookie_token = request.cookies.get(REFRESH_TOKEN_COOKIE_KEY) - if cookie_token: - return RefreshTokenData.decode(cookie_token) - - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + Args: + data: The one-time password payload. + crud: The database CRUD object. + Returns: + The API key if the one-time password is valid. + """ + payload = OneTimePassPayload.decode(data.payload) -async def get_session_token(request: Request) -> SessionTokenData: - # Tries Authorization header. - authorization = request.headers.get("Authorization") or request.headers.get("authorization") - if authorization: - scheme, credentials = get_authorization_scheme_param(authorization) - if not (scheme and credentials): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - if scheme.lower() != TOKEN_TYPE.lower(): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") - return SessionTokenData.decode(credentials) + # If the user doesn't exist, then create a new user. + email = payload.email + user_obj = await crud.get_user_from_email(email) + if user_obj is None: + await crud.add_user(User(user_id=str(get_new_user_id()), email=email)) + if (user_obj := await crud.get_user_from_email(email)) is None: + raise RuntimeError("Failed to add user to the database") - # Tries Cookie. - cookie_token = request.cookies.get(SESSION_TOKEN_COOKIE_KEY) - if cookie_token: - return SessionTokenData.decode(cookie_token) + # Issue a new API key for the user. + user_id: uuid.UUID = user_obj.to_uuid() + api_key: uuid.UUID = get_new_api_key(user_id) + await crud.add_api_key(api_key, user_id) - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") + return UserLoginResponse(api_key=str(api_key)) class UserInfoResponse(BaseModel): @@ -208,23 +126,29 @@ class UserInfoResponse(BaseModel): @users_router.get("/me", response_model=UserInfoResponse) async def get_user_info_endpoint( - data: Annotated[SessionTokenData, Depends(get_session_token)], + data: Annotated[ApiKeyData, Depends(get_api_key)], crud: Annotated[Crud, Depends(Crud.get)], ) -> UserInfoResponse: - user_obj = await crud.get_user(data.user_id) + user_id = await crud.get_user_id_from_api_key(data.api_key) + if user_id is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + user_obj = await crud.get_user(user_id) if user_obj is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") return UserInfoResponse(email=user_obj.email) @users_router.delete("/me") async def delete_user_endpoint( - data: Annotated[SessionTokenData, Depends(get_session_token)], + data: Annotated[ApiKeyData, Depends(get_api_key)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_obj = await crud.get_user(data.user_id) + user_id = await crud.get_user_id_from_api_key(data.api_key) + if user_id is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + user_obj = await crud.get_user(user_id) if user_obj is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not found") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") await crud.delete_user(user_obj) await send_delete_email(user_obj.email) return True @@ -232,28 +156,8 @@ async def delete_user_endpoint( @users_router.delete("/logout") async def logout_user_endpoint( - response: Response, - data: Annotated[SessionTokenData, Depends(get_session_token)], + data: Annotated[ApiKeyData, Depends(get_api_key)], + crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - response.delete_cookie(key=SESSION_TOKEN_COOKIE_KEY) - response.delete_cookie(key=REFRESH_TOKEN_COOKIE_KEY) + await crud.delete_api_key(data.api_key) return True - - -class RefreshTokenResponse(BaseModel): - token: str - token_type: str - - -@users_router.post("/refresh", response_model=RefreshTokenResponse) -async def refresh_endpoint( - response: Response, - data: Annotated[RefreshTokenData, Depends(get_refresh_token)], - crud: Annotated[Crud, Depends(Crud.get)], -) -> RefreshTokenResponse: - token = await crud.get_token(data.token_id) - if token is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token") - session_token = SessionTokenData(user_id=data.user_id, token_id=data.token_id).encode() - set_token_cookie(response, session_token, SESSION_TOKEN_COOKIE_KEY) - return RefreshTokenResponse(token=session_token, token_type=TOKEN_TYPE) diff --git a/store/app/api/token.py b/store/app/api/token.py deleted file mode 100644 index e68e9af7..00000000 --- a/store/app/api/token.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Defines functions for controlling access tokens.""" - -import datetime -import logging -import uuid - -import jwt -from fastapi import HTTPException, status - -from store.app.api.db import Crud -from store.app.api.model import Token -from store.settings import settings -from store.utils import server_time - -logger = logging.getLogger(__name__) - -TIME_FORMAT = "%Y-%m-%d %H:%M:%S" - - -def get_token_id() -> str: - """Generates a unique token ID. - - Returns: - A unique token ID. - """ - return str(uuid.uuid4()) - - -def create_token(data: dict, expire_after: datetime.timedelta | None = None, extra: str | None = None) -> str: - """Creates a token from a dictionary. - - The "exp" key is reserved for internal use. - - Args: - data: The data to encode. - expire_after: If provided, token will expire after this amount of time. - extra: Additional secret to append to the secret key. - - Returns: - The encoded JWT. - """ - secret = settings.crypto.jwt_secret - if extra is not None: - secret += extra - if "exp" in data: - raise ValueError("The payload should not contain an expiration time") - to_encode = data.copy() - - # JWT exp claim expects a timestamp in seconds. This will automatically be - # used to determine if the token is expired. - if expire_after is not None: - expires = server_time() + expire_after - to_encode.update({"exp": expires}) - - encoded_jwt = jwt.encode(to_encode, secret, algorithm=settings.crypto.algorithm) - return encoded_jwt - - -def load_token(payload: str, extra: str | None = None) -> dict: - """Loads the token payload. - - Args: - payload: The JWT-encoded payload. - only_once: If ``True``, the token will be marked as used. - extra: Additional secret to append to the secret key. - - Returns: - The decoded payload. - """ - secret = settings.crypto.jwt_secret - if extra is not None: - secret += extra - try: - data: dict = jwt.decode(payload, secret, algorithms=[settings.crypto.algorithm]) - except Exception: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") - return data - - -async def create_refresh_token(user_id: str, crud: Crud) -> str: - """Creates a refresh token for a user. - - Refresh tokens never expire. They are used to generate short-lived session - tokens which are used for authentication. - - Args: - user_id: The user ID to associate with the token. - crud: The CRUD class for the databases. - - Returns: - The encoded JWT. - """ - token_id = get_token_id() - token = Token(user_id=user_id, token_id=token_id) - await crud.add_token(token) - return create_token({"uid": user_id, "tid": token_id}) - - -def load_refresh_token(payload: str) -> tuple[str, str]: - """Loads the refresh token payload. - - Args: - payload: The JWT-encoded payload. - - Returns: - The decoded refresh token data. - """ - data = load_token(payload) - return data["uid"], data["tid"] diff --git a/tests/api/test_users.py b/tests/api/test_users.py index 1a1b9ce7..af8a62c3 100644 --- a/tests/api/test_users.py +++ b/tests/api/test_users.py @@ -14,68 +14,54 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType) test_email = "test@example.com" login_url = "/" - bad_actor_email = "badactor@gmail.com" - # Creates a bad actor user for testing admin actions later. - otp = OneTimePassPayload(email=bad_actor_email) - response = app_client.post("/api/users/otp", json={"payload": otp.encode()}) - assert response.status_code == 200, response.json() - - # Sends an email to the user with their one-time pass. - response = app_client.post( - "/api/users/login", - json={ - "email": test_email, - "login_url": login_url, - }, - ) + # Sends the one-time password to the test email. + response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url}) assert response.status_code == 200, response.json() assert mock_send_email.call_count == 1 - # Uses the one-time pass to set client cookies. + # Uses the one-time pass to get an API key. We need to make a new OTP + # manually because we can't send emails in unit tests. otp = OneTimePassPayload(email=test_email) response = app_client.post("/api/users/otp", json={"payload": otp.encode()}) assert response.status_code == 200, response.json() + response_data = response.json() + api_key = response_data["api_key"] - # Checks that we get a 401 without a session token. + # Checks that without the API key we get a 401 response. response = app_client.get("/api/users/me") assert response.status_code == 401, response.json() assert response.json()["detail"] == "Not authenticated" - # Get a session token. - response = app_client.post("/api/users/refresh") - assert response.status_code == 200, response.json() - assert response.json()["token_type"] == "Bearer" - - # Gets the user's profile using the token. - response = app_client.get("/api/users/me") + # Checks that with the API key we get a 200 response. + response = app_client.get("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() assert response.json()["email"] == test_email - # Log the user out. + # Checks that we can't log the user out without the API key. response = app_client.delete("/api/users/logout") + assert response.status_code == 401, response.json() + + # Log the user out, which deletes the API key. + response = app_client.delete("/api/users/logout", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() assert response.json() is True - # Check that the user cookie has been cleared. - response = app_client.get("/api/users/me") - assert response.status_code == 401, response.json() - assert response.json()["detail"] == "Not authenticated" + # Checks that we can no longer use that API key to get the user's info. + response = app_client.get("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) + assert response.status_code == 404, response.json() + assert response.json()["detail"] == "User not found" - # Log the user back in. + # Log the user back in, getting new API key. response = app_client.post("/api/users/otp", json={"payload": otp.encode()}) assert response.status_code == 200, response.json() - # Gets another session token. - response = app_client.post("/api/users/refresh") - assert response.status_code == 200, response.json() - - # Delete the user. - response = app_client.delete("/api/users/me") + # Delete the user using the new API key. + response = app_client.delete("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() assert response.json() is True - # Make sure the user is gone. - response = app_client.get("/api/users/me") - assert response.status_code == 400, response.json() + # Tries deleting the user again, which should fail. + response = app_client.delete("/api/users/me", headers={"Authorization": f"Bearer {api_key}"}) + assert response.status_code == 404, response.json() assert response.json()["detail"] == "User not found" From 8aaa3b0968daecea3c747a038129a01db616ff60 Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:38:36 -0400 Subject: [PATCH 2/9] convert frontend to only use api keys instead of refresh tokens --- frontend/src/hooks/auth.tsx | 134 +++++++----------------------------- 1 file changed, 24 insertions(+), 110 deletions(-) diff --git a/frontend/src/hooks/auth.tsx b/frontend/src/hooks/auth.tsx index 178afec2..df8b58c1 100644 --- a/frontend/src/hooks/auth.tsx +++ b/frontend/src/hooks/auth.tsx @@ -10,49 +10,28 @@ import { } from "react"; import { useNavigate, useSearchParams } from "react-router-dom"; -const REFRESH_TOKEN_KEY = "__REFRESH_TOKEN"; -const SESSION_TOKEN_KEY = "__SESSION_TOKEN"; - -type TokenType = "refresh" | "session"; - -const getLocalStorageValueKey = (tokenType: TokenType) => { - switch (tokenType) { - case "refresh": - return REFRESH_TOKEN_KEY; - case "session": - return SESSION_TOKEN_KEY; - default: - throw new Error("Invalid token type"); - } -}; +const API_KEY_ID = "__API_KEY"; -const getLocalStorageToken = (tokenType: TokenType): string | null => { - return localStorage.getItem(getLocalStorageValueKey(tokenType)); +const getLocalStorageApiKey = (): string | null => { + return localStorage.getItem(API_KEY_ID); }; -const setLocalStorageToken = (token: string, tokenType: TokenType) => { - localStorage.setItem(getLocalStorageValueKey(tokenType), token); +const setLocalStorageApiKey = (token: string) => { +localStorage.setItem(API_KEY_ID, token); }; -const deleteLocalStorageToken = (tokenType: TokenType) => { - localStorage.removeItem(getLocalStorageValueKey(tokenType)); +const deleteLocalStorageApiKey = () => { + localStorage.removeItem(API_KEY_ID); }; interface AuthenticationContextProps { - sessionToken: string | null; - setSessionToken: (token: string) => void; - refreshToken: string | null; - setRefreshToken: (token: string) => void; + apiKey: string | null; + setApiKey: (token: string) => void; logout: () => void; isAuthenticated: boolean; api: AxiosInstance; } -interface RefreshTokenResponse { - token: string; - token_type: string; -} - const AuthenticationContext = createContext< AuthenticationContextProps | undefined >(undefined); @@ -64,16 +43,11 @@ interface AuthenticationProviderProps { export const AuthenticationProvider = (props: AuthenticationProviderProps) => { const { children } = props; - const [sessionToken, setSessionToken] = useState( - getLocalStorageToken("session"), - ); - const [refreshToken, setRefreshToken] = useState( - getLocalStorageToken("refresh"), - ); + const [apiKey, setApiKey] = useState(getLocalStorageApiKey()); const navigate = useNavigate(); - const isAuthenticated = refreshToken !== null; + const isAuthenticated = apiKey !== null; const api = axios.create({ baseURL: BACKEND_URL, @@ -85,31 +59,23 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => { }); useEffect(() => { - if (sessionToken === null) { - deleteLocalStorageToken("session"); - } else { - setLocalStorageToken(sessionToken, "session"); - } - }, [sessionToken]); - - useEffect(() => { - if (refreshToken === null) { - deleteLocalStorageToken("refresh"); + if (apiKey === null) { + deleteLocalStorageApiKey(); } else { - setLocalStorageToken(refreshToken, "refresh"); + setLocalStorageApiKey(apiKey); } - }, [refreshToken]); + }, [apiKey]); const logout = useCallback(() => { - setSessionToken(null); - setRefreshToken(null); + setApiKey(null); navigate("/"); }, [navigate]); + // Adds the API key to the request header, if it is set. api.interceptors.request.use( (config) => { - if (sessionToken !== null) { - config.headers.Authorization = `Bearer ${sessionToken}`; + if (apiKey !== null) { + config.headers.Authorization = `Bearer ${apiKey}`; config.headers["Access-Control-Allow-Origin"] = "*"; } return config; @@ -119,63 +85,11 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => { }, ); - api.interceptors.response.use( - (response) => response, - async (error) => { - const originalRequest = error.config; - if (error.response.status === 401 && !originalRequest._retry) { - originalRequest._retry = true; - if (refreshToken === null) { - return Promise.reject(error); - } - - let localSessionToken; - try { - // Gets a new session token and try the request again. - const response = await baseApi.post( - "/users/refresh", - {}, - { - headers: { - Authorization: `Bearer ${refreshToken}`, - "Access-Control-Allow-Origin": "*", - }, - }, - ); - localSessionToken = response.data.token; - } catch (refreshError) { - if (isAxiosError(refreshError)) { - const axiosError = refreshError as AxiosError; - if (axiosError?.response?.status === 401) { - logout(); - } - } - return Promise.reject(refreshError); - } - - // Retry the request with the new session token. - setSessionToken(localSessionToken); - const updatedRequest = { - ...originalRequest, - headers: { - Authorization: `Bearer ${localSessionToken}`, - "Access-Control-Allow-Origin": "*", - }, - }; - return await baseApi(updatedRequest); - } - - return Promise.reject(error); - }, - ); - return ( { const [searchParams] = useSearchParams(); const navigate = useNavigate(); - const { setRefreshToken, api } = useAuthentication(); + const { setApiKey, api } = useAuthentication(); useEffect(() => { (async () => { @@ -220,7 +134,7 @@ export const OneTimePasswordWrapper = ({ const response = await api.post("/users/otp", { payload, }); - setRefreshToken(response.data.token); + setApiKey(response.data.token); navigate("/"); } catch (error) { // Do nothing @@ -229,7 +143,7 @@ export const OneTimePasswordWrapper = ({ } } })(); - }, [searchParams, navigate, setRefreshToken, api]); + }, [searchParams, navigate, setApiKey, api]); return <>{children}; }; From 129cd2751d7e5b085a46167b319e35b14119e147 Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:44:34 -0400 Subject: [PATCH 3/9] call the logout endpoint --- frontend/src/hooks/auth.tsx | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/frontend/src/hooks/auth.tsx b/frontend/src/hooks/auth.tsx index df8b58c1..7ecdd86b 100644 --- a/frontend/src/hooks/auth.tsx +++ b/frontend/src/hooks/auth.tsx @@ -17,7 +17,7 @@ const getLocalStorageApiKey = (): string | null => { }; const setLocalStorageApiKey = (token: string) => { -localStorage.setItem(API_KEY_ID, token); + localStorage.setItem(API_KEY_ID, token); }; const deleteLocalStorageApiKey = () => { @@ -67,8 +67,15 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => { }, [apiKey]); const logout = useCallback(() => { - setApiKey(null); - navigate("/"); + (async () => { + try { + await api.delete("/users/logout"); + setApiKey(null); + navigate("/"); + } catch (error) { + // Do nothing + } + })(); }, [navigate]); // Adds the API key to the request header, if it is set. From ce65ca4e3727635abb9b82f094306040573dccc2 Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:50:02 -0400 Subject: [PATCH 4/9] Add back google authentication --- .../components/auth/GoogleAuthComponent.tsx | 6 +- store/app/api/routers/users.py | 68 ++++++++++++++++--- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/frontend/src/components/auth/GoogleAuthComponent.tsx b/frontend/src/components/auth/GoogleAuthComponent.tsx index 8b89667f..1de8bd7a 100644 --- a/frontend/src/components/auth/GoogleAuthComponent.tsx +++ b/frontend/src/components/auth/GoogleAuthComponent.tsx @@ -18,7 +18,7 @@ const GoogleAuthComponentInner = () => { const [credential, setCredential] = useState(null); const [disableButton, setDisableButton] = useState(false); - const { setRefreshToken, api } = useAuthentication(); + const { setApiKey, api } = useAuthentication(); const { addAlert } = useAlertQueue(); useEffect(() => { @@ -28,7 +28,7 @@ const GoogleAuthComponentInner = () => { const response = await api.post("/users/google", { token: credential, }); - setRefreshToken(response.data.token); + setApiKey(response.data.token); } catch (error) { addAlert(humanReadableError(error), "error"); } finally { @@ -36,7 +36,7 @@ const GoogleAuthComponentInner = () => { } } })(); - }, [credential, setRefreshToken, api, addAlert]); + }, [credential, setApiKey, api, addAlert]); const login = useGoogleLogin({ onSuccess: (tokenResponse) => { diff --git a/store/app/api/routers/users.py b/store/app/api/routers/users.py index 6d30b2bf..dcd98315 100644 --- a/store/app/api/routers/users.py +++ b/store/app/api/routers/users.py @@ -5,6 +5,7 @@ from email.utils import parseaddr as parse_email_address from typing import Annotated +import aiohttp from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi.security.utils import get_authorization_scheme_param from pydantic.main import BaseModel @@ -88,24 +89,20 @@ class UserLoginResponse(BaseModel): api_key: str -@users_router.post("/otp", response_model=UserLoginResponse) -async def otp_endpoint( - data: OneTimePass, - crud: Annotated[Crud, Depends(Crud.get)], -) -> UserLoginResponse: - """Takes the one-time password and returns an API key. +async def get_login_response(email: str, crud: Crud) -> UserLoginResponse: + """Takes the user email and returns an API key. + + This function gets a user API key for an email which has been validated, + either through an OTP or through Google OAuth. Args: - data: The one-time password payload. + email: The validated email of the user. crud: The database CRUD object. Returns: - The API key if the one-time password is valid. + The API key for the user. """ - payload = OneTimePassPayload.decode(data.payload) - # If the user doesn't exist, then create a new user. - email = payload.email user_obj = await crud.get_user_from_email(email) if user_obj is None: await crud.add_user(User(user_id=str(get_new_user_id()), email=email)) @@ -120,6 +117,55 @@ async def otp_endpoint( return UserLoginResponse(api_key=str(api_key)) +@users_router.post("/otp", response_model=UserLoginResponse) +async def otp_endpoint( + data: OneTimePass, + crud: Annotated[Crud, Depends(Crud.get)], +) -> UserLoginResponse: + """Takes the one-time password and returns an API key. + + Args: + data: The one-time password payload. + crud: The database CRUD object. + + Returns: + The API key if the one-time password is valid. + """ + payload = OneTimePassPayload.decode(data.payload) + return await get_login_response(payload.email, crud) + + +async def get_google_user_info(token: str) -> dict: + async with aiohttp.ClientSession() as session: + response = await session.get( + "https://www.googleapis.com/oauth2/v3/userinfo", + params={"access_token": token}, + ) + if response.status != 200: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Google token") + return await response.json() + + +class GoogleLogin(BaseModel): + token: str # This is the token that Google gives us for authenticated users. + + +@users_router.post("/google") +async def google_login_endpoint( + data: GoogleLogin, + crud: Annotated[Crud, Depends(Crud.get)], +) -> UserLoginResponse: + try: + idinfo = await get_google_user_info(data.token) + email = idinfo["email"] + except ValueError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Google token") + if idinfo.get("email_verified") is not True: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Google email not verified") + + return await get_login_response(email, crud) + + class UserInfoResponse(BaseModel): email: str From dfaf27f178e48122e9949399dc21f7c87efd3bc9 Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:54:32 -0400 Subject: [PATCH 5/9] fix lint --- frontend/src/hooks/auth.tsx | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/frontend/src/hooks/auth.tsx b/frontend/src/hooks/auth.tsx index 7ecdd86b..ea0d9975 100644 --- a/frontend/src/hooks/auth.tsx +++ b/frontend/src/hooks/auth.tsx @@ -1,4 +1,4 @@ -import axios, { AxiosError, AxiosInstance, isAxiosError } from "axios"; +import axios, { AxiosInstance } from "axios"; import { BACKEND_URL } from "constants/backend"; import { createContext, @@ -54,10 +54,6 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => { withCredentials: true, }); - const baseApi = axios.create({ - baseURL: BACKEND_URL, - }); - useEffect(() => { if (apiKey === null) { deleteLocalStorageApiKey(); From 5a73fd94aba486cc765059150d63f1a5aec3b01d Mon Sep 17 00:00:00 2001 From: Isaac Light Date: Mon, 3 Jun 2024 03:57:05 -0400 Subject: [PATCH 6/9] add otp wrapper --- frontend/src/App.tsx | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index f142ee31..ac85810e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -2,7 +2,7 @@ import "bootstrap/dist/css/bootstrap.min.css"; import TopNavbar from "components/nav/TopNavbar"; import NotFoundRedirect from "components/NotFoundRedirect"; import { AlertQueue, AlertQueueProvider } from "hooks/alerts"; -import { AuthenticationProvider } from "hooks/auth"; +import { AuthenticationProvider, OneTimePasswordWrapper } from "hooks/auth"; import { ThemeProvider } from "hooks/theme"; import Home from "pages/Home"; import NotFound from "pages/NotFound"; @@ -24,15 +24,17 @@ const App = () => { - - } /> - } /> - } /> - } /> - } /> - } /> - } /> - + + + } /> + } /> + } /> + } /> + } /> + } /> + } /> + +