diff --git a/store/app/crud/base.py b/store/app/crud/base.py index b612af01..ac3b1da6 100644 --- a/store/app/crud/base.py +++ b/store/app/crud/base.py @@ -10,7 +10,6 @@ from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource from types_aiobotocore_s3.service_resource import S3ServiceResource -from store.app.crypto import hash_token from store.app.model import RobolistBaseModel TABLE_NAME = "Robolist" @@ -77,15 +76,6 @@ async def _add_item(self, item: RobolistBaseModel) -> None: item_data["type"] = item.__class__.__name__ await table.put_item(Item=item_data) - async def _add_hashed_item(self, item: RobolistBaseModel) -> None: - table = await self.db.Table(TABLE_NAME) - item_data = item.model_dump() - if "type" in item_data: - raise ValueError("Cannot add item with 'type' attribute") - item_data["type"] = item.__class__.__name__ - item_data["id"] = hash_token(item_data["id"]) - await table.put_item(Item=item_data) - async def _delete_item(self, item: RobolistBaseModel | str) -> None: table = await self.db.Table(TABLE_NAME) if isinstance(item, str): @@ -93,13 +83,6 @@ async def _delete_item(self, item: RobolistBaseModel | str) -> None: else: await table.delete_item(Key={"id": item.id}) - async def _delete_hashed_item(self, item: RobolistBaseModel | str) -> None: - table = await self.db.Table(TABLE_NAME) - if isinstance(item, str): - await table.delete_item(Key={"id": hash_token(item)}) - else: - await table.delete_item(Key={"id": hash_token(item.id)}) - async def _list_items( self, item_class: type[T], @@ -130,7 +113,11 @@ async def _list_items( return [self._validate_item(item, item_class) for item in items] async def _list( - self, item_class: type[T], page: int, sort_key: Callable[[T], int], search_query: str | None = None + self, + item_class: type[T], + page: int, + sort_key: Callable[[T], int], + search_query: str | None = None, ) -> tuple[list[T], bool]: if search_query: response = await self._list_items( @@ -187,7 +174,7 @@ def _validate_item(self, data: dict[str, Any], item_class: type[T]) -> T: async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: Literal[True]) -> T: ... @overload - async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: Literal[False]) -> T | None: ... + async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: bool = False) -> T | None: ... async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: bool = False) -> T | None: table = await self.db.Table(TABLE_NAME) @@ -204,11 +191,6 @@ async def _item_exists(self, item_id: str) -> bool: item_dict = await table.get_item(Key={"id": item_id}) return "Item" in item_dict - async def _hashed_item_exists(self, item_id: str) -> bool: - table = await self.db.Table(TABLE_NAME) - item_dict = await table.get_item(Key={"id": hash_token(item_id)}) - return "Item" in item_dict - async def _get_item_batch( self, item_ids: list[str], @@ -256,7 +238,7 @@ async def _get_unique_item_from_secondary_index( secondary_index_name: str, secondary_index_value: str, item_class: type[T], - throw_if_missing: Literal[False] = False, + throw_if_missing: bool = False, ) -> T | None: ... async def _get_unique_item_from_secondary_index( diff --git a/store/app/crud/users.py b/store/app/crud/users.py index a5af2d4d..ec5165ad 100644 --- a/store/app/crud/users.py +++ b/store/app/crud/users.py @@ -2,18 +2,12 @@ import asyncio import warnings -from datetime import datetime +from typing import Literal, overload from store.app.crud.base import BaseCrud, GlobalSecondaryIndex -from store.app.crypto import hash_token -from store.app.model import APIKey, OAuthKey, User +from store.app.model import APIKey, APIKeyPermissionSet, APIKeySource, OAuthKey, User from store.settings import settings -from store.utils import LRUCache - -# This dictionary is used to locally cache the last time a token was validated -# against the database. We give the tokens some buffer time to avoid hitting -# the database too often. -LAST_API_KEY_VALIDATION = LRUCache[str, tuple[datetime, bool]](2**20) +from store.utils import cache_result def github_auth_key(github_id: str) -> str: @@ -39,8 +33,14 @@ def get_gsis(cls) -> list[GlobalSecondaryIndex]: ("emailIndex", "email", "S", "HASH"), ] - async def get_user(self, id: str) -> User | None: - return await self._get_item(id, User, throw_if_missing=False) + @overload + async def get_user(self, id: str) -> User | None: ... + + @overload + async def get_user(self, id: str, throw_if_missing: Literal[True]) -> User: ... + + async def get_user(self, id: str, throw_if_missing: bool = False) -> User | None: + return await self._get_item(id, User, throw_if_missing=throw_if_missing) async def create_user_from_token(self, token: str, email: str) -> User: user = User.create(email=email) @@ -78,9 +78,9 @@ async def create_user_from_email(self, email: str) -> User: async def get_user_batch(self, ids: list[str]) -> list[User]: return await self._get_item_batch(ids, User) - async def get_user_from_api_key(self, key: str) -> User: - key = await self.get_api_key(key) - return await self._get_item(key.user_id, User, throw_if_missing=True) + async def get_user_from_api_key(self, api_key_id: str) -> User: + api_key = await self.get_api_key(api_key_id) + return await self._get_item(api_key.user_id, User, throw_if_missing=True) async def delete_user(self, id: str) -> None: await self._delete_item(id) @@ -92,38 +92,22 @@ async def list_users(self) -> list[User]: async def get_user_count(self) -> int: return await self._count_items(User) - async def get_api_key(self, id: str) -> APIKey: - hashed_id = hash_token(id) - return await self._get_item(hashed_id, APIKey, throw_if_missing=True) - - async def add_api_key(self, id: str) -> APIKey: - token = APIKey.create(id=id) - await self._add_hashed_item(token) + @cache_result(settings.crypto.cache_token_db_result_seconds) + async def get_api_key(self, api_key_id: str) -> APIKey: + return await self._get_item(api_key_id, APIKey, throw_if_missing=True) + + async def add_api_key( + self, + user_id: str, + source: APIKeySource, + permissions: APIKeyPermissionSet, + ) -> APIKey: + token = APIKey.create(user_id=user_id, source=source, permissions=permissions) + await self._add_item(token) return token async def delete_api_key(self, token: APIKey | str) -> None: - await self._delete_hashed_item(token) - - async def api_key_is_valid(self, token: str) -> bool: - """Validates a token against the database, with caching. - - In order to reduce the number of database queries, we locally cache - whether or not a token is valid for some amount of time. - - Args: - token: The token to validate. - - Returns: - If the token is valid, meaning, if it exists in the database. - """ - cur_time = datetime.now() - if token in LAST_API_KEY_VALIDATION: - last_time, is_valid = LAST_API_KEY_VALIDATION[token] - if (cur_time - last_time).seconds < settings.crypto.cache_token_db_result_seconds: - return is_valid - is_valid = await self._hashed_item_exists(token) - LAST_API_KEY_VALIDATION[token] = (cur_time, is_valid) - return is_valid + await self._delete_item(token) async def test_adhoc() -> None: diff --git a/store/app/crypto.py b/store/app/crypto.py deleted file mode 100644 index 7394b5dd..00000000 --- a/store/app/crypto.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Defines crypto functions.""" - -import hashlib -import secrets -import string -import uuid - -from argon2 import PasswordHasher - - -def new_token(length: int = 64) -> str: - """Generates a cryptographically secure random 64 character alphanumeric token.""" - return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) - - -def hash_token(token: str) -> str: - return hashlib.sha256(token.encode()).hexdigest() - - -def check_hash(token: str, hash: str) -> bool: - return hash_token(token) == hash - - -def new_uuid() -> uuid.UUID: - return uuid.uuid4() - - -def check_password(password: str, hash: str) -> bool: - try: - return PasswordHasher().verify(hash, password) - except Exception: - return False diff --git a/store/app/model.py b/store/app/model.py index 0492c850..f051a6a5 100644 --- a/store/app/model.py +++ b/store/app/model.py @@ -5,11 +5,11 @@ expects (for example, converting a UUID into a string). """ -from typing import Self +from typing import Literal, Self from pydantic import BaseModel -from store.app.crypto import new_uuid +from store.utils import new_uuid class RobolistBaseModel(BaseModel): @@ -53,21 +53,37 @@ def create(cls, token: str, user_id: str) -> Self: return cls(id=token, user_id=user_id) +APIKeySource = Literal["user", "oauth"] +APIKeyPermission = Literal["read", "write", "admin"] +APIKeyPermissionSet = set[APIKeyPermission] | Literal["full"] + + class APIKey(RobolistBaseModel): """The API key is used for querying the API. - Downstream users keep the JWT locally, and it is used to authenticate + Downstream users keep the API key, and it is used to authenticate requests to the API. The key is stored in the database, and can be revoked by the user at any time. """ user_id: str + source: APIKeySource + permissions: set[APIKeyPermission] @classmethod - def create(cls, id: str) -> Self: + def create( + cls, + user_id: str, + source: APIKeySource, + permissions: APIKeyPermissionSet, + ) -> Self: + if permissions == "full": + permissions = {"read", "write", "admin"} return cls( id=str(new_uuid()), - user_id=id, + user_id=user_id, + source=source, + permissions=permissions, ) diff --git a/store/app/routers/auth/github.py b/store/app/routers/auth/github.py index 643efb4c..b27191e4 100644 --- a/store/app/routers/auth/github.py +++ b/store/app/routers/auth/github.py @@ -93,7 +93,11 @@ async def github_code( github_id=github_id, ) - api_key = await crud.add_api_key(user.id) + api_key = await crud.add_api_key( + user_id=user.id, + source="oauth", + permissions="full", # OAuth tokens have full permissions. + ) response.set_cookie(key="session_token", value=api_key.id, httponly=True, samesite="lax") diff --git a/store/app/routers/image.py b/store/app/routers/image.py index 9be5d9a8..6575349d 100644 --- a/store/app/routers/image.py +++ b/store/app/routers/image.py @@ -8,9 +8,9 @@ from fastapi.responses import JSONResponse, RedirectResponse from PIL import Image -from store.app.crypto import new_uuid from store.app.db import Crud from store.settings import settings +from store.utils import new_uuid image_router = APIRouter() diff --git a/store/app/routers/part.py b/store/app/routers/part.py index c5e0323d..d2022c37 100644 --- a/store/app/routers/part.py +++ b/store/app/routers/part.py @@ -7,10 +7,13 @@ from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -from store.app.crypto import new_uuid from store.app.db import Crud -from store.app.model import Image, Part -from store.app.routers.users import get_session_token +from store.app.model import Image, Part, User +from store.app.routers.users import ( + get_session_user_with_read_permission, + get_session_user_with_write_permission, +) +from store.utils import new_uuid parts_router = APIRouter() @@ -34,11 +37,10 @@ async def dump_parts(crud: Annotated[Crud, Depends(Crud.get)]) -> list[Part]: @parts_router.get("/your/") async def list_your_parts( crud: Annotated[Crud, Depends(Crud.get)], - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_read_permission)], page: int = Query(description="Page number for pagination"), search_query: str = Query(None, description="Search query string"), ) -> tuple[list[Part], bool]: - user = await crud.get_user_from_api_key(token) return await crud.list_your_parts(user.id, page, search_query=search_query) @@ -49,10 +51,8 @@ async def get_part(part_id: str, crud: Annotated[Crud, Depends(Crud.get)]) -> Pa @parts_router.get("/user/") async def current_user( - crud: Annotated[Crud, Depends(Crud.get)], - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_read_permission)], ) -> str | None: - user = await crud.get_user_from_api_key(token) return user.id @@ -65,10 +65,9 @@ class NewPart(BaseModel): @parts_router.post("/add/") async def add_part( part: NewPart, - token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], + user: Annotated[User, Depends(get_session_user_with_write_permission)], ) -> bool: - user = await crud.get_user_from_api_key(token) await crud.add_part( Part( name=part.name, @@ -85,29 +84,26 @@ async def add_part( @parts_router.delete("/delete/{part_id}") async def delete_part( part_id: str, - token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], + user: Annotated[User, Depends(get_session_user_with_write_permission)], ) -> bool: part = await crud.get_part(part_id) if part is None: raise HTTPException(status_code=404, detail="Part not found") - user = await crud.get_user_from_api_key(token) if part.owner != user.id: raise HTTPException(status_code=403, detail="You do not own this part") await crud.delete_part(part_id) return True +# TODO: Improve part type annotations. @parts_router.post("/edit-part/{part_id}/") async def edit_part( part_id: str, - part: dict[ - str, Any - ], # There has got to be a better type annotation than this (possibly the deleted) EditPart class - token: Annotated[str, Depends(get_session_token)], + part: dict[str, Any], + user: Annotated[User, Depends(get_session_user_with_write_permission)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user = await crud.get_user_from_api_key(token) part_info = await crud.get_part(part_id) if part_info is None: raise HTTPException(status_code=404, detail="Part not found") diff --git a/store/app/routers/robot.py b/store/app/routers/robot.py index 50f8e925..089905f7 100644 --- a/store/app/routers/robot.py +++ b/store/app/routers/robot.py @@ -7,10 +7,13 @@ from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -from store.app.crypto import new_uuid from store.app.db import Crud -from store.app.model import Bom, Image, Package, Robot -from store.app.routers.users import get_session_token +from store.app.model import Bom, Image, Package, Robot, User +from store.app.routers.users import ( + get_session_user_with_read_permission, + get_session_user_with_write_permission, +) +from store.utils import new_uuid robots_router = APIRouter() @@ -46,22 +49,19 @@ async def list_robots( @robots_router.get("/your/") async def list_your_robots( crud: Annotated[Crud, Depends(Crud.get)], - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_read_permission)], page: int = Query(description="Page number for pagination"), search_query: str = Query(None, description="Search query string"), ) -> tuple[list[Robot], bool]: - user = await crud.get_user_from_api_key(token) return await crud.list_your_robots(user.id, page, search_query=search_query) @robots_router.post("/add/") async def add_robot( new_robot: NewRobot, - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_write_permission)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user = await crud.get_user_from_api_key(token) - await crud.add_robot( Robot( id=str(new_uuid()), @@ -84,13 +84,12 @@ async def add_robot( @robots_router.delete("/delete/{robot_id}/") async def delete_robot( robot_id: str, - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_write_permission)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: robot = await crud.get_robot(robot_id) if robot is None: raise HTTPException(status_code=404, detail="Robot not found") - user = await crud.get_user_from_api_key(token) if robot.owner != user.id: raise HTTPException(status_code=403, detail="You do not own this robot") await crud.delete_robot(robot_id) @@ -101,13 +100,12 @@ async def delete_robot( async def edit_robot( id: str, robot: dict[str, Any], - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_write_permission)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: robot_info = await crud.get_robot(id) if robot_info is None: raise HTTPException(status_code=404, detail="Robot not found") - user = await crud.get_user_from_api_key(token) if robot_info.owner != user.id: raise HTTPException(status_code=403, detail="You do not own this robot") robot["owner"] = user.id diff --git a/store/app/routers/users.py b/store/app/routers/users.py index 147369fb..2528018d 100644 --- a/store/app/routers/users.py +++ b/store/app/routers/users.py @@ -9,7 +9,7 @@ from pydantic.main import BaseModel as PydanticBaseModel from store.app.db import Crud -from store.app.model import UserPermissions +from store.app.model import User, UserPermissions from store.app.routers.auth.github import github_auth_router from store.app.utils.email import send_delete_email @@ -35,9 +35,9 @@ def set_token_cookie(response: Response, token: str, key: str) -> None: ) -async def get_session_token(request: Request) -> str: - token = request.cookies.get("session_token") - if not token: +async def get_request_api_key_id(request: Request) -> str: + api_key_id = request.cookies.get("session_token") + if not api_key_id: authorization = request.headers.get("Authorization") or request.headers.get("authorization") if authorization: scheme, credentials = get_authorization_scheme_param(authorization) @@ -56,7 +56,37 @@ async def get_session_token(request: Request) -> str: status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", ) - return token + return api_key_id + + +async def get_session_user_with_read_permission( + crud: Annotated[Crud, Depends(Crud.get)], + api_key_id: Annotated[str, Depends(get_request_api_key_id)], +) -> User: + api_key = await crud.get_api_key(api_key_id) + if "read" not in api_key.permissions: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied") + return await crud.get_user(api_key.user_id, throw_if_missing=True) + + +async def get_session_user_with_write_permission( + crud: Annotated[Crud, Depends(Crud.get)], + api_key_id: Annotated[str, Depends(get_request_api_key_id)], +) -> User: + api_key = await crud.get_api_key(api_key_id) + if "write" not in api_key.permissions: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied") + return await crud.get_user(api_key.user_id, throw_if_missing=True) + + +async def get_session_user_with_admin_permission( + crud: Annotated[Crud, Depends(Crud.get)], + api_key_id: Annotated[str, Depends(get_request_api_key_id)], +) -> User: + api_key = await crud.get_api_key(api_key_id) + if "admin" not in api_key.permissions: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied") + return await crud.get_user(api_key.user_id, throw_if_missing=True) def validate_email(email: str) -> str: @@ -82,11 +112,10 @@ class UserInfoResponse(BaseModel): @users_router.get("/me", response_model=UserInfoResponse) async def get_user_info_endpoint( - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_read_permission)], crud: Annotated[Crud, Depends(Crud.get)], ) -> UserInfoResponse | None: try: - user = await crud.get_user_from_api_key(token) return UserInfoResponse( id=user.id, permissions=user.permissions, @@ -97,10 +126,9 @@ async def get_user_info_endpoint( @users_router.delete("/me") async def delete_user_endpoint( - token: Annotated[str, Depends(get_session_token)], + user: Annotated[User, Depends(get_session_user_with_write_permission)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user = await crud.get_user_from_api_key(token) await crud.delete_user(user.id) await send_delete_email(user.email) return True @@ -108,7 +136,7 @@ async def delete_user_endpoint( @users_router.delete("/logout") async def logout_user_endpoint( - token: Annotated[str, Depends(get_session_token)], + token: Annotated[str, Depends(get_request_api_key_id)], crud: Annotated[Crud, Depends(Crud.get)], response: Response, ) -> bool: diff --git a/store/utils.py b/store/utils.py index 1e32018d..95db043f 100644 --- a/store/utils.py +++ b/store/utils.py @@ -1,11 +1,14 @@ """Defines package-wide utility functions.""" import datetime +import functools +import uuid from collections import OrderedDict -from typing import Generic, Hashable, TypeVar, overload +from typing import Callable, Generic, Hashable, ParamSpec, TypeVar, overload Tk = TypeVar("Tk", bound=Hashable) Tv = TypeVar("Tv") +P = ParamSpec("P") class LRUCache(Generic[Tk, Tv]): @@ -40,6 +43,9 @@ def put(self, key: Tk, value: Tv) -> None: if len(self.cache) > self.capacity: self.cache.popitem(last=False) + def pop(self, key: Tk) -> Tv: + return self.cache.pop(key) + def __getitem__(self, key: Tk) -> Tv: if (item := self.get(key)) is None: raise KeyError(key) @@ -49,5 +55,48 @@ def __setitem__(self, key: Tk, value: Tv) -> None: self.put(key, value) +def cache_result(num_seconds: float, capacity: int = 2**16) -> Callable[[Callable[P, Tv]], Callable[P, Tv]]: + """Cache the result of a function for a certain number of seconds. + + Usage: + + ```python + @cache_result(num_seconds=60) + def expensive_function(arg): + ... + ``` + + Args: + num_seconds: The number of seconds to cache the result. + capacity: The number of results to cache. + + Returns: + A decorator that caches the result of the function. + """ + + def decorator(func: Callable[P, Tv]) -> Callable[P, Tv]: + cache = LRUCache[str, tuple[datetime.datetime, Tv]](capacity) + + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Tv: + cur_time = datetime.datetime.now() + key = str((args, kwargs)) + if key in cache: + last_time, result = cache[key] + if (cur_time - last_time).total_seconds() < num_seconds: + return result + result = func(*args, **kwargs) + cache[key] = (cur_time, result) + return result + + return wrapper + + return decorator + + def server_time() -> datetime.datetime: return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + + +def new_uuid() -> uuid.UUID: + return uuid.uuid4()