Skip to content

Commit

Permalink
add more fields to tokens (#174)
Browse files Browse the repository at this point in the history
* add more fields to tokens

* stuff

* more permissions

* rename, remove dead code

* ugh drop jwts

* permissions
  • Loading branch information
codekansas authored Jul 25, 2024
1 parent 8812d0c commit 3907bac
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 147 deletions.
32 changes: 7 additions & 25 deletions store/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource
from types_aiobotocore_s3.service_resource import S3ServiceResource

from store.app.crypto import hash_token
from store.app.model import RobolistBaseModel

TABLE_NAME = "Robolist"
Expand Down Expand Up @@ -77,29 +76,13 @@ async def _add_item(self, item: RobolistBaseModel) -> None:
item_data["type"] = item.__class__.__name__
await table.put_item(Item=item_data)

async def _add_hashed_item(self, item: RobolistBaseModel) -> None:
table = await self.db.Table(TABLE_NAME)
item_data = item.model_dump()
if "type" in item_data:
raise ValueError("Cannot add item with 'type' attribute")
item_data["type"] = item.__class__.__name__
item_data["id"] = hash_token(item_data["id"])
await table.put_item(Item=item_data)

async def _delete_item(self, item: RobolistBaseModel | str) -> None:
table = await self.db.Table(TABLE_NAME)
if isinstance(item, str):
await table.delete_item(Key={"id": item})
else:
await table.delete_item(Key={"id": item.id})

async def _delete_hashed_item(self, item: RobolistBaseModel | str) -> None:
table = await self.db.Table(TABLE_NAME)
if isinstance(item, str):
await table.delete_item(Key={"id": hash_token(item)})
else:
await table.delete_item(Key={"id": hash_token(item.id)})

async def _list_items(
self,
item_class: type[T],
Expand Down Expand Up @@ -130,7 +113,11 @@ async def _list_items(
return [self._validate_item(item, item_class) for item in items]

async def _list(
self, item_class: type[T], page: int, sort_key: Callable[[T], int], search_query: str | None = None
self,
item_class: type[T],
page: int,
sort_key: Callable[[T], int],
search_query: str | None = None,
) -> tuple[list[T], bool]:
if search_query:
response = await self._list_items(
Expand Down Expand Up @@ -187,7 +174,7 @@ def _validate_item(self, data: dict[str, Any], item_class: type[T]) -> T:
async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: Literal[True]) -> T: ...

@overload
async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: Literal[False]) -> T | None: ...
async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: bool = False) -> T | None: ...

async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: bool = False) -> T | None:
table = await self.db.Table(TABLE_NAME)
Expand All @@ -204,11 +191,6 @@ async def _item_exists(self, item_id: str) -> bool:
item_dict = await table.get_item(Key={"id": item_id})
return "Item" in item_dict

async def _hashed_item_exists(self, item_id: str) -> bool:
table = await self.db.Table(TABLE_NAME)
item_dict = await table.get_item(Key={"id": hash_token(item_id)})
return "Item" in item_dict

async def _get_item_batch(
self,
item_ids: list[str],
Expand Down Expand Up @@ -256,7 +238,7 @@ async def _get_unique_item_from_secondary_index(
secondary_index_name: str,
secondary_index_value: str,
item_class: type[T],
throw_if_missing: Literal[False] = False,
throw_if_missing: bool = False,
) -> T | None: ...

async def _get_unique_item_from_secondary_index(
Expand Down
70 changes: 27 additions & 43 deletions store/app/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@

import asyncio
import warnings
from datetime import datetime
from typing import Literal, overload

from store.app.crud.base import BaseCrud, GlobalSecondaryIndex
from store.app.crypto import hash_token
from store.app.model import APIKey, OAuthKey, User
from store.app.model import APIKey, APIKeyPermissionSet, APIKeySource, OAuthKey, User
from store.settings import settings
from store.utils import LRUCache

# This dictionary is used to locally cache the last time a token was validated
# against the database. We give the tokens some buffer time to avoid hitting
# the database too often.
LAST_API_KEY_VALIDATION = LRUCache[str, tuple[datetime, bool]](2**20)
from store.utils import cache_result


def github_auth_key(github_id: str) -> str:
Expand All @@ -39,8 +33,14 @@ def get_gsis(cls) -> list[GlobalSecondaryIndex]:
("emailIndex", "email", "S", "HASH"),
]

async def get_user(self, id: str) -> User | None:
return await self._get_item(id, User, throw_if_missing=False)
@overload
async def get_user(self, id: str) -> User | None: ...

@overload
async def get_user(self, id: str, throw_if_missing: Literal[True]) -> User: ...

async def get_user(self, id: str, throw_if_missing: bool = False) -> User | None:
return await self._get_item(id, User, throw_if_missing=throw_if_missing)

async def create_user_from_token(self, token: str, email: str) -> User:
user = User.create(email=email)
Expand Down Expand Up @@ -78,9 +78,9 @@ async def create_user_from_email(self, email: str) -> User:
async def get_user_batch(self, ids: list[str]) -> list[User]:
return await self._get_item_batch(ids, User)

async def get_user_from_api_key(self, key: str) -> User:
key = await self.get_api_key(key)
return await self._get_item(key.user_id, User, throw_if_missing=True)
async def get_user_from_api_key(self, api_key_id: str) -> User:
api_key = await self.get_api_key(api_key_id)
return await self._get_item(api_key.user_id, User, throw_if_missing=True)

async def delete_user(self, id: str) -> None:
await self._delete_item(id)
Expand All @@ -92,38 +92,22 @@ async def list_users(self) -> list[User]:
async def get_user_count(self) -> int:
return await self._count_items(User)

async def get_api_key(self, id: str) -> APIKey:
hashed_id = hash_token(id)
return await self._get_item(hashed_id, APIKey, throw_if_missing=True)

async def add_api_key(self, id: str) -> APIKey:
token = APIKey.create(id=id)
await self._add_hashed_item(token)
@cache_result(settings.crypto.cache_token_db_result_seconds)
async def get_api_key(self, api_key_id: str) -> APIKey:
return await self._get_item(api_key_id, APIKey, throw_if_missing=True)

async def add_api_key(
self,
user_id: str,
source: APIKeySource,
permissions: APIKeyPermissionSet,
) -> APIKey:
token = APIKey.create(user_id=user_id, source=source, permissions=permissions)
await self._add_item(token)
return token

async def delete_api_key(self, token: APIKey | str) -> None:
await self._delete_hashed_item(token)

async def api_key_is_valid(self, token: str) -> bool:
"""Validates a token against the database, with caching.
In order to reduce the number of database queries, we locally cache
whether or not a token is valid for some amount of time.
Args:
token: The token to validate.
Returns:
If the token is valid, meaning, if it exists in the database.
"""
cur_time = datetime.now()
if token in LAST_API_KEY_VALIDATION:
last_time, is_valid = LAST_API_KEY_VALIDATION[token]
if (cur_time - last_time).seconds < settings.crypto.cache_token_db_result_seconds:
return is_valid
is_valid = await self._hashed_item_exists(token)
LAST_API_KEY_VALIDATION[token] = (cur_time, is_valid)
return is_valid
await self._delete_item(token)


async def test_adhoc() -> None:
Expand Down
32 changes: 0 additions & 32 deletions store/app/crypto.py

This file was deleted.

26 changes: 21 additions & 5 deletions store/app/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
expects (for example, converting a UUID into a string).
"""

from typing import Self
from typing import Literal, Self

from pydantic import BaseModel

from store.app.crypto import new_uuid
from store.utils import new_uuid


class RobolistBaseModel(BaseModel):
Expand Down Expand Up @@ -53,21 +53,37 @@ def create(cls, token: str, user_id: str) -> Self:
return cls(id=token, user_id=user_id)


APIKeySource = Literal["user", "oauth"]
APIKeyPermission = Literal["read", "write", "admin"]
APIKeyPermissionSet = set[APIKeyPermission] | Literal["full"]


class APIKey(RobolistBaseModel):
"""The API key is used for querying the API.
Downstream users keep the JWT locally, and it is used to authenticate
Downstream users keep the API key, and it is used to authenticate
requests to the API. The key is stored in the database, and can be
revoked by the user at any time.
"""

user_id: str
source: APIKeySource
permissions: set[APIKeyPermission]

@classmethod
def create(cls, id: str) -> Self:
def create(
cls,
user_id: str,
source: APIKeySource,
permissions: APIKeyPermissionSet,
) -> Self:
if permissions == "full":
permissions = {"read", "write", "admin"}
return cls(
id=str(new_uuid()),
user_id=id,
user_id=user_id,
source=source,
permissions=permissions,
)


Expand Down
6 changes: 5 additions & 1 deletion store/app/routers/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ async def github_code(
github_id=github_id,
)

api_key = await crud.add_api_key(user.id)
api_key = await crud.add_api_key(
user_id=user.id,
source="oauth",
permissions="full", # OAuth tokens have full permissions.
)

response.set_cookie(key="session_token", value=api_key.id, httponly=True, samesite="lax")

Expand Down
2 changes: 1 addition & 1 deletion store/app/routers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from fastapi.responses import JSONResponse, RedirectResponse
from PIL import Image

from store.app.crypto import new_uuid
from store.app.db import Crud
from store.settings import settings
from store.utils import new_uuid

image_router = APIRouter()

Expand Down
30 changes: 13 additions & 17 deletions store/app/routers/part.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel

from store.app.crypto import new_uuid
from store.app.db import Crud
from store.app.model import Image, Part
from store.app.routers.users import get_session_token
from store.app.model import Image, Part, User
from store.app.routers.users import (
get_session_user_with_read_permission,
get_session_user_with_write_permission,
)
from store.utils import new_uuid

parts_router = APIRouter()

Expand All @@ -34,11 +37,10 @@ async def dump_parts(crud: Annotated[Crud, Depends(Crud.get)]) -> list[Part]:
@parts_router.get("/your/")
async def list_your_parts(
crud: Annotated[Crud, Depends(Crud.get)],
token: Annotated[str, Depends(get_session_token)],
user: Annotated[User, Depends(get_session_user_with_read_permission)],
page: int = Query(description="Page number for pagination"),
search_query: str = Query(None, description="Search query string"),
) -> tuple[list[Part], bool]:
user = await crud.get_user_from_api_key(token)
return await crud.list_your_parts(user.id, page, search_query=search_query)


Expand All @@ -49,10 +51,8 @@ async def get_part(part_id: str, crud: Annotated[Crud, Depends(Crud.get)]) -> Pa

@parts_router.get("/user/")
async def current_user(
crud: Annotated[Crud, Depends(Crud.get)],
token: Annotated[str, Depends(get_session_token)],
user: Annotated[User, Depends(get_session_user_with_read_permission)],
) -> str | None:
user = await crud.get_user_from_api_key(token)
return user.id


Expand All @@ -65,10 +65,9 @@ class NewPart(BaseModel):
@parts_router.post("/add/")
async def add_part(
part: NewPart,
token: Annotated[str, Depends(get_session_token)],
crud: Annotated[Crud, Depends(Crud.get)],
user: Annotated[User, Depends(get_session_user_with_write_permission)],
) -> bool:
user = await crud.get_user_from_api_key(token)
await crud.add_part(
Part(
name=part.name,
Expand All @@ -85,29 +84,26 @@ async def add_part(
@parts_router.delete("/delete/{part_id}")
async def delete_part(
part_id: str,
token: Annotated[str, Depends(get_session_token)],
crud: Annotated[Crud, Depends(Crud.get)],
user: Annotated[User, Depends(get_session_user_with_write_permission)],
) -> bool:
part = await crud.get_part(part_id)
if part is None:
raise HTTPException(status_code=404, detail="Part not found")
user = await crud.get_user_from_api_key(token)
if part.owner != user.id:
raise HTTPException(status_code=403, detail="You do not own this part")
await crud.delete_part(part_id)
return True


# TODO: Improve part type annotations.
@parts_router.post("/edit-part/{part_id}/")
async def edit_part(
part_id: str,
part: dict[
str, Any
], # There has got to be a better type annotation than this (possibly the deleted) EditPart class
token: Annotated[str, Depends(get_session_token)],
part: dict[str, Any],
user: Annotated[User, Depends(get_session_user_with_write_permission)],
crud: Annotated[Crud, Depends(Crud.get)],
) -> bool:
user = await crud.get_user_from_api_key(token)
part_info = await crud.get_part(part_id)
if part_info is None:
raise HTTPException(status_code=404, detail="Part not found")
Expand Down
Loading

0 comments on commit 3907bac

Please sign in to comment.